wip: 增加了workflow_template_generate的api接口
This commit is contained in:
parent
2432bc9e3b
commit
95ec019b5a
|
|
@ -0,0 +1,27 @@
|
|||
# Copyright 2026 zhaoxi826
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from pretor.core.workflow.workflow_template_generator.workflow_template import WorkflowTemplate
|
||||
from pretor.utils.ray_hook import ray_actor_hook
|
||||
from fastapi import APIRouter, Depends
|
||||
from pretor.utils.access import TokenData, Accessor
|
||||
|
||||
resource_router = APIRouter(prefix="/api/v1/resource")
|
||||
|
||||
@resource_router.post("/workflow_template")
|
||||
async def create_workflow_template(workflow_template: WorkflowTemplate,
|
||||
_: TokenData = Depends(Accessor.get_current_user)):
|
||||
global_state_machine = ray_actor_hook("global_state_machine")
|
||||
await global_state_machine.workflow_template_generate.remote(workflow_template)
|
||||
return {"message": "创建成功"}
|
||||
|
|
@ -14,6 +14,7 @@
|
|||
|
||||
import ray
|
||||
from pretor.core.global_state_machine.provider_manager import ProviderManager
|
||||
from pretor.core.global_state_machine.tool_manager import GlobalToolManager
|
||||
from pretor.core.global_state_machine.model_provider import Provider, ProviderArgs
|
||||
import httpx
|
||||
from loguru import logger
|
||||
|
|
@ -22,14 +23,22 @@ from pretor.core.database.postgres import PostgresDatabase
|
|||
from pretor.api.platform.event import PretorEvent
|
||||
import asyncio
|
||||
from pretor.core.workflow.workflow import PretorWorkflow
|
||||
from pretor.core.workflow.workflow_template_generator.workflow_template import WorkflowTemplate
|
||||
from pretor.core.workflow.workflow_template_manager import WorkflowManager
|
||||
from pretor.tool_plugin.base_tool import BaseToolData
|
||||
|
||||
@ray.remote
|
||||
class GlobalStateMachine:
|
||||
def __init__(self, postgres_database: PostgresDatabase):
|
||||
|
||||
self.event_dict: Dict[int, PretorEvent] = {}
|
||||
self.global_provider_manager = ProviderManager(postgres_database)
|
||||
self.global_tool_manager = GlobalToolManager()
|
||||
self.global_workflow_template_manager = WorkflowManager()
|
||||
|
||||
self.postgres_database = postgres_database
|
||||
|
||||
|
||||
###以下方法为event_dict方法
|
||||
def add_event(self, event: PretorEvent) -> None:
|
||||
event.pending_queue = asyncio.Queue()
|
||||
|
|
@ -124,4 +133,23 @@ class GlobalStateMachine:
|
|||
Provider对象,返回注册在self.global_provider_manager.provider_register的供应商
|
||||
"""
|
||||
provider = self.global_provider_manager.provider_register.get(provider_title)
|
||||
return provider
|
||||
return provider
|
||||
|
||||
|
||||
###以下为global_tool_manager方法
|
||||
def get_tool_list(self, agent_name: str) -> Dict[str, BaseToolData]:
|
||||
"""
|
||||
获取工具表方法
|
||||
Args:
|
||||
agent_name: agent的名字
|
||||
|
||||
Returns:
|
||||
返回该agent的tool,类型为dict
|
||||
"""
|
||||
tool_list = self.global_tool_manager.tool_mapper.get(agent_name, {})
|
||||
return tool_list
|
||||
|
||||
|
||||
###以下为workflow_template_manager方法
|
||||
def workflow_template_generator(self, workflow_template: WorkflowTemplate) -> None:
|
||||
self.global_workflow_template_manager.generate_workflow_template(workflow_template)
|
||||
|
|
@ -17,11 +17,10 @@ from pretor.core.workflow.workflow_template_generator.workflow_template import W
|
|||
|
||||
class WorkflowTemplateGenerator:
|
||||
@staticmethod
|
||||
def generate_workflow_template(name: str, desc: str, steps: list) -> None:
|
||||
workflow_template = WorkflowTemplate(name=name, desc=desc, work_link=steps)
|
||||
def generate_workflow_template(workflow_template: WorkflowTemplate) -> WorkflowTemplate:
|
||||
output_dir = Path("pretor.workflow_template")
|
||||
if not output_dir.exists():
|
||||
output_dir.mkdir(parents=True)
|
||||
output_file = output_dir / f"{name}_workflow_template.json"
|
||||
output_file = output_dir / f"{workflow_template.name}_workflow_template.json"
|
||||
with output_file.open("w", encoding="utf-8") as f:
|
||||
f.write(workflow_template.model_dump_json(indent=4))
|
||||
f.write(workflow_template.model_dump_json(indent=4))
|
||||
|
|
|
|||
|
|
@ -16,6 +16,7 @@ import json
|
|||
from pretor.core.workflow.workflow_template_generator.workflow_template_generator import WorkflowTemplateGenerator
|
||||
from pathlib import Path
|
||||
from loguru import logger
|
||||
from pretor.core.workflow.workflow_template_generator.workflow_template import WorkflowTemplate
|
||||
|
||||
class WorkflowManager:
|
||||
def __init__(self):
|
||||
|
|
@ -35,9 +36,9 @@ class WorkflowManager:
|
|||
except KeyError:
|
||||
logger.warning(f"{workflow_template_file}不符合workflow_template格式")
|
||||
|
||||
|
||||
def generate_workflow_template(self, name: str, desc: str, steps: list) -> None:
|
||||
def generate_workflow_template(self, workflow_template: WorkflowTemplate) -> None:
|
||||
try:
|
||||
self.workflow_template_generator.generate_workflow_template(name=name, desc=desc, steps=steps)
|
||||
workflow_template = self.workflow_template_generator.generate_workflow_template(workflow_template=workflow_template)
|
||||
self.workflow_templates_registry[workflow_template.name] = workflow_template.desc
|
||||
except Exception as e:
|
||||
logger.exception("Failed to generate workflow template")
|
||||
Loading…
Reference in New Issue