From b934ee2e321e4cf529fdb0ca48be787af4694dff Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=9C=9D=E5=A4=95?= Date: Mon, 27 Apr 2026 19:20:16 +0800 Subject: [PATCH] =?UTF-8?q?=E2=9C=A8=20Enhance=20skill=20management,=20add?= =?UTF-8?q?=20tool=20integrations,=20and=20overhaul=20Chat=20UI=20(#44)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * feat: restructure skills, introduce tools, and enhance chat UI - Split worker_individual.py into separate component files: base, ordinary, special, and skill individuals. - Update skill download and resolution paths to absolute references matching viceroy capabilities, correcting tmp and docker access issues. - Introduce `GET /api/v1/resource/tool` and dynamic File Tool for agents to read file content. - Update frontend Resource view to display tools instead of resource stubs. - Convert Dashboard to Chat view, splitting chat interface to support standard chat or workflow deployment by appending prompt prefixes. Co-authored-by: zhaoxi826 <198742034+zhaoxi826@users.noreply.github.com> * feat: restructure skills, introduce tools, and enhance chat UI - Split worker_individual.py into separate component files: base, ordinary, special, and skill individuals. - Update skill download and resolution paths to absolute references matching viceroy capabilities, correcting tmp and docker access issues. - Introduce `GET /api/v1/resource/tool` and dynamic File Tool for agents to read file content. - Update frontend Resource view to display tools instead of resource stubs. - Convert Dashboard to Chat view, splitting chat interface to support standard chat or workflow deployment by appending prompt prefixes. Co-authored-by: zhaoxi826 <198742034+zhaoxi826@users.noreply.github.com> --------- Co-authored-by: google-labs-jules[bot] <161369871+google-labs-jules[bot]@users.noreply.github.com> Co-authored-by: zhaoxi826 <198742034+zhaoxi826@users.noreply.github.com> --- frontend/src/components/Chat/ChatPanel.tsx | 27 ++- frontend/src/components/Layout/Sidebar.tsx | 2 +- .../components/Resource/ResourceLayout.tsx | 10 +- .../components/Resource/ResourceSettings.tsx | 13 -- .../src/components/Resource/ToolSettings.tsx | 64 +++++++ pretor/api/resource.py | 11 +- .../global_state_machine/skill_manager.py | 6 +- .../tool_plugin/file_reader/__init__.py | 3 + .../tool_plugin/file_reader/file_reader.py | 42 +++++ pretor/worker_individual/__init__.py | 11 ++ pretor/worker_individual/base_individual.py | 70 ++++++++ .../worker_individual/ordinary_individual.py | 43 +++++ pretor/worker_individual/skill_individual.py | 110 ++++++++++++ .../worker_individual/special_individual.py | 43 +++++ pretor/worker_individual/worker_cluster.py | 6 +- pretor/worker_individual/worker_individual.py | 157 ------------------ pyproject.toml | 1 + .../global_state_machine_test.py | 12 +- tests/core/workflow/workflow_test.py | 17 +- uv.lock | 2 + 20 files changed, 451 insertions(+), 199 deletions(-) delete mode 100644 frontend/src/components/Resource/ResourceSettings.tsx create mode 100644 frontend/src/components/Resource/ToolSettings.tsx create mode 100644 pretor/plugin/tool_plugin/file_reader/__init__.py create mode 100644 pretor/plugin/tool_plugin/file_reader/file_reader.py create mode 100644 pretor/worker_individual/base_individual.py create mode 100644 pretor/worker_individual/ordinary_individual.py create mode 100644 pretor/worker_individual/skill_individual.py create mode 100644 pretor/worker_individual/special_individual.py delete mode 100644 pretor/worker_individual/worker_individual.py diff --git a/frontend/src/components/Chat/ChatPanel.tsx b/frontend/src/components/Chat/ChatPanel.tsx index 2c63b6d..ab0127e 100644 --- a/frontend/src/components/Chat/ChatPanel.tsx +++ b/frontend/src/components/Chat/ChatPanel.tsx @@ -77,8 +77,9 @@ export function ChatPanel() { try { // Assuming a token might be needed, apiClient should handle it if set + const promptModifier = mode === 'deploy' ? '[DEPLOY TASK] ' : ''; const response = await apiClient.post('/api/v1/adapter/client', { - message: userMessage.text + message: promptModifier + userMessage.text }); const aiMessage: ChatMessage = { @@ -113,11 +114,29 @@ export function ChatPanel() { } }; + const [mode, setMode] = useState<'chat' | 'deploy'>('chat'); + return (
-
- -

Pretor Assistant

+
+
+ +

Pretor Assistant

+
+
+ + +
{/* Chat History */} diff --git a/frontend/src/components/Layout/Sidebar.tsx b/frontend/src/components/Layout/Sidebar.tsx index 1923dcf..28d0bdc 100644 --- a/frontend/src/components/Layout/Sidebar.tsx +++ b/frontend/src/components/Layout/Sidebar.tsx @@ -20,7 +20,7 @@ export function Sidebar({ currentView, setCurrentView }: SidebarProps) { diff --git a/frontend/src/components/Resource/ResourceLayout.tsx b/frontend/src/components/Resource/ResourceLayout.tsx index 889c610..17e78d1 100644 --- a/frontend/src/components/Resource/ResourceLayout.tsx +++ b/frontend/src/components/Resource/ResourceLayout.tsx @@ -1,6 +1,6 @@ import { Wrench, Database, FileCode } from 'lucide-react'; import { SkillSettings } from './SkillSettings'; -import { ResourceSettings } from './ResourceSettings'; +import { ToolSettings } from './ToolSettings'; import { WorkflowTemplateSettings } from './WorkflowTemplateSettings'; interface ResourceLayoutProps { @@ -32,11 +32,11 @@ export function ResourceLayout({ resourceTab, setResourceTab }: ResourceLayoutPr Workflow Templates
@@ -45,7 +45,7 @@ export function ResourceLayout({ resourceTab, setResourceTab }: ResourceLayoutPr
{resourceTab === 'skill' && } {resourceTab === 'workflow_template' && } - {resourceTab === 'resource' && } + {resourceTab === 'tool' && }
); diff --git a/frontend/src/components/Resource/ResourceSettings.tsx b/frontend/src/components/Resource/ResourceSettings.tsx deleted file mode 100644 index 0f9691d..0000000 --- a/frontend/src/components/Resource/ResourceSettings.tsx +++ /dev/null @@ -1,13 +0,0 @@ -export function ResourceSettings() { - return ( -
-
-

Resource Management

-

Manage external and internal resources.

-
-
- Resource management configuration coming soon... -
-
- ); -} \ No newline at end of file diff --git a/frontend/src/components/Resource/ToolSettings.tsx b/frontend/src/components/Resource/ToolSettings.tsx new file mode 100644 index 0000000..cebb317 --- /dev/null +++ b/frontend/src/components/Resource/ToolSettings.tsx @@ -0,0 +1,64 @@ +import { useState, useEffect } from 'react'; +import { Package } from 'lucide-react'; +import apiClient from '../../api/client'; + +export function ToolSettings() { + const [tools, setTools] = useState([]); + const [loading, setLoading] = useState(true); + + useEffect(() => { + fetchTools(); + }, []); + + const fetchTools = async () => { + try { + setLoading(true); + const response = await apiClient.get('/api/v1/resource/tool'); + const toolsData = response.data.tools || []; + setTools(toolsData); + } catch (err) { + console.error('Failed to fetch tools:', err); + } finally { + setLoading(false); + } + }; + + return ( +
+
+

Installed Tools

+

Manage agent tools and functions.

+
+ +
+
+
+

Available Tools

+

List of installed tools available for agents.

+
+
+ +
+ {loading ? ( +
Loading tools...
+ ) : tools.length === 0 ? ( +
No tools installed yet.
+ ) : ( +
+ {tools.map((tool) => ( +
+
+
+ +
+ {tool} +
+
+ ))} +
+ )} +
+
+
+ ); +} diff --git a/pretor/api/resource.py b/pretor/api/resource.py index 141c425..e39789a 100644 --- a/pretor/api/resource.py +++ b/pretor/api/resource.py @@ -53,9 +53,12 @@ async def install_skill(skill: Skill, _: TokenData = Depends(RoleChecker(allowed_roles=UserAuthority.USER))): global_state_machine = ray_actor_hook("global_state_machine").global_state_machine # noinspection PyUnresolvedReferences + import os + skill_output_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "plugin", "skill")) + os.makedirs(skill_output_dir, exist_ok=True) await viceroy.install_skill_async(url = skill.repo_url, path = skill.path, - output = "./pretor/plugin/tool_plugin") + output = skill_output_dir) if skill.path: skill_name = skill.path.split("/")[-1] else: @@ -75,3 +78,9 @@ async def delete_skill(skill_name: str, _: TokenData = Depends(RoleChecker(allow # Note: this only removes it from the state machine manager. await global_state_machine.remove_skill.remote( skill_name) return {"message": "success"} + +@resource_router.get("/tool") +async def get_tools(_: TokenData = Depends(RoleChecker(allowed_roles=UserAuthority.USER))): + global_state_machine = ray_actor_hook("global_state_machine").global_state_machine + tools = await global_state_machine.get_tool_list.remote("default") + return {"tools": list(tools.keys())} diff --git a/pretor/core/global_state_machine/skill_manager.py b/pretor/core/global_state_machine/skill_manager.py index a5a6d98..6021a83 100644 --- a/pretor/core/global_state_machine/skill_manager.py +++ b/pretor/core/global_state_machine/skill_manager.py @@ -24,7 +24,8 @@ class GlobalSkillManager: def __init__(self): self.skill_mapper = defaultdict(tuple) - skill_plugin_dir = pathlib.Path(__file__).parent.parent.parent / "plugin" / "skill_plugin" + import os + skill_plugin_dir = pathlib.Path(os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "..", "plugin", "skill"))) if not skill_plugin_dir.exists() or not skill_plugin_dir.is_dir(): return for item in skill_plugin_dir.iterdir(): @@ -46,7 +47,8 @@ class GlobalSkillManager: def add_skill(self, skill_name: str) -> None: """Add a skill to the manager by reading its skill.json from the path""" - skill_plugin_dir = pathlib.Path(__file__).parent.parent.parent / "plugin" / "skill_plugin" + import os + skill_plugin_dir = pathlib.Path(os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "..", "plugin", "skill"))) item = skill_plugin_dir / skill_name if item.is_dir() and not item.name.startswith((".", "__")): json_path = item / "skill.json" diff --git a/pretor/plugin/tool_plugin/file_reader/__init__.py b/pretor/plugin/tool_plugin/file_reader/__init__.py new file mode 100644 index 0000000..30143e6 --- /dev/null +++ b/pretor/plugin/tool_plugin/file_reader/__init__.py @@ -0,0 +1,3 @@ +from .file_reader import FileReaderData, file_reader + +__all__ = ["FileReaderData", "file_reader"] diff --git a/pretor/plugin/tool_plugin/file_reader/file_reader.py b/pretor/plugin/tool_plugin/file_reader/file_reader.py new file mode 100644 index 0000000..9ec80be --- /dev/null +++ b/pretor/plugin/tool_plugin/file_reader/file_reader.py @@ -0,0 +1,42 @@ +# Copyright 2026 zhaoxi826 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from pydantic_ai import RunContext +from pretor.plugin.tool_plugin.base_tool import BaseToolData +import os + +class FileReaderData(BaseToolData): + name: str = "file_reader" + description: str = "读取本地文件的内容" + +def file_reader(ctx: RunContext, filepath: str) -> str: + """读取本地文件内容的工具。 + + Args: + filepath: 目标文件的绝对路径或相对路径。 + + Returns: + 如果文件存在并可读,返回文件内容;否则返回错误信息。 + """ + if not os.path.exists(filepath): + return f"Error: 文件 {filepath} 不存在。" + if not os.path.isfile(filepath): + return f"Error: {filepath} 不是一个文件。" + + try: + with open(filepath, 'r', encoding='utf-8') as f: + content = f.read() + return content + except Exception as e: + return f"Error: 读取文件失败,原因:{str(e)}" diff --git a/pretor/worker_individual/__init__.py b/pretor/worker_individual/__init__.py index e69de29..565a31a 100644 --- a/pretor/worker_individual/__init__.py +++ b/pretor/worker_individual/__init__.py @@ -0,0 +1,11 @@ +from pretor.worker_individual.base_individual import BaseIndividual +from pretor.worker_individual.skill_individual import SkillIndividual +from pretor.worker_individual.ordinary_individual import OrdinaryIndividual +from pretor.worker_individual.special_individual import SpecialIndividual + +__all__ = [ + "BaseIndividual", + "SkillIndividual", + "OrdinaryIndividual", + "SpecialIndividual", +] diff --git a/pretor/worker_individual/base_individual.py b/pretor/worker_individual/base_individual.py new file mode 100644 index 0000000..d719770 --- /dev/null +++ b/pretor/worker_individual/base_individual.py @@ -0,0 +1,70 @@ +# Copyright 2026 zhaoxi826 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from pydantic_ai import Agent, RunContext +from pydantic import Field +from pretor.adapter.model_adapter.agent_factory import AgentFactory +from pretor.core.global_state_machine.model_provider.base_provider import Provider +from pretor.utils.agent_model import ResponseModel, InputModel, DepsModel +from pretor.utils.ray_hook import ray_actor_hook + +from pretor.utils.logger import get_logger +logger = get_logger('worker_individual') + +class WorkerIndividualResponse(ResponseModel): + output: str = Field(..., description="Worker执行任务的输出结果") + +class WorkerIndividualDeps(DepsModel): + task_event: dict + +class WorkerIndividualInput(InputModel): + task_event: dict + +class BaseIndividual: + """ + Worker Individual 的基类 + """ + + def __init__(self, agent_config: dict): + self.agent_config = agent_config + self.agent_id = agent_config.get("agent_id") + self.agent: Agent | None = None + + async def _init_agent(self, agent_name: str, system_prompt: str): + global_state_machine = ray_actor_hook("global_state_machine").global_state_machine + provider_title = self.agent_config.get("provider_title", "openai") # default fallback + model_id = self.agent_config.get("model_id", "gpt-4o") # default fallback + + provider: Provider = await global_state_machine.get_provider.remote( provider_title) + agent_factory = AgentFactory() + self.agent = agent_factory.create_agent( + provider=provider, + model_id=model_id, + output_type=WorkerIndividualResponse, + system_prompt=system_prompt, + deps_type=WorkerIndividualDeps, + agent_name=agent_name + ) + + @self.agent.system_prompt + async def dynamic_prompt(ctx: RunContext[WorkerIndividualDeps]): + prompt = system_prompt + "\n\n" + prompt += ( + f"=== 当前任务上下文 ===\n" + f"{ctx.deps.task_event}\n" + ) + return prompt + + async def run(self, task_event: dict) -> dict: + raise NotImplementedError("子类必须实现 run 方法") diff --git a/pretor/worker_individual/ordinary_individual.py b/pretor/worker_individual/ordinary_individual.py new file mode 100644 index 0000000..2386ac9 --- /dev/null +++ b/pretor/worker_individual/ordinary_individual.py @@ -0,0 +1,43 @@ +# Copyright 2026 zhaoxi826 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from pretor.worker_individual.base_individual import BaseIndividual, WorkerIndividualDeps +from pretor.utils.logger import get_logger + +logger = get_logger('ordinary_individual') + +class OrdinaryIndividual(BaseIndividual): + """ + 普通子个体:普通的 agent。 + """ + + def __init__(self, agent_config: dict): + super().__init__(agent_config) + + async def run(self, task_event: dict) -> dict: + if self.agent is None: + system_prompt = self.agent_config.get("prompt", "你是一个普通的AI助手,请尽力完成给定的任务。") + await self._init_agent("ordinary_individual", system_prompt) + + deps = WorkerIndividualDeps(task_event=task_event) + self.agent.retries = 3 + try: + result = await self.agent.run( + f"请执行以下任务:\n{task_event}", + deps=deps + ) + return {"output": result.data.output} + except Exception as e: + logger.exception(f"OrdinaryIndividual {self.agent_id} 执行失败: {e}") + raise diff --git a/pretor/worker_individual/skill_individual.py b/pretor/worker_individual/skill_individual.py new file mode 100644 index 0000000..cde9776 --- /dev/null +++ b/pretor/worker_individual/skill_individual.py @@ -0,0 +1,110 @@ +# Copyright 2026 zhaoxi826 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from pretor.worker_individual.base_individual import BaseIndividual, WorkerIndividualDeps +from pretor.utils.logger import get_logger +import os +import json +from pydantic_ai import Tool +import importlib.util + +logger = get_logger('skill_individual') + +class SkillIndividual(BaseIndividual): + """ + 专家子个体:拥有专业 skill 的 agent。 + """ + + def __init__(self, agent_config: dict): + super().__init__(agent_config) + + async def _load_skill_tools(self): + """动态加载已绑定的 skill 工具。""" + tools = [] + bound_skill = self.agent_config.get("bound_skill", "") + # bound_skill can be string or dict {"skill_name": ["file1", "file2"]} + skill_mapper = {} + if isinstance(bound_skill, str) and bound_skill: + try: + skill_mapper = json.loads(bound_skill) + except json.JSONDecodeError: + pass + elif isinstance(bound_skill, dict): + skill_mapper = bound_skill + + skill_base_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "plugin", "skill")) + + for skill_name, _ in skill_mapper.items(): + skill_path = os.path.join(skill_base_dir, skill_name) + metadata_path = os.path.join(skill_path, "metadata.json") + if not os.path.exists(metadata_path): + continue + + try: + with open(metadata_path, 'r', encoding='utf-8') as f: + metadata = json.load(f) + except Exception as e: + logger.error(f"Failed to load metadata for skill {skill_name}: {e}") + continue + + if "functions" in metadata: + for func_info in metadata["functions"]: + # Ensure path is absolute + script_path = func_info.get("file_path", "") + if not os.path.isabs(script_path): + script_path = os.path.join(skill_path, script_path) + + if not os.path.exists(script_path): + logger.warning(f"Skill script not found: {script_path}") + continue + + func_name = func_info.get("name") + try: + # Dynamically load the python module + spec = importlib.util.spec_from_file_location(func_name, script_path) + module = importlib.util.module_from_spec(spec) + spec.loader.exec_module(module) + + func = getattr(module, func_name) + if callable(func): + # Convert to PydanticAI Tool + tool = Tool(func, name=func_name, description=func_info.get("docstring", "")) + tools.append(tool) + logger.info(f"Loaded skill tool: {func_name} from {skill_name}") + except Exception as e: + logger.error(f"Failed to load function {func_name} from {script_path}: {e}") + + return tools + + async def run(self, task_event: dict) -> dict: + if self.agent is None: + system_prompt = self.agent_config.get("prompt", + "你是一个拥有专业技能的专家级AI助手,请利用你的专业知识完成给定的任务。") + await self._init_agent("skill_individual", system_prompt) + + deps = WorkerIndividualDeps(task_event=task_event) + self.agent.retries = 3 + + tools = await self._load_skill_tools() + + try: + result = await self.agent.run( + f"请执行以下任务:\n{task_event}", + deps=deps, + tools=tools if tools else None + ) + return {"output": result.data.output} + except Exception as e: + logger.exception(f"SkillIndividual {self.agent_id} 执行失败: {e}") + raise diff --git a/pretor/worker_individual/special_individual.py b/pretor/worker_individual/special_individual.py new file mode 100644 index 0000000..27ee49d --- /dev/null +++ b/pretor/worker_individual/special_individual.py @@ -0,0 +1,43 @@ +# Copyright 2026 zhaoxi826 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from pretor.worker_individual.base_individual import BaseIndividual, WorkerIndividualDeps +from pretor.utils.logger import get_logger + +logger = get_logger('special_individual') + +class SpecialIndividual(BaseIndividual): + """ + 特殊子个体:执行特殊任务的 agent,如生成语音、视频等。 + """ + + def __init__(self, agent_config: dict): + super().__init__(agent_config) + + async def run(self, task_event: dict) -> dict: + if self.agent is None: + system_prompt = self.agent_config.get("prompt", "你是一个特殊的AI助手,负责处理特殊类型的任务。") + await self._init_agent("special_individual", system_prompt) + + deps = WorkerIndividualDeps(task_event=task_event) + self.agent.retries = 3 + try: + result = await self.agent.run( + f"请执行以下任务:\n{task_event}", + deps=deps + ) + return {"output": result.data.output} + except Exception as e: + logger.exception(f"SpecialIndividual {self.agent_id} 执行失败: {e}") + raise diff --git a/pretor/worker_individual/worker_cluster.py b/pretor/worker_individual/worker_cluster.py index e0e2013..dfa3117 100644 --- a/pretor/worker_individual/worker_cluster.py +++ b/pretor/worker_individual/worker_cluster.py @@ -18,8 +18,10 @@ import asyncio from collections import OrderedDict from ray.util.queue import Queue from pretor.utils.ray_hook import ray_actor_hook -from pretor.worker_individual.worker_individual import BaseIndividual, SkillIndividual, OrdinaryIndividual, \ - SpecialIndividual +from pretor.worker_individual.base_individual import BaseIndividual +from pretor.worker_individual.skill_individual import SkillIndividual +from pretor.worker_individual.ordinary_individual import OrdinaryIndividual +from pretor.worker_individual.special_individual import SpecialIndividual from pretor.utils.logger import get_logger diff --git a/pretor/worker_individual/worker_individual.py b/pretor/worker_individual/worker_individual.py deleted file mode 100644 index 36d3029..0000000 --- a/pretor/worker_individual/worker_individual.py +++ /dev/null @@ -1,157 +0,0 @@ -# Copyright 2026 zhaoxi826 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - -from pydantic_ai import Agent, RunContext -from pydantic import Field -from pretor.adapter.model_adapter.agent_factory import AgentFactory -from pretor.core.global_state_machine.model_provider.base_provider import Provider -from pretor.utils.agent_model import ResponseModel, InputModel, DepsModel -from pretor.utils.ray_hook import ray_actor_hook - - -from pretor.utils.logger import get_logger -logger = get_logger('worker_individual') - -class WorkerIndividualResponse(ResponseModel): - output: str = Field(..., description="Worker执行任务的输出结果") - - -class WorkerIndividualDeps(DepsModel): - task_event: dict - - -class WorkerIndividualInput(InputModel): - task_event: dict - - -class BaseIndividual: - """ - Worker Individual 的基类 - """ - - def __init__(self, agent_config: dict): - self.agent_config = agent_config - self.agent_id = agent_config.get("agent_id") - self.agent: Agent | None = None - - async def _init_agent(self, agent_name: str, system_prompt: str): - global_state_machine = ray_actor_hook("global_state_machine").global_state_machine - provider_title = self.agent_config.get("provider_title", "openai") # default fallback - model_id = self.agent_config.get("model_id", "gpt-4o") # default fallback - - provider: Provider = await global_state_machine.get_provider.remote( provider_title) - agent_factory = AgentFactory() - self.agent = agent_factory.create_agent( - provider=provider, - model_id=model_id, - output_type=WorkerIndividualResponse, - system_prompt=system_prompt, - deps_type=WorkerIndividualDeps, - agent_name=agent_name - ) - - @self.agent.system_prompt - async def dynamic_prompt(ctx: RunContext[WorkerIndividualDeps]): - prompt = system_prompt + "\n\n" - prompt += ( - f"=== 当前任务上下文 ===\n" - f"{ctx.deps.task_event}\n" - ) - return prompt - - async def run(self, task_event: dict) -> dict: - raise NotImplementedError("子类必须实现 run 方法") - - -class SkillIndividual(BaseIndividual): - """ - 专家子个体:拥有专业 skill 的 agent。 - """ - - def __init__(self, agent_config: dict): - super().__init__(agent_config) - - async def run(self, task_event: dict) -> dict: - if self.agent is None: - system_prompt = self.agent_config.get("prompt", - "你是一个拥有专业技能的专家级AI助手,请利用你的专业知识完成给定的任务。") - await self._init_agent("skill_individual", system_prompt) - - deps = WorkerIndividualDeps(task_event=task_event) - self.agent.retries = 3 - # In actual usage, tools could be dynamically loaded here based on agent_config - # tool = get_tool("skill_individual") - try: - result = await self.agent.run( - f"请执行以下任务:\n{task_event}", - deps=deps - # tools=tool - ) - return {"output": result.data.output} - except Exception as e: - logger.exception(f"SkillIndividual {self.agent_id} 执行失败: {e}") - raise - - -class OrdinaryIndividual(BaseIndividual): - """ - 普通子个体:普通的 agent。 - """ - - def __init__(self, agent_config: dict): - super().__init__(agent_config) - - async def run(self, task_event: dict) -> dict: - if self.agent is None: - system_prompt = self.agent_config.get("prompt", "你是一个普通的AI助手,请尽力完成给定的任务。") - await self._init_agent("ordinary_individual", system_prompt) - - deps = WorkerIndividualDeps(task_event=task_event) - self.agent.retries = 3 - try: - result = await self.agent.run( - f"请执行以下任务:\n{task_event}", - deps=deps - ) - return {"output": result.data.output} - except Exception as e: - logger.exception(f"OrdinaryIndividual {self.agent_id} 执行失败: {e}") - raise - - -class SpecialIndividual(BaseIndividual): - """ - 特殊子个体:执行特殊任务的 agent,如生成语音、视频等。 - """ - - def __init__(self, agent_config: dict): - super().__init__(agent_config) - - async def run(self, task_event: dict) -> dict: - if self.agent is None: - system_prompt = self.agent_config.get("prompt", "你是一个特殊的AI助手,负责处理特殊类型的任务。") - await self._init_agent("special_individual", system_prompt) - - deps = WorkerIndividualDeps(task_event=task_event) - self.agent.retries = 3 - try: - result = await self.agent.run( - f"请执行以下任务:\n{task_event}", - deps=deps - ) - return {"output": result.data.output} - except Exception as e: - logger.exception(f"SpecialIndividual {self.agent_id} 执行失败: {e}") - raise diff --git a/pyproject.toml b/pyproject.toml index bc28687..3778c59 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -16,6 +16,7 @@ dependencies = [ "pwdlib[argon2,bcrypt]>=0.3.0", "pydantic-ai>=1.73.0", "pyfiglet>=1.0.4", + "pyjwt>=2.12.1", "python-ulid>=3.1.0", "ray[default,serve]>=2.54.0", "rich>=14.3.3", diff --git a/tests/core/global_state_machine/global_state_machine_test.py b/tests/core/global_state_machine/global_state_machine_test.py index 9efbdf9..f894199 100644 --- a/tests/core/global_state_machine/global_state_machine_test.py +++ b/tests/core/global_state_machine/global_state_machine_test.py @@ -50,23 +50,23 @@ def gsm(mock_postgres): def test_add_delete_get_event(gsm): event = MagicMock(spec=PretorEvent) - event.event_id = 123 + event.trace_id = "123" gsm.add_event(event) assert getattr(event, 'pending_queue', None) is not None assert getattr(event, 'receive_queue', None) is not None - retrieved = gsm.get_event(123) + retrieved = gsm.get_event("123") assert retrieved == event - gsm.delete_event(123) - assert gsm.get_event(123) is None + gsm.delete_event("123") + assert gsm.get_event("123") is None def test_update_attachment_and_workflow(gsm): event = MagicMock(spec=PretorEvent) - event.event_id = "abc" + event.trace_id = "abc" gsm.add_event(event) gsm.update_attachment("abc", {"k": "v"}) @@ -80,7 +80,7 @@ def test_update_attachment_and_workflow(gsm): @pytest.mark.asyncio async def test_queues(gsm): event = MagicMock(spec=PretorEvent) - event.event_id = "q_event" + event.trace_id = "q_event" # To use await put/get, we must actually use real asyncio queues for the mock event event.pending_queue = asyncio.Queue() event.receive_queue = asyncio.Queue() diff --git a/tests/core/workflow/workflow_test.py b/tests/core/workflow/workflow_test.py index b052019..7ed7842 100644 --- a/tests/core/workflow/workflow_test.py +++ b/tests/core/workflow/workflow_test.py @@ -10,41 +10,43 @@ def test_worker_group(): def test_work_step(): ws = WorkStep( step=1, + name="step1", node="control_node", action="coding", desc="Write some code" ) assert ws.step == 1 + assert ws.name == "step1" assert ws.node == "control_node" assert ws.action == "coding" assert ws.desc == "Write some code" assert ws.status == "waiting" def test_pretor_workflow_validation_success(): - ws1 = WorkStep(step=1, node="control_node", action="a1", desc="d1") - ws2 = WorkStep(step=2, node="supervisory_node", action="a2", desc="d2") + ws1 = WorkStep(step=1, name="s1", node="control_node", action="a1", desc="d1") + ws2 = WorkStep(step=2, name="s2", node="supervisory_node", action="a2", desc="d2") wg = WorkerGroup(name="g1", primary_individual={"coder": 1}, composite_individual={}) wf = PretorWorkflow(title="wf1", workgroup_list=[wg], work_link=[ws1, ws2], trace_id="t", event_info={"platform":"a", "user_name":"b"}) assert wf.title == "wf1" def test_pretor_workflow_validation_error_step_discontinuous(): - ws1 = WorkStep(step=1, node="control_node", action="a1", desc="d1") - ws2 = WorkStep(step=3, node="supervisory_node", action="a2", desc="d2") + ws1 = WorkStep(step=1, name="s1", node="control_node", action="a1", desc="d1") + ws2 = WorkStep(step=3, name="s3", node="supervisory_node", action="a2", desc="d2") wg = WorkerGroup(name="g1", primary_individual={}, composite_individual={}) with pytest.raises(ValueError, match="工作链步数不连续"): PretorWorkflow(title="wf1", workgroup_list=[wg], work_link=[ws1, ws2], trace_id="t", event_info={"platform":"a", "user_name":"b"}) def test_pretor_workflow_validation_error_jump_out_of_bounds(): lg = LogicGate(if_fail="jump_to_step_3", if_pass="continue") - ws1 = WorkStep(step=1, node="control_node", action="a1", desc="d1", logic_gate=lg) - ws2 = WorkStep(step=2, node="supervisory_node", action="a2", desc="d2") + ws1 = WorkStep(step=1, name="s1", node="control_node", action="a1", desc="d1", logic_gate=lg) + ws2 = WorkStep(step=2, name="s2", node="supervisory_node", action="a2", desc="d2") wg = WorkerGroup(name="g1", primary_individual={}, composite_individual={}) with pytest.raises(ValueError, match="跳转目标 Step 3 越界了"): PretorWorkflow(title="wf1", workgroup_list=[wg], work_link=[ws1, ws2], trace_id="t", event_info={"platform":"a", "user_name":"b"}) def test_pretor_workflow_validation_error_jump_format_error(): lg = LogicGate(if_fail="jump_to_step_invalid", if_pass="continue") - ws1 = WorkStep(step=1, node="control_node", action="a1", desc="d1", logic_gate=lg) + ws1 = WorkStep(step=1, name="s1", node="control_node", action="a1", desc="d1", logic_gate=lg) wg = WorkerGroup(name="g1", primary_individual={}, composite_individual={}) with pytest.raises(ValueError, match="LogicGate 格式错误"): PretorWorkflow(title="wf1", workgroup_list=[wg], work_link=[ws1], trace_id="t", event_info={"platform":"a", "user_name":"b"}) @@ -53,4 +55,3 @@ def test_workflow_status(): status = WorkflowStatus() assert status.step == 1 assert status.status == "waiting_llm_working" - assert status.demand is None diff --git a/uv.lock b/uv.lock index 8204bc7..283da53 100644 --- a/uv.lock +++ b/uv.lock @@ -3074,6 +3074,7 @@ dependencies = [ { name = "pydantic-ai", version = "1.75.0", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version >= '3.14'" }, { name = "pydantic-ai", version = "1.84.0", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version < '3.14'" }, { name = "pyfiglet" }, + { name = "pyjwt" }, { name = "python-ulid" }, { name = "ray", extra = ["default", "serve"] }, { name = "rich" }, @@ -3105,6 +3106,7 @@ requires-dist = [ { name = "pwdlib", extras = ["argon2", "bcrypt"], specifier = ">=0.3.0" }, { name = "pydantic-ai", specifier = ">=1.73.0" }, { name = "pyfiglet", specifier = ">=1.0.4" }, + { name = "pyjwt", specifier = ">=2.12.1" }, { name = "python-ulid", specifier = ">=3.1.0" }, { name = "ray", extras = ["default", "serve"], specifier = ">=2.54.0" }, { name = "rich", specifier = ">=14.3.3" },