From 95ec019b5a4acc910ab39752da15a98f40ccf5d7 Mon Sep 17 00:00:00 2001 From: zhaoxi Date: Mon, 13 Apr 2026 22:44:20 +0800 Subject: [PATCH] =?UTF-8?q?wip:=20=E5=A2=9E=E5=8A=A0=E4=BA=86workflow=5Fte?= =?UTF-8?q?mplate=5Fgenerate=E7=9A=84api=E6=8E=A5=E5=8F=A3?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- pretor/api/resource.py | 27 +++++++++++++++++ .../global_state_machine.py | 30 ++++++++++++++++++- .../workflow_template_generator.py | 7 ++--- .../workflow/workflow_template_manager.py | 7 +++-- 4 files changed, 63 insertions(+), 8 deletions(-) create mode 100644 pretor/api/resource.py diff --git a/pretor/api/resource.py b/pretor/api/resource.py new file mode 100644 index 0000000..2d882d9 --- /dev/null +++ b/pretor/api/resource.py @@ -0,0 +1,27 @@ +# Copyright 2026 zhaoxi826 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from pretor.core.workflow.workflow_template_generator.workflow_template import WorkflowTemplate +from pretor.utils.ray_hook import ray_actor_hook +from fastapi import APIRouter, Depends +from pretor.utils.access import TokenData, Accessor + +resource_router = APIRouter(prefix="/api/v1/resource") + +@resource_router.post("/workflow_template") +async def create_workflow_template(workflow_template: WorkflowTemplate, + _: TokenData = Depends(Accessor.get_current_user)): + global_state_machine = ray_actor_hook("global_state_machine") + await global_state_machine.workflow_template_generate.remote(workflow_template) + return {"message": "创建成功"} \ No newline at end of file diff --git a/pretor/core/global_state_machine/global_state_machine.py b/pretor/core/global_state_machine/global_state_machine.py index 9d1b4ac..25acc79 100644 --- a/pretor/core/global_state_machine/global_state_machine.py +++ b/pretor/core/global_state_machine/global_state_machine.py @@ -14,6 +14,7 @@ import ray from pretor.core.global_state_machine.provider_manager import ProviderManager +from pretor.core.global_state_machine.tool_manager import GlobalToolManager from pretor.core.global_state_machine.model_provider import Provider, ProviderArgs import httpx from loguru import logger @@ -22,14 +23,22 @@ from pretor.core.database.postgres import PostgresDatabase from pretor.api.platform.event import PretorEvent import asyncio from pretor.core.workflow.workflow import PretorWorkflow +from pretor.core.workflow.workflow_template_generator.workflow_template import WorkflowTemplate +from pretor.core.workflow.workflow_template_manager import WorkflowManager +from pretor.tool_plugin.base_tool import BaseToolData @ray.remote class GlobalStateMachine: def __init__(self, postgres_database: PostgresDatabase): + self.event_dict: Dict[int, PretorEvent] = {} self.global_provider_manager = ProviderManager(postgres_database) + self.global_tool_manager = GlobalToolManager() + self.global_workflow_template_manager = WorkflowManager() + self.postgres_database = postgres_database + ###以下方法为event_dict方法 def add_event(self, event: PretorEvent) -> None: event.pending_queue = asyncio.Queue() @@ -124,4 +133,23 @@ class GlobalStateMachine: Provider对象,返回注册在self.global_provider_manager.provider_register的供应商 """ provider = self.global_provider_manager.provider_register.get(provider_title) - return provider \ No newline at end of file + return provider + + + ###以下为global_tool_manager方法 + def get_tool_list(self, agent_name: str) -> Dict[str, BaseToolData]: + """ + 获取工具表方法 + Args: + agent_name: agent的名字 + + Returns: + 返回该agent的tool,类型为dict + """ + tool_list = self.global_tool_manager.tool_mapper.get(agent_name, {}) + return tool_list + + + ###以下为workflow_template_manager方法 + def workflow_template_generator(self, workflow_template: WorkflowTemplate) -> None: + self.global_workflow_template_manager.generate_workflow_template(workflow_template) \ No newline at end of file diff --git a/pretor/core/workflow/workflow_template_generator/workflow_template_generator.py b/pretor/core/workflow/workflow_template_generator/workflow_template_generator.py index 3959ad9..b210dc0 100644 --- a/pretor/core/workflow/workflow_template_generator/workflow_template_generator.py +++ b/pretor/core/workflow/workflow_template_generator/workflow_template_generator.py @@ -17,11 +17,10 @@ from pretor.core.workflow.workflow_template_generator.workflow_template import W class WorkflowTemplateGenerator: @staticmethod - def generate_workflow_template(name: str, desc: str, steps: list) -> None: - workflow_template = WorkflowTemplate(name=name, desc=desc, work_link=steps) + def generate_workflow_template(workflow_template: WorkflowTemplate) -> WorkflowTemplate: output_dir = Path("pretor.workflow_template") if not output_dir.exists(): output_dir.mkdir(parents=True) - output_file = output_dir / f"{name}_workflow_template.json" + output_file = output_dir / f"{workflow_template.name}_workflow_template.json" with output_file.open("w", encoding="utf-8") as f: - f.write(workflow_template.model_dump_json(indent=4)) \ No newline at end of file + f.write(workflow_template.model_dump_json(indent=4)) diff --git a/pretor/core/workflow/workflow_template_manager.py b/pretor/core/workflow/workflow_template_manager.py index 9a870af..f6cdff2 100644 --- a/pretor/core/workflow/workflow_template_manager.py +++ b/pretor/core/workflow/workflow_template_manager.py @@ -16,6 +16,7 @@ import json from pretor.core.workflow.workflow_template_generator.workflow_template_generator import WorkflowTemplateGenerator from pathlib import Path from loguru import logger +from pretor.core.workflow.workflow_template_generator.workflow_template import WorkflowTemplate class WorkflowManager: def __init__(self): @@ -35,9 +36,9 @@ class WorkflowManager: except KeyError: logger.warning(f"{workflow_template_file}不符合workflow_template格式") - - def generate_workflow_template(self, name: str, desc: str, steps: list) -> None: + def generate_workflow_template(self, workflow_template: WorkflowTemplate) -> None: try: - self.workflow_template_generator.generate_workflow_template(name=name, desc=desc, steps=steps) + workflow_template = self.workflow_template_generator.generate_workflow_template(workflow_template=workflow_template) + self.workflow_templates_registry[workflow_template.name] = workflow_template.desc except Exception as e: logger.exception("Failed to generate workflow template") \ No newline at end of file