Enhance skill management, add tool integrations, and overhaul Chat UI (#44)

* feat: restructure skills, introduce tools, and enhance chat UI

- Split worker_individual.py into separate component files: base, ordinary, special, and skill individuals.
- Update skill download and resolution paths to absolute references matching viceroy capabilities, correcting tmp and docker access issues.
- Introduce `GET /api/v1/resource/tool` and dynamic File Tool for agents to read file content.
- Update frontend Resource view to display tools instead of resource stubs.
- Convert Dashboard to Chat view, splitting chat interface to support standard chat or workflow deployment by appending prompt prefixes.

Co-authored-by: zhaoxi826 <198742034+zhaoxi826@users.noreply.github.com>

* feat: restructure skills, introduce tools, and enhance chat UI

- Split worker_individual.py into separate component files: base, ordinary, special, and skill individuals.
- Update skill download and resolution paths to absolute references matching viceroy capabilities, correcting tmp and docker access issues.
- Introduce `GET /api/v1/resource/tool` and dynamic File Tool for agents to read file content.
- Update frontend Resource view to display tools instead of resource stubs.
- Convert Dashboard to Chat view, splitting chat interface to support standard chat or workflow deployment by appending prompt prefixes.

Co-authored-by: zhaoxi826 <198742034+zhaoxi826@users.noreply.github.com>

---------

Co-authored-by: google-labs-jules[bot] <161369871+google-labs-jules[bot]@users.noreply.github.com>
Co-authored-by: zhaoxi826 <198742034+zhaoxi826@users.noreply.github.com>
This commit is contained in:
朝夕 2026-04-27 19:20:16 +08:00 committed by GitHub
parent c39b5eb8e2
commit b934ee2e32
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
20 changed files with 451 additions and 199 deletions

View File

@ -77,8 +77,9 @@ export function ChatPanel() {
try { try {
// Assuming a token might be needed, apiClient should handle it if set // Assuming a token might be needed, apiClient should handle it if set
const promptModifier = mode === 'deploy' ? '[DEPLOY TASK] ' : '';
const response = await apiClient.post('/api/v1/adapter/client', { const response = await apiClient.post('/api/v1/adapter/client', {
message: userMessage.text message: promptModifier + userMessage.text
}); });
const aiMessage: ChatMessage = { const aiMessage: ChatMessage = {
@ -113,11 +114,29 @@ export function ChatPanel() {
} }
}; };
const [mode, setMode] = useState<'chat' | 'deploy'>('chat');
return ( return (
<div className="flex-1 flex flex-col bg-slate-50"> <div className="flex-1 flex flex-col bg-slate-50">
<div className="h-14 border-b border-slate-200 bg-white flex items-center px-6 shadow-sm z-10"> <div className="h-14 border-b border-slate-200 bg-white flex items-center justify-between px-6 shadow-sm z-10">
<MessageSquare size={18} className="text-blue-600 mr-3" /> <div className="flex items-center">
<h1 className="font-semibold text-slate-800">Pretor Assistant</h1> <MessageSquare size={18} className="text-blue-600 mr-3" />
<h1 className="font-semibold text-slate-800">Pretor Assistant</h1>
</div>
<div className="flex space-x-2 bg-slate-100 p-1 rounded-lg">
<button
onClick={() => setMode('chat')}
className={`px-3 py-1 text-sm font-medium rounded-md transition-colors ${mode === 'chat' ? 'bg-white text-blue-600 shadow-sm' : 'text-slate-500 hover:text-slate-700'}`}
>
Chat
</button>
<button
onClick={() => setMode('deploy')}
className={`px-3 py-1 text-sm font-medium rounded-md transition-colors ${mode === 'deploy' ? 'bg-white text-blue-600 shadow-sm' : 'text-slate-500 hover:text-slate-700'}`}
>
Deploy Task
</button>
</div>
</div> </div>
{/* Chat History */} {/* Chat History */}

View File

@ -20,7 +20,7 @@ export function Sidebar({ currentView, setCurrentView }: SidebarProps) {
<button <button
onClick={() => setCurrentView('dashboard')} onClick={() => setCurrentView('dashboard')}
className={`p-1.5 rounded-lg transition-colors ${currentView === 'dashboard' ? 'text-blue-600 bg-blue-50' : 'text-slate-400 hover:text-blue-500 hover:bg-blue-50'}`} className={`p-1.5 rounded-lg transition-colors ${currentView === 'dashboard' ? 'text-blue-600 bg-blue-50' : 'text-slate-400 hover:text-blue-500 hover:bg-blue-50'}`}
title="Dashboard" title="Chat"
> >
<MessageSquare size={18} /> <MessageSquare size={18} />
</button> </button>

View File

@ -1,6 +1,6 @@
import { Wrench, Database, FileCode } from 'lucide-react'; import { Wrench, Database, FileCode } from 'lucide-react';
import { SkillSettings } from './SkillSettings'; import { SkillSettings } from './SkillSettings';
import { ResourceSettings } from './ResourceSettings'; import { ToolSettings } from './ToolSettings';
import { WorkflowTemplateSettings } from './WorkflowTemplateSettings'; import { WorkflowTemplateSettings } from './WorkflowTemplateSettings';
interface ResourceLayoutProps { interface ResourceLayoutProps {
@ -32,11 +32,11 @@ export function ResourceLayout({ resourceTab, setResourceTab }: ResourceLayoutPr
Workflow Templates Workflow Templates
</button> </button>
<button <button
onClick={() => setResourceTab('resource')} onClick={() => setResourceTab('tool')}
className={`w-full flex items-center px-4 py-3 text-sm font-medium rounded-xl transition-all ${resourceTab === 'resource' ? 'bg-blue-50 text-blue-600' : 'text-slate-600 hover:bg-slate-50 hover:text-slate-900'}`} className={`w-full flex items-center px-4 py-3 text-sm font-medium rounded-xl transition-all ${resourceTab === 'tool' ? 'bg-blue-50 text-blue-600' : 'text-slate-600 hover:bg-slate-50 hover:text-slate-900'}`}
> >
<Database size={18} className="mr-3" /> <Database size={18} className="mr-3" />
Resources Tools
</button> </button>
</div> </div>
</div> </div>
@ -45,7 +45,7 @@ export function ResourceLayout({ resourceTab, setResourceTab }: ResourceLayoutPr
<div className="flex-1 overflow-y-auto p-8"> <div className="flex-1 overflow-y-auto p-8">
{resourceTab === 'skill' && <SkillSettings />} {resourceTab === 'skill' && <SkillSettings />}
{resourceTab === 'workflow_template' && <WorkflowTemplateSettings />} {resourceTab === 'workflow_template' && <WorkflowTemplateSettings />}
{resourceTab === 'resource' && <ResourceSettings />} {resourceTab === 'tool' && <ToolSettings />}
</div> </div>
</div> </div>
); );

View File

@ -1,13 +0,0 @@
export function ResourceSettings() {
return (
<div className="max-w-4xl space-y-6">
<div className="mb-8">
<h1 className="text-2xl font-bold text-slate-800">Resource Management</h1>
<p className="text-slate-500 mt-1">Manage external and internal resources.</p>
</div>
<div className="bg-white rounded-xl shadow-sm border border-slate-200 overflow-hidden p-6 text-slate-500 text-sm">
Resource management configuration coming soon...
</div>
</div>
);
}

View File

@ -0,0 +1,64 @@
import { useState, useEffect } from 'react';
import { Package } from 'lucide-react';
import apiClient from '../../api/client';
export function ToolSettings() {
const [tools, setTools] = useState<string[]>([]);
const [loading, setLoading] = useState(true);
useEffect(() => {
fetchTools();
}, []);
const fetchTools = async () => {
try {
setLoading(true);
const response = await apiClient.get('/api/v1/resource/tool');
const toolsData = response.data.tools || [];
setTools(toolsData);
} catch (err) {
console.error('Failed to fetch tools:', err);
} finally {
setLoading(false);
}
};
return (
<div className="max-w-4xl space-y-6">
<div>
<h3 className="text-xl font-semibold text-slate-800">Installed Tools</h3>
<p className="text-slate-500 mt-1">Manage agent tools and functions.</p>
</div>
<div className="bg-white border border-slate-200 rounded-2xl shadow-sm overflow-hidden">
<div className="p-6 border-b border-slate-100 flex justify-between items-center bg-slate-50/50">
<div>
<h4 className="font-medium text-slate-800">Available Tools</h4>
<p className="text-sm text-slate-500">List of installed tools available for agents.</p>
</div>
</div>
<div className="p-6">
{loading ? (
<div className="text-slate-500 text-sm">Loading tools...</div>
) : tools.length === 0 ? (
<div className="text-slate-500 text-sm">No tools installed yet.</div>
) : (
<div className="grid grid-cols-1 md:grid-cols-2 gap-4">
{tools.map((tool) => (
<div key={tool} className="p-4 border border-slate-200 rounded-xl flex items-center justify-between hover:shadow-sm transition-shadow">
<div className="flex items-center">
<div className="w-10 h-10 bg-purple-50 rounded-lg flex items-center justify-center mr-3">
<Package size={20} className="text-purple-600" />
</div>
<span className="font-medium text-slate-800">{tool}</span>
</div>
</div>
))}
</div>
)}
</div>
</div>
</div>
);
}

View File

@ -53,9 +53,12 @@ async def install_skill(skill: Skill,
_: TokenData = Depends(RoleChecker(allowed_roles=UserAuthority.USER))): _: TokenData = Depends(RoleChecker(allowed_roles=UserAuthority.USER))):
global_state_machine = ray_actor_hook("global_state_machine").global_state_machine global_state_machine = ray_actor_hook("global_state_machine").global_state_machine
# noinspection PyUnresolvedReferences # noinspection PyUnresolvedReferences
import os
skill_output_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "plugin", "skill"))
os.makedirs(skill_output_dir, exist_ok=True)
await viceroy.install_skill_async(url = skill.repo_url, await viceroy.install_skill_async(url = skill.repo_url,
path = skill.path, path = skill.path,
output = "./pretor/plugin/tool_plugin") output = skill_output_dir)
if skill.path: if skill.path:
skill_name = skill.path.split("/")[-1] skill_name = skill.path.split("/")[-1]
else: else:
@ -75,3 +78,9 @@ async def delete_skill(skill_name: str, _: TokenData = Depends(RoleChecker(allow
# Note: this only removes it from the state machine manager. # Note: this only removes it from the state machine manager.
await global_state_machine.remove_skill.remote( skill_name) await global_state_machine.remove_skill.remote( skill_name)
return {"message": "success"} return {"message": "success"}
@resource_router.get("/tool")
async def get_tools(_: TokenData = Depends(RoleChecker(allowed_roles=UserAuthority.USER))):
global_state_machine = ray_actor_hook("global_state_machine").global_state_machine
tools = await global_state_machine.get_tool_list.remote("default")
return {"tools": list(tools.keys())}

View File

@ -24,7 +24,8 @@ class GlobalSkillManager:
def __init__(self): def __init__(self):
self.skill_mapper = defaultdict(tuple) self.skill_mapper = defaultdict(tuple)
skill_plugin_dir = pathlib.Path(__file__).parent.parent.parent / "plugin" / "skill_plugin" import os
skill_plugin_dir = pathlib.Path(os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "..", "plugin", "skill")))
if not skill_plugin_dir.exists() or not skill_plugin_dir.is_dir(): if not skill_plugin_dir.exists() or not skill_plugin_dir.is_dir():
return return
for item in skill_plugin_dir.iterdir(): for item in skill_plugin_dir.iterdir():
@ -46,7 +47,8 @@ class GlobalSkillManager:
def add_skill(self, skill_name: str) -> None: def add_skill(self, skill_name: str) -> None:
"""Add a skill to the manager by reading its skill.json from the path""" """Add a skill to the manager by reading its skill.json from the path"""
skill_plugin_dir = pathlib.Path(__file__).parent.parent.parent / "plugin" / "skill_plugin" import os
skill_plugin_dir = pathlib.Path(os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "..", "plugin", "skill")))
item = skill_plugin_dir / skill_name item = skill_plugin_dir / skill_name
if item.is_dir() and not item.name.startswith((".", "__")): if item.is_dir() and not item.name.startswith((".", "__")):
json_path = item / "skill.json" json_path = item / "skill.json"

View File

@ -0,0 +1,3 @@
from .file_reader import FileReaderData, file_reader
__all__ = ["FileReaderData", "file_reader"]

View File

@ -0,0 +1,42 @@
# Copyright 2026 zhaoxi826
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from pydantic_ai import RunContext
from pretor.plugin.tool_plugin.base_tool import BaseToolData
import os
class FileReaderData(BaseToolData):
name: str = "file_reader"
description: str = "读取本地文件的内容"
def file_reader(ctx: RunContext, filepath: str) -> str:
"""读取本地文件内容的工具。
Args:
filepath: 目标文件的绝对路径或相对路径
Returns:
如果文件存在并可读返回文件内容否则返回错误信息
"""
if not os.path.exists(filepath):
return f"Error: 文件 {filepath} 不存在。"
if not os.path.isfile(filepath):
return f"Error: {filepath} 不是一个文件。"
try:
with open(filepath, 'r', encoding='utf-8') as f:
content = f.read()
return content
except Exception as e:
return f"Error: 读取文件失败,原因:{str(e)}"

View File

@ -0,0 +1,11 @@
from pretor.worker_individual.base_individual import BaseIndividual
from pretor.worker_individual.skill_individual import SkillIndividual
from pretor.worker_individual.ordinary_individual import OrdinaryIndividual
from pretor.worker_individual.special_individual import SpecialIndividual
__all__ = [
"BaseIndividual",
"SkillIndividual",
"OrdinaryIndividual",
"SpecialIndividual",
]

View File

@ -0,0 +1,70 @@
# Copyright 2026 zhaoxi826
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from pydantic_ai import Agent, RunContext
from pydantic import Field
from pretor.adapter.model_adapter.agent_factory import AgentFactory
from pretor.core.global_state_machine.model_provider.base_provider import Provider
from pretor.utils.agent_model import ResponseModel, InputModel, DepsModel
from pretor.utils.ray_hook import ray_actor_hook
from pretor.utils.logger import get_logger
logger = get_logger('worker_individual')
class WorkerIndividualResponse(ResponseModel):
output: str = Field(..., description="Worker执行任务的输出结果")
class WorkerIndividualDeps(DepsModel):
task_event: dict
class WorkerIndividualInput(InputModel):
task_event: dict
class BaseIndividual:
"""
Worker Individual 的基类
"""
def __init__(self, agent_config: dict):
self.agent_config = agent_config
self.agent_id = agent_config.get("agent_id")
self.agent: Agent | None = None
async def _init_agent(self, agent_name: str, system_prompt: str):
global_state_machine = ray_actor_hook("global_state_machine").global_state_machine
provider_title = self.agent_config.get("provider_title", "openai") # default fallback
model_id = self.agent_config.get("model_id", "gpt-4o") # default fallback
provider: Provider = await global_state_machine.get_provider.remote( provider_title)
agent_factory = AgentFactory()
self.agent = agent_factory.create_agent(
provider=provider,
model_id=model_id,
output_type=WorkerIndividualResponse,
system_prompt=system_prompt,
deps_type=WorkerIndividualDeps,
agent_name=agent_name
)
@self.agent.system_prompt
async def dynamic_prompt(ctx: RunContext[WorkerIndividualDeps]):
prompt = system_prompt + "\n\n"
prompt += (
f"=== 当前任务上下文 ===\n"
f"{ctx.deps.task_event}\n"
)
return prompt
async def run(self, task_event: dict) -> dict:
raise NotImplementedError("子类必须实现 run 方法")

View File

@ -0,0 +1,43 @@
# Copyright 2026 zhaoxi826
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from pretor.worker_individual.base_individual import BaseIndividual, WorkerIndividualDeps
from pretor.utils.logger import get_logger
logger = get_logger('ordinary_individual')
class OrdinaryIndividual(BaseIndividual):
"""
普通子个体普通的 agent
"""
def __init__(self, agent_config: dict):
super().__init__(agent_config)
async def run(self, task_event: dict) -> dict:
if self.agent is None:
system_prompt = self.agent_config.get("prompt", "你是一个普通的AI助手请尽力完成给定的任务。")
await self._init_agent("ordinary_individual", system_prompt)
deps = WorkerIndividualDeps(task_event=task_event)
self.agent.retries = 3
try:
result = await self.agent.run(
f"请执行以下任务:\n{task_event}",
deps=deps
)
return {"output": result.data.output}
except Exception as e:
logger.exception(f"OrdinaryIndividual {self.agent_id} 执行失败: {e}")
raise

View File

@ -0,0 +1,110 @@
# Copyright 2026 zhaoxi826
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from pretor.worker_individual.base_individual import BaseIndividual, WorkerIndividualDeps
from pretor.utils.logger import get_logger
import os
import json
from pydantic_ai import Tool
import importlib.util
logger = get_logger('skill_individual')
class SkillIndividual(BaseIndividual):
"""
专家子个体拥有专业 skill agent
"""
def __init__(self, agent_config: dict):
super().__init__(agent_config)
async def _load_skill_tools(self):
"""动态加载已绑定的 skill 工具。"""
tools = []
bound_skill = self.agent_config.get("bound_skill", "")
# bound_skill can be string or dict {"skill_name": ["file1", "file2"]}
skill_mapper = {}
if isinstance(bound_skill, str) and bound_skill:
try:
skill_mapper = json.loads(bound_skill)
except json.JSONDecodeError:
pass
elif isinstance(bound_skill, dict):
skill_mapper = bound_skill
skill_base_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "plugin", "skill"))
for skill_name, _ in skill_mapper.items():
skill_path = os.path.join(skill_base_dir, skill_name)
metadata_path = os.path.join(skill_path, "metadata.json")
if not os.path.exists(metadata_path):
continue
try:
with open(metadata_path, 'r', encoding='utf-8') as f:
metadata = json.load(f)
except Exception as e:
logger.error(f"Failed to load metadata for skill {skill_name}: {e}")
continue
if "functions" in metadata:
for func_info in metadata["functions"]:
# Ensure path is absolute
script_path = func_info.get("file_path", "")
if not os.path.isabs(script_path):
script_path = os.path.join(skill_path, script_path)
if not os.path.exists(script_path):
logger.warning(f"Skill script not found: {script_path}")
continue
func_name = func_info.get("name")
try:
# Dynamically load the python module
spec = importlib.util.spec_from_file_location(func_name, script_path)
module = importlib.util.module_from_spec(spec)
spec.loader.exec_module(module)
func = getattr(module, func_name)
if callable(func):
# Convert to PydanticAI Tool
tool = Tool(func, name=func_name, description=func_info.get("docstring", ""))
tools.append(tool)
logger.info(f"Loaded skill tool: {func_name} from {skill_name}")
except Exception as e:
logger.error(f"Failed to load function {func_name} from {script_path}: {e}")
return tools
async def run(self, task_event: dict) -> dict:
if self.agent is None:
system_prompt = self.agent_config.get("prompt",
"你是一个拥有专业技能的专家级AI助手请利用你的专业知识完成给定的任务。")
await self._init_agent("skill_individual", system_prompt)
deps = WorkerIndividualDeps(task_event=task_event)
self.agent.retries = 3
tools = await self._load_skill_tools()
try:
result = await self.agent.run(
f"请执行以下任务:\n{task_event}",
deps=deps,
tools=tools if tools else None
)
return {"output": result.data.output}
except Exception as e:
logger.exception(f"SkillIndividual {self.agent_id} 执行失败: {e}")
raise

View File

@ -0,0 +1,43 @@
# Copyright 2026 zhaoxi826
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from pretor.worker_individual.base_individual import BaseIndividual, WorkerIndividualDeps
from pretor.utils.logger import get_logger
logger = get_logger('special_individual')
class SpecialIndividual(BaseIndividual):
"""
特殊子个体执行特殊任务的 agent如生成语音视频等
"""
def __init__(self, agent_config: dict):
super().__init__(agent_config)
async def run(self, task_event: dict) -> dict:
if self.agent is None:
system_prompt = self.agent_config.get("prompt", "你是一个特殊的AI助手负责处理特殊类型的任务。")
await self._init_agent("special_individual", system_prompt)
deps = WorkerIndividualDeps(task_event=task_event)
self.agent.retries = 3
try:
result = await self.agent.run(
f"请执行以下任务:\n{task_event}",
deps=deps
)
return {"output": result.data.output}
except Exception as e:
logger.exception(f"SpecialIndividual {self.agent_id} 执行失败: {e}")
raise

View File

@ -18,8 +18,10 @@ import asyncio
from collections import OrderedDict from collections import OrderedDict
from ray.util.queue import Queue from ray.util.queue import Queue
from pretor.utils.ray_hook import ray_actor_hook from pretor.utils.ray_hook import ray_actor_hook
from pretor.worker_individual.worker_individual import BaseIndividual, SkillIndividual, OrdinaryIndividual, \ from pretor.worker_individual.base_individual import BaseIndividual
SpecialIndividual from pretor.worker_individual.skill_individual import SkillIndividual
from pretor.worker_individual.ordinary_individual import OrdinaryIndividual
from pretor.worker_individual.special_individual import SpecialIndividual
from pretor.utils.logger import get_logger from pretor.utils.logger import get_logger

View File

@ -1,157 +0,0 @@
# Copyright 2026 zhaoxi826
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from pydantic_ai import Agent, RunContext
from pydantic import Field
from pretor.adapter.model_adapter.agent_factory import AgentFactory
from pretor.core.global_state_machine.model_provider.base_provider import Provider
from pretor.utils.agent_model import ResponseModel, InputModel, DepsModel
from pretor.utils.ray_hook import ray_actor_hook
from pretor.utils.logger import get_logger
logger = get_logger('worker_individual')
class WorkerIndividualResponse(ResponseModel):
output: str = Field(..., description="Worker执行任务的输出结果")
class WorkerIndividualDeps(DepsModel):
task_event: dict
class WorkerIndividualInput(InputModel):
task_event: dict
class BaseIndividual:
"""
Worker Individual 的基类
"""
def __init__(self, agent_config: dict):
self.agent_config = agent_config
self.agent_id = agent_config.get("agent_id")
self.agent: Agent | None = None
async def _init_agent(self, agent_name: str, system_prompt: str):
global_state_machine = ray_actor_hook("global_state_machine").global_state_machine
provider_title = self.agent_config.get("provider_title", "openai") # default fallback
model_id = self.agent_config.get("model_id", "gpt-4o") # default fallback
provider: Provider = await global_state_machine.get_provider.remote( provider_title)
agent_factory = AgentFactory()
self.agent = agent_factory.create_agent(
provider=provider,
model_id=model_id,
output_type=WorkerIndividualResponse,
system_prompt=system_prompt,
deps_type=WorkerIndividualDeps,
agent_name=agent_name
)
@self.agent.system_prompt
async def dynamic_prompt(ctx: RunContext[WorkerIndividualDeps]):
prompt = system_prompt + "\n\n"
prompt += (
f"=== 当前任务上下文 ===\n"
f"{ctx.deps.task_event}\n"
)
return prompt
async def run(self, task_event: dict) -> dict:
raise NotImplementedError("子类必须实现 run 方法")
class SkillIndividual(BaseIndividual):
"""
专家子个体拥有专业 skill agent
"""
def __init__(self, agent_config: dict):
super().__init__(agent_config)
async def run(self, task_event: dict) -> dict:
if self.agent is None:
system_prompt = self.agent_config.get("prompt",
"你是一个拥有专业技能的专家级AI助手请利用你的专业知识完成给定的任务。")
await self._init_agent("skill_individual", system_prompt)
deps = WorkerIndividualDeps(task_event=task_event)
self.agent.retries = 3
# In actual usage, tools could be dynamically loaded here based on agent_config
# tool = get_tool("skill_individual")
try:
result = await self.agent.run(
f"请执行以下任务:\n{task_event}",
deps=deps
# tools=tool
)
return {"output": result.data.output}
except Exception as e:
logger.exception(f"SkillIndividual {self.agent_id} 执行失败: {e}")
raise
class OrdinaryIndividual(BaseIndividual):
"""
普通子个体普通的 agent
"""
def __init__(self, agent_config: dict):
super().__init__(agent_config)
async def run(self, task_event: dict) -> dict:
if self.agent is None:
system_prompt = self.agent_config.get("prompt", "你是一个普通的AI助手请尽力完成给定的任务。")
await self._init_agent("ordinary_individual", system_prompt)
deps = WorkerIndividualDeps(task_event=task_event)
self.agent.retries = 3
try:
result = await self.agent.run(
f"请执行以下任务:\n{task_event}",
deps=deps
)
return {"output": result.data.output}
except Exception as e:
logger.exception(f"OrdinaryIndividual {self.agent_id} 执行失败: {e}")
raise
class SpecialIndividual(BaseIndividual):
"""
特殊子个体执行特殊任务的 agent如生成语音视频等
"""
def __init__(self, agent_config: dict):
super().__init__(agent_config)
async def run(self, task_event: dict) -> dict:
if self.agent is None:
system_prompt = self.agent_config.get("prompt", "你是一个特殊的AI助手负责处理特殊类型的任务。")
await self._init_agent("special_individual", system_prompt)
deps = WorkerIndividualDeps(task_event=task_event)
self.agent.retries = 3
try:
result = await self.agent.run(
f"请执行以下任务:\n{task_event}",
deps=deps
)
return {"output": result.data.output}
except Exception as e:
logger.exception(f"SpecialIndividual {self.agent_id} 执行失败: {e}")
raise

View File

@ -16,6 +16,7 @@ dependencies = [
"pwdlib[argon2,bcrypt]>=0.3.0", "pwdlib[argon2,bcrypt]>=0.3.0",
"pydantic-ai>=1.73.0", "pydantic-ai>=1.73.0",
"pyfiglet>=1.0.4", "pyfiglet>=1.0.4",
"pyjwt>=2.12.1",
"python-ulid>=3.1.0", "python-ulid>=3.1.0",
"ray[default,serve]>=2.54.0", "ray[default,serve]>=2.54.0",
"rich>=14.3.3", "rich>=14.3.3",

View File

@ -50,23 +50,23 @@ def gsm(mock_postgres):
def test_add_delete_get_event(gsm): def test_add_delete_get_event(gsm):
event = MagicMock(spec=PretorEvent) event = MagicMock(spec=PretorEvent)
event.event_id = 123 event.trace_id = "123"
gsm.add_event(event) gsm.add_event(event)
assert getattr(event, 'pending_queue', None) is not None assert getattr(event, 'pending_queue', None) is not None
assert getattr(event, 'receive_queue', None) is not None assert getattr(event, 'receive_queue', None) is not None
retrieved = gsm.get_event(123) retrieved = gsm.get_event("123")
assert retrieved == event assert retrieved == event
gsm.delete_event(123) gsm.delete_event("123")
assert gsm.get_event(123) is None assert gsm.get_event("123") is None
def test_update_attachment_and_workflow(gsm): def test_update_attachment_and_workflow(gsm):
event = MagicMock(spec=PretorEvent) event = MagicMock(spec=PretorEvent)
event.event_id = "abc" event.trace_id = "abc"
gsm.add_event(event) gsm.add_event(event)
gsm.update_attachment("abc", {"k": "v"}) gsm.update_attachment("abc", {"k": "v"})
@ -80,7 +80,7 @@ def test_update_attachment_and_workflow(gsm):
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_queues(gsm): async def test_queues(gsm):
event = MagicMock(spec=PretorEvent) event = MagicMock(spec=PretorEvent)
event.event_id = "q_event" event.trace_id = "q_event"
# To use await put/get, we must actually use real asyncio queues for the mock event # To use await put/get, we must actually use real asyncio queues for the mock event
event.pending_queue = asyncio.Queue() event.pending_queue = asyncio.Queue()
event.receive_queue = asyncio.Queue() event.receive_queue = asyncio.Queue()

View File

@ -10,41 +10,43 @@ def test_worker_group():
def test_work_step(): def test_work_step():
ws = WorkStep( ws = WorkStep(
step=1, step=1,
name="step1",
node="control_node", node="control_node",
action="coding", action="coding",
desc="Write some code" desc="Write some code"
) )
assert ws.step == 1 assert ws.step == 1
assert ws.name == "step1"
assert ws.node == "control_node" assert ws.node == "control_node"
assert ws.action == "coding" assert ws.action == "coding"
assert ws.desc == "Write some code" assert ws.desc == "Write some code"
assert ws.status == "waiting" assert ws.status == "waiting"
def test_pretor_workflow_validation_success(): def test_pretor_workflow_validation_success():
ws1 = WorkStep(step=1, node="control_node", action="a1", desc="d1") ws1 = WorkStep(step=1, name="s1", node="control_node", action="a1", desc="d1")
ws2 = WorkStep(step=2, node="supervisory_node", action="a2", desc="d2") ws2 = WorkStep(step=2, name="s2", node="supervisory_node", action="a2", desc="d2")
wg = WorkerGroup(name="g1", primary_individual={"coder": 1}, composite_individual={}) wg = WorkerGroup(name="g1", primary_individual={"coder": 1}, composite_individual={})
wf = PretorWorkflow(title="wf1", workgroup_list=[wg], work_link=[ws1, ws2], trace_id="t", event_info={"platform":"a", "user_name":"b"}) wf = PretorWorkflow(title="wf1", workgroup_list=[wg], work_link=[ws1, ws2], trace_id="t", event_info={"platform":"a", "user_name":"b"})
assert wf.title == "wf1" assert wf.title == "wf1"
def test_pretor_workflow_validation_error_step_discontinuous(): def test_pretor_workflow_validation_error_step_discontinuous():
ws1 = WorkStep(step=1, node="control_node", action="a1", desc="d1") ws1 = WorkStep(step=1, name="s1", node="control_node", action="a1", desc="d1")
ws2 = WorkStep(step=3, node="supervisory_node", action="a2", desc="d2") ws2 = WorkStep(step=3, name="s3", node="supervisory_node", action="a2", desc="d2")
wg = WorkerGroup(name="g1", primary_individual={}, composite_individual={}) wg = WorkerGroup(name="g1", primary_individual={}, composite_individual={})
with pytest.raises(ValueError, match="工作链步数不连续"): with pytest.raises(ValueError, match="工作链步数不连续"):
PretorWorkflow(title="wf1", workgroup_list=[wg], work_link=[ws1, ws2], trace_id="t", event_info={"platform":"a", "user_name":"b"}) PretorWorkflow(title="wf1", workgroup_list=[wg], work_link=[ws1, ws2], trace_id="t", event_info={"platform":"a", "user_name":"b"})
def test_pretor_workflow_validation_error_jump_out_of_bounds(): def test_pretor_workflow_validation_error_jump_out_of_bounds():
lg = LogicGate(if_fail="jump_to_step_3", if_pass="continue") lg = LogicGate(if_fail="jump_to_step_3", if_pass="continue")
ws1 = WorkStep(step=1, node="control_node", action="a1", desc="d1", logic_gate=lg) ws1 = WorkStep(step=1, name="s1", node="control_node", action="a1", desc="d1", logic_gate=lg)
ws2 = WorkStep(step=2, node="supervisory_node", action="a2", desc="d2") ws2 = WorkStep(step=2, name="s2", node="supervisory_node", action="a2", desc="d2")
wg = WorkerGroup(name="g1", primary_individual={}, composite_individual={}) wg = WorkerGroup(name="g1", primary_individual={}, composite_individual={})
with pytest.raises(ValueError, match="跳转目标 Step 3 越界了"): with pytest.raises(ValueError, match="跳转目标 Step 3 越界了"):
PretorWorkflow(title="wf1", workgroup_list=[wg], work_link=[ws1, ws2], trace_id="t", event_info={"platform":"a", "user_name":"b"}) PretorWorkflow(title="wf1", workgroup_list=[wg], work_link=[ws1, ws2], trace_id="t", event_info={"platform":"a", "user_name":"b"})
def test_pretor_workflow_validation_error_jump_format_error(): def test_pretor_workflow_validation_error_jump_format_error():
lg = LogicGate(if_fail="jump_to_step_invalid", if_pass="continue") lg = LogicGate(if_fail="jump_to_step_invalid", if_pass="continue")
ws1 = WorkStep(step=1, node="control_node", action="a1", desc="d1", logic_gate=lg) ws1 = WorkStep(step=1, name="s1", node="control_node", action="a1", desc="d1", logic_gate=lg)
wg = WorkerGroup(name="g1", primary_individual={}, composite_individual={}) wg = WorkerGroup(name="g1", primary_individual={}, composite_individual={})
with pytest.raises(ValueError, match="LogicGate 格式错误"): with pytest.raises(ValueError, match="LogicGate 格式错误"):
PretorWorkflow(title="wf1", workgroup_list=[wg], work_link=[ws1], trace_id="t", event_info={"platform":"a", "user_name":"b"}) PretorWorkflow(title="wf1", workgroup_list=[wg], work_link=[ws1], trace_id="t", event_info={"platform":"a", "user_name":"b"})
@ -53,4 +55,3 @@ def test_workflow_status():
status = WorkflowStatus() status = WorkflowStatus()
assert status.step == 1 assert status.step == 1
assert status.status == "waiting_llm_working" assert status.status == "waiting_llm_working"
assert status.demand is None

View File

@ -3074,6 +3074,7 @@ dependencies = [
{ name = "pydantic-ai", version = "1.75.0", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version >= '3.14'" }, { name = "pydantic-ai", version = "1.75.0", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version >= '3.14'" },
{ name = "pydantic-ai", version = "1.84.0", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version < '3.14'" }, { name = "pydantic-ai", version = "1.84.0", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version < '3.14'" },
{ name = "pyfiglet" }, { name = "pyfiglet" },
{ name = "pyjwt" },
{ name = "python-ulid" }, { name = "python-ulid" },
{ name = "ray", extra = ["default", "serve"] }, { name = "ray", extra = ["default", "serve"] },
{ name = "rich" }, { name = "rich" },
@ -3105,6 +3106,7 @@ requires-dist = [
{ name = "pwdlib", extras = ["argon2", "bcrypt"], specifier = ">=0.3.0" }, { name = "pwdlib", extras = ["argon2", "bcrypt"], specifier = ">=0.3.0" },
{ name = "pydantic-ai", specifier = ">=1.73.0" }, { name = "pydantic-ai", specifier = ">=1.73.0" },
{ name = "pyfiglet", specifier = ">=1.0.4" }, { name = "pyfiglet", specifier = ">=1.0.4" },
{ name = "pyjwt", specifier = ">=2.12.1" },
{ name = "python-ulid", specifier = ">=3.1.0" }, { name = "python-ulid", specifier = ">=3.1.0" },
{ name = "ray", extras = ["default", "serve"], specifier = ">=2.54.0" }, { name = "ray", extras = ["default", "serve"], specifier = ">=2.54.0" },
{ name = "rich", specifier = ">=14.3.3" }, { name = "rich", specifier = ">=14.3.3" },