chore: initial commit for Pretor v0.1.0-alpha

正式发布 Pretor 平台的首个 alpha 版本。本项目旨在构建一个基于分布式架构的多智能体协同工作流水线。

核心功能实现:
1. 建立基于 BaseIndividual 的动态插件加载机制。
2. 实现三类核心 worker_individual 子个体。
3. 集成 Ray 框架支持分布式集群调度。
4. 基于 PostgreSQL 的全量持久化存储方案。
5. 提供完整的 FastAPI 后端与 React 前端交互界面。
This commit is contained in:
2026-04-29 10:09:07 +08:00
commit d84212f780
163 changed files with 19251 additions and 0 deletions
+14
View File
@@ -0,0 +1,14 @@
# Copyright 2026 zhaoxi826
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
+14
View File
@@ -0,0 +1,14 @@
# Copyright 2026 zhaoxi826
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
+14
View File
@@ -0,0 +1,14 @@
# Copyright 2026 zhaoxi826
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
@@ -0,0 +1,77 @@
# Copyright 2026 zhaoxi826
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from pydantic_ai import Agent
from pydantic_ai.models.openai import OpenAIChatModel
from pydantic_ai.models.anthropic import AnthropicModel
from pydantic_ai.providers.openai import OpenAIProvider
from pydantic_ai.providers.anthropic import AnthropicProvider
from pretor.adapter.model_adapter.deepseek_reasoner import DeepSeekReasonerAgent
from pretor.core.global_state_machine.model_provider import Provider
from pretor.utils.agent_model import ResponseModel, DepsModel
from pretor.utils.error import ModelNotExistError
class AgentFactory:
def __init__(self):
self._models_mapping = {"openai": (OpenAIChatModel, OpenAIProvider),
"claude": (AnthropicModel, AnthropicProvider),
"deepseek": (OpenAIChatModel, OpenAIProvider),}
def create_agent(self,
provider: Provider,
model_id: str,
output_type: ResponseModel,
system_prompt: str,
deps_type: DepsModel,
agent_name: str,
tools: list = None) -> Agent:
"""
create_agent方法,将输入的provider对象实例化为一个pydantic-ai的agent对象
Args:
provider: Provider对象,从global_state_machine中获取
model_id: 模型名
output_type: 输出格式
system_prompt: 系统提示词
deps_type: 依赖类型,在agent运行时动态输入的格式化消息
agent_name: agent的名字
tools: 工具列表
Returns:
返回被实例化的pydantic-ai的Agent对象
"""
if model_id not in provider.provider_models:
raise ModelNotExistError("模型不存在")
if provider.provider_type not in self._models_mapping:
raise ValueError(f"不支持的协议类型: {provider.provider_type}")
model_class, provider_class = self._models_mapping[provider.provider_type]
model = model_class(model_id, provider=provider_class(api_key=provider.provider_apikey, base_url=provider.provider_url))
match provider.provider_type:
case "deepseek":
agent = DeepSeekReasonerAgent(model=model,
name=agent_name,
output_type=output_type,
deps_type=deps_type,
system_prompt=system_prompt,
tools=tools,
retries=3,
)
case _:
agent = Agent(model=model,
name=agent_name,
system_prompt=system_prompt,
output_type=output_type,
deps_type=deps_type,
tools=tools)
return agent
@@ -0,0 +1,150 @@
import re
import json
from typing import Type, TypeVar, Any, Generic
from pydantic import BaseModel, ValidationError
from pydantic_ai import Agent, RunContext
from pydantic_ai.run import AgentRunResult
T = TypeVar('T', bound=BaseModel)
class AgentRunResultProxy:
def __init__(self, original, parsed):
self._original = original
self._parsed = parsed
def __getattr__(self, name):
if name == 'data':
return self._parsed
if name == 'output':
return self._parsed
return getattr(self._original, name)
class DeepSeekReasonerAgent(Generic[T]):
"""
专为 DeepSeek-V4/R1 设计的适配器。
将结构化输出降级为文本解析模式,并支持重试逻辑以确保系统兼容性。
"""
def __init__(
self,
model,
name,
output_type: Any = str,
system_prompt: str = "",
deps_type: Type[Any] = None,
tools: list = None,
retries: int = 3,
**kwargs
):
self.output_schema = output_type
self.has_custom_output = output_type is not str and output_type is not None
self.tools = tools or []
self.retries = retries
format_instruction = ""
if self.has_custom_output:
try:
from pydantic import TypeAdapter
schema_dict = TypeAdapter(self.output_schema).json_schema()
schema_str = json.dumps(schema_dict, ensure_ascii=False)
format_instruction = (
f"\n\nCRITICAL: 你必须输出且只能输出一段纯 JSON 格式的数据,"
f"并包裹在 ```json 和 ``` 之间。格式必须符合以下 JSON Schema 结构(或对应数据类型):\n"
f"{schema_str}"
)
except Exception:
pass
tool_instruction = ""
if self.tools:
tool_descs = []
for t in self.tools:
desc = getattr(t, '__name__', str(t))
if hasattr(t, '__doc__') and t.__doc__:
desc += f": {t.__doc__.strip()}"
tool_descs.append(f"- {desc}")
tool_instruction = (
"\n\n系统为您提供了以下工具。由于当前处于结构化降级模式,无法原生调用。"
"但如果您在思考过程中判断必须使用这些工具,请在返回的结构中(或如果是自由文本)注明意图,由外层逻辑进行调度:\n" +
"\n".join(tool_descs)
)
self.agent = Agent(
model=model,
name=name,
output_type=str, # Force native agent to return str to disable function calling
system_prompt=system_prompt + format_instruction + tool_instruction,
deps_type=deps_type,
**kwargs
)
def _parse_output(self, text: str) -> Any:
if not self.has_custom_output:
return text
match = re.search(r'```json\s*(.*?)\s*```', text, re.DOTALL)
json_str = match.group(1).strip() if match else text
if not json_str.startswith('{') and not json_str.startswith('['):
start_obj = json_str.find('{')
start_arr = json_str.find('[')
start = -1
end = -1
if start_obj != -1 and (start_arr == -1 or start_obj < start_arr):
start = start_obj
end = json_str.rfind('}')
elif start_arr != -1:
start = start_arr
end = json_str.rfind(']')
if start != -1 and end != -1 and end > start:
json_str = json_str[start:end+1]
if not json_str:
raise ValueError("未找到有效的 JSON 块。请将结果包装在 ```json 中。")
try:
from pydantic import TypeAdapter
adapter = TypeAdapter(self.output_schema)
return adapter.validate_json(json_str)
except ValidationError as e:
raise ValueError(f"返回的 JSON 无法匹配所需结构:{e}")
except json.JSONDecodeError as e:
raise ValueError(f"返回的不是合法的 JSON{e}")
def __getattr__(self, item):
# Delegate any unknown attributes (like .system_prompt, .tool) to the underlying pydantic_ai Agent
return getattr(self.agent, item)
async def run(self, user_prompt: str, deps: Any = None, message_history: list = None, **kwargs) -> Any:
# Custom retry loop
current_history = message_history or []
last_exception = None
for attempt in range(self.retries + 1):
result = await self.agent.run(
user_prompt,
deps=deps,
message_history=current_history,
**kwargs
)
raw_text = result.data if hasattr(result, 'data') else getattr(result, 'output', str(result))
try:
parsed_data = self._parse_output(raw_text)
# Proxy the result to inject the parsed data seamlessly
return AgentRunResultProxy(result, parsed_data)
except ValueError as e:
last_exception = e
# Prepare retry prompt
user_prompt = f"你的上一次输出解析失败,错误原因是: {e}\n请修正格式后重新输出。"
# We need to maintain history manually so the model sees what it did wrong
# Actually, pydantic-ai manages history inside the result. Let's use the all_messages from result
if hasattr(result, 'all_messages'):
current_history = result.all_messages()
raise ValueError(f"Exceeded maximum retries ({self.retries}) for output validation. Last error: {last_exception}")
+14
View File
@@ -0,0 +1,14 @@
# Copyright 2026 zhaoxi826
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
+185
View File
@@ -0,0 +1,185 @@
# Copyright 2026 zhaoxi826
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Union
from pretor.utils.ray_hook import ray_actor_hook
from fastapi import APIRouter, Depends
from pydantic import BaseModel
from pretor.utils.access import Accessor, TokenData
from pretor.core.database.table.individual import AgentType
from fastapi import HTTPException
from typing import Optional, List, Dict
from pretor.utils.check_user.role_check import RoleChecker
from pretor.core.database.table.user import UserAuthority
agent_router = APIRouter(prefix="/api/v1/agent", tags=["agent"])
class AgentRegister(BaseModel):
provider_title: str
model_id: str
individual_name: str
tools: Optional[List[str]] = None
class AgentLocalRegister(BaseModel):
path: str
individual_name: str
tools: Optional[List[str]] = None
@agent_router.get("")
async def get_system_nodes(_: TokenData = Depends(RoleChecker(allowed_roles=UserAuthority.USER))):
postgres_database = ray_actor_hook("postgres_database").postgres_database
configs = await postgres_database.get_all_system_node_configs.remote()
return {"system_nodes": configs}
@agent_router.post("")
async def load_agent(agent_register: Union[AgentRegister, AgentLocalRegister],
_: TokenData = Depends(RoleChecker(allowed_roles=UserAuthority.USER))):
global_state_machine = ray_actor_hook("global_state_machine").global_state_machine
postgres_database = ray_actor_hook("postgres_database").postgres_database
if isinstance(agent_register, AgentLocalRegister):
pass
elif isinstance(agent_register, AgentRegister):
try:
# Persist configuration
await postgres_database.upsert_system_node_config.remote(
agent_register.individual_name,
agent_register.provider_title,
agent_register.model_id,
agent_register.tools
)
# Load agent into state machine
match agent_register.individual_name:
case "supervisory_node":
node = ray_actor_hook("supervisory_node").supervisory_node
await node.create_agent.remote(global_state_machine,agent_register.provider_title,agent_register.model_id, agent_register.tools)
case "consciousness_node":
node = ray_actor_hook("consciousness_node").consciousness_node
await node.create_agent.remote(global_state_machine,agent_register.provider_title,agent_register.model_id, agent_register.tools)
case "control_node":
node = ray_actor_hook("control_node").control_node
await node.create_agent.remote(global_state_machine,agent_register.provider_title,agent_register.model_id, agent_register.tools)
case _:
pass
except Exception as e:
raise HTTPException(status_code=500, detail=f"加载节点失败: {str(e)}")
return {"message": "创建成功"}
class WorkerIndividualCreate(BaseModel):
agent_name: str
agent_type: AgentType
description: str
provider_title: str
model_id: str
system_prompt: str
output_template: dict
bound_skill: Dict[str, List[str]]
workspace: List[str]
tools: Optional[List[str]] = None
class WorkerIndividualUpdate(BaseModel):
agent_name: Optional[str] = None
agent_type: Optional[AgentType] = None
description: Optional[str] = None
provider_title: Optional[str] = None
model_id: Optional[str] = None
system_prompt: Optional[str] = None
output_template: Optional[dict] = None
bound_skill: Optional[Dict[str, List[str]]] = None
workspace: Optional[List[str]] = None
tools: Optional[List[str]] = None
@agent_router.post("/worker")
async def create_worker_individual(worker_data: WorkerIndividualCreate,
token_data: TokenData = Depends(RoleChecker(allowed_roles=UserAuthority.USER))):
postgres_database = ray_actor_hook("postgres_database").postgres_database
data_dict = worker_data.model_dump()
data_dict["owner_id"] = token_data.user_id
worker = await postgres_database.add_worker_individual.remote( **data_dict)
return {"message": "success", "agent_id": worker.agent_id}
@agent_router.get("/worker")
async def get_worker_individual_list(token_data: TokenData = Depends(Accessor.get_current_user)):
postgres_database = ray_actor_hook("postgres_database").postgres_database
workers = await postgres_database.get_worker_individual_list.remote( owner_id=token_data.user_id)
return {"workers": workers}
@agent_router.get("/worker/{agent_id}")
async def get_worker_individual(agent_id: str,
token_data: TokenData = Depends(Accessor.get_current_user)):
postgres_database = ray_actor_hook("postgres_database").postgres_database
worker = await postgres_database.get_worker_individual.remote( agent_id=agent_id)
if not worker:
raise HTTPException(status_code=404, detail="Agent not found")
if worker.owner_id != token_data.user_id:
raise HTTPException(status_code=403, detail="Forbidden: You do not own this agent")
return worker
@agent_router.put("/worker/{agent_id}")
async def update_worker_individual(agent_id: str,
worker_data: WorkerIndividualUpdate,
token_data: TokenData = Depends(Accessor.get_current_user)):
postgres_database = ray_actor_hook("postgres_database").postgres_database
worker = await postgres_database.get_worker_individual.remote( agent_id=agent_id)
if not worker:
raise HTTPException(status_code=404, detail="Agent not found")
if worker.owner_id != token_data.user_id:
raise HTTPException(status_code=403, detail="Forbidden: You do not own this agent")
update_data = worker_data.model_dump(exclude_unset=True)
updated_worker = await postgres_database.update_worker_individual.remote( agent_id=agent_id, **update_data)
global_state_machine = ray_actor_hook("global_state_machine").global_state_machine
try:
await global_state_machine.remove_individual.remote(agent_id)
except Exception:
pass
return {"message": "success", "worker": updated_worker}
@agent_router.post("/worker/{agent_id}/reload")
async def reload_worker_individual(agent_id: str, token_data: TokenData = Depends(Accessor.get_current_user)):
postgres_database = ray_actor_hook("postgres_database").postgres_database
worker = await postgres_database.get_worker_individual.remote(agent_id=agent_id)
if not worker:
raise HTTPException(status_code=404, detail="Agent not found")
if worker.owner_id != token_data.user_id:
raise HTTPException(status_code=403, detail="Forbidden: You do not own this agent")
global_state_machine = ray_actor_hook("global_state_machine").global_state_machine
await global_state_machine.remove_individual.remote(agent_id)
return {"message": "Worker will be reloaded on next use"}
@agent_router.delete("/worker/{agent_id}")
async def delete_worker_individual(agent_id: str,
token_data: TokenData = Depends(Accessor.get_current_user)):
postgres_database = ray_actor_hook("postgres_database").postgres_database
worker = await postgres_database.get_worker_individual.remote( agent_id=agent_id)
if not worker:
raise HTTPException(status_code=404, detail="Agent not found")
if worker.owner_id != token_data.user_id:
raise HTTPException(status_code=403, detail="Forbidden: You do not own this agent")
await postgres_database.delete_worker_individual.remote( agent_id=agent_id)
return {"message": "success"}
+88
View File
@@ -0,0 +1,88 @@
# Copyright 2026 zhaoxi826
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from fastapi import APIRouter
from fastapi import Depends
from pydantic import BaseModel
from pretor.utils.access import Accessor, TokenData
from fastapi.concurrency import run_in_threadpool
from pretor.utils.ray_hook import ray_actor_hook
from pretor.utils.check_user.role_check import RoleChecker
from pretor.core.database.table.user import UserAuthority
from pretor.utils.error import UserNotExistError
auth_router = APIRouter(prefix="/api/v1/auth", tags=["auth"])
class UserRegister(BaseModel):
user_name: str
password: str
@auth_router.post("/register")
async def create_user(user_register: UserRegister):
postgres_database = ray_actor_hook("postgres_database").postgres_database
hashed_password = await run_in_threadpool(Accessor.hash_password, user_register.password)
user = await postgres_database.add_user.remote( user_register.user_name, hashed_password)
return {"message": "success", "user_id": user.user_id}
class UserLogin(BaseModel):
user_name: str
password: str
@auth_router.post("/login")
async def login_user(user_login: UserLogin):
postgres_database = ray_actor_hook("postgres_database").postgres_database
user = await postgres_database.login_user.remote( user_login.user_name)
if not user:
raise UserNotExistError()
token = await run_in_threadpool(Accessor.login_hashed_password, user, user_login.password)
return {"message":"success", "token":token}
class ChangeAuthorityRequest(BaseModel):
user_id: str
new_authority: UserAuthority
@auth_router.put("/authority")
async def change_authority(
request: ChangeAuthorityRequest,
_: TokenData = Depends(RoleChecker(allowed_roles=UserAuthority.SUPER_ADMINISTRATOR))
):
"""
Update a user's authority level. Only accessible by SUPER_ADMINISTRATOR.
"""
postgres_database = ray_actor_hook("postgres_database").postgres_database
user = await postgres_database.change_user_authority.remote( user_id=request.user_id, new_authority=request.new_authority)
return {"message": "success", "user_id": user.user_id, "new_authority": user.user_authority}
@auth_router.get("/list")
async def get_user_list(
_: TokenData = Depends(RoleChecker(allowed_roles=UserAuthority.SUPER_ADMINISTRATOR))
):
"""
Get a list of all users. Only accessible by SUPER_ADMINISTRATOR.
"""
postgres_database = ray_actor_hook("postgres_database").postgres_database
users = await postgres_database.get_all_users.remote()
return {"users": [{"user_id": u.user_id, "user_name": u.user_name, "role": u.user_authority} for u in users]}
@auth_router.delete("/{user_id}")
async def delete_user(
user_id: str,
_: TokenData = Depends(RoleChecker(allowed_roles=UserAuthority.SUPER_ADMINISTRATOR))
):
"""
Delete a user. Only accessible by SUPER_ADMINISTRATOR.
"""
postgres_database = ray_actor_hook("postgres_database").postgres_database
await postgres_database.delete_user_by_id.remote( user_id=user_id)
return {"message": "success"}
+19
View File
@@ -0,0 +1,19 @@
# Copyright 2026 zhaoxi826
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from fastapi import APIRouter
cluster_router = APIRouter(prefix="/api/v1/cluster", tags=["cluster"])
# Monitor websocket API temporarily removed
+16
View File
@@ -0,0 +1,16 @@
# Copyright 2026 zhaoxi826
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .frontend import client_router
__all__ = ["client_router"]
+36
View File
@@ -0,0 +1,36 @@
# Copyright 2026 zhaoxi826
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import datetime
from pydantic import BaseModel, Field, ConfigDict
from ulid import ULID
from typing import Any, Dict
from pretor.core.workflow.workflow import PretorWorkflow
import asyncio
class PretorEvent(BaseModel):
model_config = ConfigDict(arbitrary_types_allowed=True)
trace_id: str = Field(default_factory=lambda: str(ULID()), description="事件的唯一标识符")
platform: str = Field(description="消息来源的平台")
user_id: str = Field(description="用户id")
user_name: str = Field(description="用户名")
create_time: str = Field(default_factory=lambda: str(datetime.datetime.now(datetime.timezone.utc).isoformat()),
description="事件创建时间")
message: str = Field(description="用户发来的消息")
attachment: Dict[str, str] | None = Field(default=None,description="附件")
#--------------------------------------------------------------------------------------------------------------
context: Dict[str, Any] = Field(default_factory=dict, description="事件上下文内容,可包含工作流模板等信息")
workflow: PretorWorkflow | None = Field(default=None,description="工作流")
pending_queue: asyncio.Queue[str] | None= Field(default=None,description="待处理队列")
receive_queue: asyncio.Queue[str] | None = Field(default=None,description="待接收队列")
+63
View File
@@ -0,0 +1,63 @@
# Copyright 2026 zhaoxi826
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from fastapi import APIRouter, Depends, HTTPException, status, UploadFile, File
from pydantic import BaseModel
from pretor.utils.access import Accessor, TokenData
from pretor.api.platform.event import PretorEvent
from pretor.utils.ray_hook import ray_actor_hook
import os
import shutil
from pretor.utils.logger import get_logger
logger = get_logger('frontend')
client_router = APIRouter(prefix="/api/v1/adapter/client", tags=["client"])
class Message(BaseModel):
message: str
@client_router.post("")
async def create_message(message: Message,
token_data: TokenData = Depends(Accessor.get_current_user)):
logger.info("收到消息,来源:客户端")
logger.debug(f"消息内容:{message.message}")
event = PretorEvent(platform="client",
user_id=str(token_data.user_id),
user_name=token_data.username,
message=message.message)
supervisory_node = ray_actor_hook("supervisory_node").supervisory_node
message = await supervisory_node.working.remote(event)
if message == "任务已创建":
return {"message": event.trace_id}
elif message == "未知相应类型":
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="模型回复错误")
else:
return {"message": message}
@client_router.post("/upload")
async def upload_file(file: UploadFile = File(...),
token_data: TokenData = Depends(Accessor.get_current_user)):
try:
upload_dir = "uploads"
os.makedirs(upload_dir, exist_ok=True)
file_path = os.path.join(upload_dir, file.filename)
with open(file_path, "wb") as buffer:
shutil.copyfileobj(file.file, buffer)
logger.info(f"用户 {token_data.username} 上传了文件: {file.filename}")
return {"filename": file.filename, "message": f"File {file.filename} uploaded successfully"}
except Exception as e:
logger.error(f"文件上传失败: {e}")
raise HTTPException(status_code=500, detail="文件上传失败")
+54
View File
@@ -0,0 +1,54 @@
# Copyright 2026 zhaoxi826
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from fastapi import APIRouter, Depends
from pydantic import BaseModel
from typing import Literal
from pretor.utils.access import TokenData, Accessor
from pretor.utils.check_user.role_check import RoleChecker
from pretor.core.database.table.user import UserAuthority
from typing import Dict
from pretor.core.global_state_machine.model_provider.base_provider import Provider
from pretor.utils.ray_hook import ray_actor_hook
provider_router = APIRouter(prefix="/api/v1/provider", tags=["provider"])
class ProviderRegister(BaseModel):
provider_type: Literal["openai", "claude", "deepseek"]
provider_title: str
provider_url: str
provider_apikey: str
@provider_router.post("")
async def create_provider(provider_register: ProviderRegister,
token_data: TokenData = Depends(RoleChecker(allowed_roles=UserAuthority.USER))) -> None:
global_state_machine = ray_actor_hook("global_state_machine").global_state_machine
await global_state_machine.add_provider_wrap.remote(provider_type=provider_register.provider_type,
provider_title=provider_register.provider_title,
provider_url=provider_register.provider_url,
provider_apikey=provider_register.provider_apikey,
provider_owner=token_data.user_id)
@provider_router.get("/list")
async def get_provider_list(_: TokenData = Depends(Accessor.get_current_user)) -> Dict[str, Dict[str, Provider]]:
global_state_machine = ray_actor_hook("global_state_machine").global_state_machine
provider_list: Dict[str, Provider] = await global_state_machine.get_provider_list.remote()
return {"provider_list": provider_list}
@provider_router.delete("/{provider_title}")
async def delete_provider(provider_title: str, _: TokenData = Depends(RoleChecker(allowed_roles=UserAuthority.SUPER_ADMINISTRATOR))) -> dict:
global_state_machine = ray_actor_hook("global_state_machine").global_state_machine
await global_state_machine.delete_provider.remote(provider_title=provider_title)
return {"message": "success"}
+89
View File
@@ -0,0 +1,89 @@
# Copyright 2026 zhaoxi826
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from pydantic import BaseModel
import viceroy
from pretor.core.workflow.workflow_template_generator.workflow_template import WorkflowTemplate
from pretor.utils.ray_hook import ray_actor_hook
from fastapi import APIRouter, Depends
from pretor.utils.access import TokenData
from pretor.utils.check_user.role_check import RoleChecker
from pretor.core.database.table.user import UserAuthority
resource_router = APIRouter(prefix="/api/v1/resource")
@resource_router.post("/workflow_template")
async def create_workflow_template(workflow_template: WorkflowTemplate,
_: TokenData = Depends(RoleChecker(allowed_roles=UserAuthority.USER))):
global_state_machine = ray_actor_hook("global_state_machine").global_state_machine
await global_state_machine.add_workflow_template.remote( workflow_template.name, workflow_template)
return {"message": "创建成功"}
@resource_router.get("/workflow_template")
async def get_workflow_templates(_: TokenData = Depends(RoleChecker(allowed_roles=UserAuthority.USER))):
global_state_machine = ray_actor_hook("global_state_machine").global_state_machine
templates = await global_state_machine.get_all_workflow_templates.remote()
return {"templates": templates}
@resource_router.delete("/workflow_template/{template_name}")
async def delete_workflow_template(template_name: str, _: TokenData = Depends(RoleChecker(allowed_roles=UserAuthority.SUPER_ADMINISTRATOR))):
global_state_machine = ray_actor_hook("global_state_machine").global_state_machine
await global_state_machine.delete_workflow_template.remote( template_name)
return {"message": "success"}
class Skill(BaseModel):
repo_url: str
path: str | None
@resource_router.post("/skill")
async def install_skill(skill: Skill,
_: TokenData = Depends(RoleChecker(allowed_roles=UserAuthority.USER))):
global_state_machine = ray_actor_hook("global_state_machine").global_state_machine
# noinspection PyUnresolvedReferences
import os
skill_output_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "plugin", "skill"))
os.makedirs(skill_output_dir, exist_ok=True)
await viceroy.install_skill_async(url = skill.repo_url,
path = skill.path,
output = skill_output_dir)
if skill.path:
skill_name = skill.path.split("/")[-1]
else:
skill_name = skill.repo_url.split("/")[-1]
await global_state_machine.add_skill.remote( skill_name)
return {"message": "创建成功"}
@resource_router.get("/skill")
async def get_skills(_: TokenData = Depends(RoleChecker(allowed_roles=UserAuthority.USER))):
global_state_machine = ray_actor_hook("global_state_machine").global_state_machine
skills = await global_state_machine.get_skill_list.remote()
return {"skills": skills}
@resource_router.delete("/skill/{skill_name}")
async def delete_skill(skill_name: str, _: TokenData = Depends(RoleChecker(allowed_roles=UserAuthority.SUPER_ADMINISTRATOR))):
global_state_machine = ray_actor_hook("global_state_machine").global_state_machine
# Note: this only removes it from the state machine manager.
await global_state_machine.remove_skill.remote( skill_name)
return {"message": "success"}
@resource_router.get("/tool")
async def get_tools(_: TokenData = Depends(RoleChecker(allowed_roles=UserAuthority.USER))):
global_state_machine = ray_actor_hook("global_state_machine").global_state_machine
tool_mapper = await global_state_machine.get_tool_mapper.remote()
all_tool_names = set()
for scope_tools in tool_mapper.values():
all_tool_names.update(scope_tools.keys())
return {"tools": list(all_tool_names)}
+99
View File
@@ -0,0 +1,99 @@
# Copyright 2026 zhaoxi826
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from pretor.utils.ray_hook import ray_actor_hook
from fastapi import APIRouter, Request, HTTPException
from fastapi.responses import StreamingResponse
import asyncio
workflow_router = APIRouter(prefix="/api/v1/workflow", tags=["workflow"])
@workflow_router.get("/list")
async def get_workflow_list():
global_state_machine = ray_actor_hook("global_state_machine").global_state_machine
events = await global_state_machine.list_events.remote()
return events
@workflow_router.get("/{trace_id}")
async def get_workflow_detail(trace_id: str):
global_state_machine = ray_actor_hook("global_state_machine").global_state_machine
event = await global_state_machine.get_event.remote(trace_id)
if not event:
raise HTTPException(status_code=404, detail="Workflow not found")
workflow = event.workflow
if not workflow:
return {
"event_id": trace_id,
"workflow_title": None,
"status": "waiting",
"user_name": event.user_name,
"message": event.message,
"create_time": event.create_time,
"steps": [],
}
steps = []
for step in workflow.work_link:
steps.append({
"step": step.step,
"name": step.name,
"node": step.node,
"action": step.action,
"desc": step.desc,
"status": step.status,
"agent_id": step.agent_id,
})
return {
"event_id": trace_id,
"workflow_title": workflow.title,
"status": workflow.status.status,
"command": workflow.command,
"current_step": workflow.status.step,
"user_name": event.user_name,
"message": event.message,
"create_time": event.create_time,
"steps": steps,
}
@workflow_router.get("/sse/{trace_id}")
async def get_workflow_sse(trace_id: str, request: Request):
global_state_machine = ray_actor_hook("global_state_machine").global_state_machine
async def event_generator():
try:
while True:
if await request.is_disconnected():
break
# You might also want to send the workflow state periodically or when updated
# Here we just wait for pending messages and send them
message = await global_state_machine.get_pending.remote(trace_id)
# Ensure the message is formatted as SSE
yield f"data: {message}\n\n"
except asyncio.CancelledError:
pass
return StreamingResponse(event_generator(), media_type="text/event-stream")
@workflow_router.post("/reply/{trace_id}")
async def post_workflow_reply(trace_id: str, request: Request):
data = await request.json()
reply_msg = data.get("message", "")
global_state_machine = ray_actor_hook("global_state_machine").global_state_machine
await global_state_machine.put_received.remote(trace_id, reply_msg)
return {"status": "ok"}
+14
View File
@@ -0,0 +1,14 @@
# Copyright 2026 zhaoxi826
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
+127
View File
@@ -0,0 +1,127 @@
# Copyright 2026 zhaoxi826
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
from typing import Dict
from fastapi import FastAPI, WebSocket, Request
from fastapi.staticfiles import StaticFiles
from fastapi.responses import FileResponse, JSONResponse
from pretor.api.platform.frontend import client_router
from pretor.api.auth import auth_router
from pretor.api.provider import provider_router
from pretor.api.resource import resource_router
from pretor.api.cluster import cluster_router
from pretor.api.agent import agent_router
from pretor.utils.error import (
DemandError, ModelNotExistError, UserError,
UserNotExistError, UserPasswordError, ProviderError,
ProviderNotExistError, WorkflowError, WorkflowExit
)
from ray import serve
app = FastAPI()
app.include_router(client_router) # 客户端路径
app.include_router(auth_router) # 用户路径
app.include_router(provider_router) # 供应商路径
app.include_router(resource_router) # 资源路径
app.include_router(cluster_router) # 集群信息路径
app.include_router(agent_router) # agent路径
@app.exception_handler(UserNotExistError)
async def user_not_exist_handler(request: Request, exc: UserNotExistError):
return JSONResponse(status_code=404, content={"message": "用户不存在"})
@app.exception_handler(UserPasswordError)
async def user_password_handler(request: Request, exc: UserPasswordError):
return JSONResponse(status_code=401, content={"message": "密码错误"})
@app.exception_handler(UserError)
async def user_error_handler(request: Request, exc: UserError):
return JSONResponse(status_code=400, content={"message": "用户相关错误"})
@app.exception_handler(ProviderNotExistError)
async def provider_not_exist_handler(request: Request, exc: ProviderNotExistError):
return JSONResponse(status_code=404, content={"message": "服务提供商不存在"})
@app.exception_handler(ProviderError)
async def provider_error_handler(request: Request, exc: ProviderError):
return JSONResponse(status_code=400, content={"message": "服务提供商错误"})
@app.exception_handler(ModelNotExistError)
async def model_not_exist_handler(request: Request, exc: ModelNotExistError):
return JSONResponse(status_code=404, content={"message": "模型不存在"})
@app.exception_handler(DemandError)
async def demand_error_handler(request: Request, exc: DemandError):
return JSONResponse(status_code=400, content={"message": "需求格式错误或不满足"})
@app.exception_handler(WorkflowExit)
async def workflow_exit_handler(request: Request, exc: WorkflowExit):
return JSONResponse(status_code=400, content={"message": "工作流已退出"})
@app.exception_handler(WorkflowError)
async def workflow_error_handler(request: Request, exc: WorkflowError):
return JSONResponse(status_code=500, content={"message": "工作流执行错误"})
base_dir = os.path.dirname(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))))
frontend_dir = os.path.join(base_dir, "frontend", "dist")
if os.path.exists(frontend_dir):
app.mount("/assets", StaticFiles(directory=os.path.join(frontend_dir, "assets")), name="assets")
@app.get("/favicon.svg", include_in_schema=False)
async def serve_favicon():
return FileResponse(os.path.join(frontend_dir, "favicon.svg"))
@app.get("/icons.svg", include_in_schema=False)
async def serve_icons():
return FileResponse(os.path.join(frontend_dir, "icons.svg"))
@app.get("/{full_path:path}", include_in_schema=False)
async def serve_frontend(full_path: str):
# 【重要安全修复】避免拦截不存在的 API 路由。如果是调用了不存在的 /api/ 接口,直接返回 404,不返回前端页面
if full_path.startswith("api/"):
return JSONResponse(status_code=404, content={"detail": "API endpoint not found"})
index_path = os.path.join(frontend_dir, "index.html")
if os.path.exists(index_path):
return FileResponse(index_path)
return JSONResponse(status_code=404, content={"detail": "Frontend build not found"})
else:
import logging
logging.getLogger("pretor").warning(f"Frontend dist folder not found at {frontend_dir}. Skipping frontend mount.")
@serve.deployment
@serve.ingress(app)
class PretorGateway:
gateway: Dict[str, WebSocket]
def __init__(self):
self.gateway = {}
+14
View File
@@ -0,0 +1,14 @@
# Copyright 2026 zhaoxi826
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
@@ -0,0 +1,39 @@
# Copyright 2026 zhaoxi826
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from sqlalchemy.exc import IntegrityError, OperationalError
from pydantic import ValidationError
from pretor.utils.error import UserNotExistError
from pretor.utils.logger import get_logger
logger = get_logger('database_exception')
def database_exception(func):
async def wrapper(*args, **kwargs):
try:
return await func(*args, **kwargs)
except ValidationError as e:
logger.error(f"对象校验失败:{e}")
raise e
except IntegrityError as e:
logger.error(f"数据库完整性错误 (如重复记录): {e}")
raise e
except OperationalError as e:
logger.error(f"数据库连接异常: {e}")
raise e
except UserNotExistError as e:
logger.error(f"更改密码失败,用户不存在:{e}")
except Exception as e:
logger.exception(f"未预期的数据库错误: {e}")
raise e
return wrapper
+14
View File
@@ -0,0 +1,14 @@
# Copyright 2026 zhaoxi826
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
+83
View File
@@ -0,0 +1,83 @@
# Copyright 2026 zhaoxi826
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from pretor.core.database.table.individual import WorkerIndividual
from sqlmodel import select
from typing import List, Optional
from pretor.core.database.database_exception import database_exception
from ulid import ULID
class IndividualDatabase:
def __init__(self, async_session_maker):
self.async_session_maker = async_session_maker
@database_exception
async def add_worker_individual(self, **kwargs) -> WorkerIndividual:
async with self.async_session_maker() as session:
agent_id = str(ULID())
individual = WorkerIndividual(agent_id=agent_id, **kwargs)
session.add(individual)
await session.commit()
await session.refresh(individual)
return individual
@database_exception
async def get_worker_individual(self, agent_id: str) -> Optional[WorkerIndividual]:
async with self.async_session_maker() as session:
statement = select(WorkerIndividual).where(WorkerIndividual.agent_id == agent_id)
results = await session.execute(statement)
return results.scalar_one_or_none()
@database_exception
async def get_worker_individual_list(self, owner_id: str) -> List[WorkerIndividual]:
async with self.async_session_maker() as session:
statement = select(WorkerIndividual).where(WorkerIndividual.owner_id == owner_id)
results = await session.execute(statement)
return list(results.scalars().all())
@database_exception
async def update_worker_individual(self, agent_id: str, **kwargs) -> Optional[WorkerIndividual]:
async with self.async_session_maker() as session:
statement = select(WorkerIndividual).where(WorkerIndividual.agent_id == agent_id)
results = await session.execute(statement)
individual = results.scalar_one_or_none()
if not individual:
return None
for key, value in kwargs.items():
if value is not None:
setattr(individual, key, value)
session.add(individual)
await session.commit()
await session.refresh(individual)
return individual
@database_exception
async def delete_worker_individual(self, agent_id: str) -> bool:
async with self.async_session_maker() as session:
statement = select(WorkerIndividual).where(WorkerIndividual.agent_id == agent_id)
results = await session.execute(statement)
individual = results.scalar_one_or_none()
if not individual:
return False
session.delete(individual)
await session.commit()
return True
@database_exception
async def get_all_worker_individual(self) -> List[WorkerIndividual]:
async with self.async_session_maker() as session:
statement = select(WorkerIndividual)
results = await session.execute(statement)
return list(results.scalars().all())
+68
View File
@@ -0,0 +1,68 @@
# Copyright 2026 zhaoxi826
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from sqlmodel import SQLModel, Field, select
from typing import Optional, List
import json
class WorkflowRecord(SQLModel, table=True):
id: Optional[int] = Field(default=None, primary_key=True)
workflow_id: str = Field(index=True)
workflow_data_json: str
class MemoryRecord(SQLModel, table=True):
id: Optional[int] = Field(default=None, primary_key=True)
memory_text: str
embedding: List[float] = Field(sa_column_kwargs={"type_": "VECTOR"}) # Requires pgvector extension setup in DB
class MemoryRAG:
def __init__(self, async_session_maker):
self.async_session_maker = async_session_maker
async def save_workflow(self, workflow_id: str, workflow_data: dict):
async with self.async_session_maker() as session:
record = WorkflowRecord(
workflow_id=workflow_id,
workflow_data_json=json.dumps(workflow_data)
)
session.add(record)
await session.commit()
await session.refresh(record)
return record
async def get_workflow(self, workflow_id: str):
async with self.async_session_maker() as session:
statement = select(WorkflowRecord).where(WorkflowRecord.workflow_id == workflow_id)
results = await session.execute(statement)
record = results.scalar_one_or_none()
if record:
return json.loads(record.workflow_data_json)
return None
async def add_memory(self, memory_text: str, embedding: List[float]):
async with self.async_session_maker() as session:
record = MemoryRecord(memory_text=memory_text, embedding=embedding)
session.add(record)
await session.commit()
await session.refresh(record)
return record
async def retrieve_memory(self, query_embedding: List[float], limit: int = 5):
# Requires pgvector specific operations; simplified retrieval simulation here
async with self.async_session_maker() as session:
# A true pgvector query would use an ORDER BY using `<->` operator
# e.g. statement = select(MemoryRecord).order_by(MemoryRecord.embedding.l2_distance(query_embedding)).limit(limit)
statement = select(MemoryRecord).limit(limit)
results = await session.execute(statement)
return results.all()
+64
View File
@@ -0,0 +1,64 @@
# Copyright 2026 zhaoxi826
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import List
from pretor.core.database.table.provider import Provider
from sqlmodel import select
from pretor.core.database.database_exception import database_exception
class ProviderDatabase:
def __init__(self, async_session_maker):
self.async_session_maker = async_session_maker
@database_exception
async def get_provider(self) -> List[Provider]:
async with self.async_session_maker() as session:
statement = select(Provider)
results = await session.execute(statement)
results = results.scalars().all()
providers = [Provider(provider_title=provider.provider_title,
provider_url=provider.provider_url,
provider_apikey=provider.provider_apikey,
provider_models=provider.provider_models,
provider_type=provider.provider_type) for provider in results]
return providers
@database_exception
async def add_provider(self, **kwargs) -> None:
async with self.async_session_maker() as session:
provider = Provider(**kwargs)
session.add(provider)
await session.commit()
@database_exception
async def delete_provider(self, provider_id: str) -> None:
async with self.async_session_maker() as session:
provider = await session.get(Provider, provider_id)
if provider is not None:
session.delete(provider)
await session.commit()
@database_exception
async def update_provider(self, provider_id: str, **kwargs) -> Provider:
async with self.async_session_maker() as session:
provider = await session.get(Provider, provider_id)
if provider is not None:
for key, value in kwargs.items():
setattr(provider, key, value)
session.add(provider)
await session.commit()
await session.refresh(provider)
return provider
return None
@@ -0,0 +1,54 @@
# Copyright 2026 zhaoxi826
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from pretor.core.database.table.system_node import SystemNodeConfig
from sqlmodel import select
from typing import List, Optional
from pretor.core.database.database_exception import database_exception
class SystemNodeDatabase:
def __init__(self, async_session_maker):
self.async_session_maker = async_session_maker
@database_exception
async def upsert_system_node_config(self, node_name: str, provider_title: str, model_id: str, tools: Optional[List[str]] = None) -> SystemNodeConfig:
async with self.async_session_maker() as session:
statement = select(SystemNodeConfig).where(SystemNodeConfig.node_name == node_name)
results = await session.execute(statement)
config = results.scalar_one_or_none()
if config:
config.provider_title = provider_title
config.model_id = model_id
if tools is not None:
config.tools = tools
else:
config = SystemNodeConfig(node_name=node_name, provider_title=provider_title, model_id=model_id, tools=tools)
session.add(config)
await session.commit()
await session.refresh(config)
return config
@database_exception
async def get_all_system_node_configs(self) -> List[SystemNodeConfig]:
async with self.async_session_maker() as session:
statement = select(SystemNodeConfig)
results = await session.execute(statement)
return list(results.scalars().all())
@database_exception
async def get_system_node_config(self, node_name: str) -> Optional[SystemNodeConfig]:
async with self.async_session_maker() as session:
statement = select(SystemNodeConfig).where(SystemNodeConfig.node_name == node_name)
results = await session.execute(statement)
return results.scalar_one_or_none()
+135
View File
@@ -0,0 +1,135 @@
# Copyright 2026 zhaoxi826
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from pretor.core.database.table.user import User
from sqlmodel import select
from pretor.utils.error import UserNotExistError, UserPasswordError
from pretor.core.database.database_exception import database_exception
from pretor.core.database.table.user import UserAuthority
from pretor.utils.access import Accessor
class AuthDatabase:
def __init__(self, async_session_maker):
self.async_session_maker = async_session_maker
@database_exception
async def add_user(self, user_name: str, hashed_password: str) -> User:
from ulid import ULID
async with self.async_session_maker() as session:
# Check if any users exist
statement = select(User).limit(1)
results = await session.execute(statement)
existing_user = results.first()
authority = UserAuthority.USER
if existing_user is None:
authority = UserAuthority.SUPER_ADMINISTRATOR
user = User(
user_id=str(ULID()),
user_name=user_name,
hashed_password=hashed_password,
user_authority=authority
)
session.add(user)
await session.commit()
await session.refresh(user)
return user
@database_exception
async def change_password(self, user_name, old_password, new_password) -> User:
async with self.async_session_maker() as session:
statement = select(User).where(User.user_name == user_name)
results = await session.execute(statement)
user = results.scalar_one_or_none()
if user is None:
raise UserNotExistError()
if not Accessor.verify_password(old_password, user.hashed_password):
raise UserPasswordError()
user.hashed_password = new_password
session.add(user)
await session.commit()
await session.refresh(user)
return user
@database_exception
async def delete_user(self, user_name: str) -> None:
async with self.async_session_maker() as session:
statement = select(User).where(User.user_name == user_name)
results = await session.execute(statement)
user = results.scalar_one_or_none()
if user is None:
raise UserNotExistError()
session.delete(user)
await session.commit()
@database_exception
async def delete_user_by_id(self, user_id: str) -> None:
async with self.async_session_maker() as session:
user = await session.get(User, user_id)
if user is None:
raise UserNotExistError()
session.delete(user)
await session.commit()
@database_exception
async def login_user(self, user_name: str) -> str:
async with self.async_session_maker() as session:
statement = select(User).where(User.user_name == user_name)
results = await session.execute(statement)
user = results.scalar_one_or_none()
if user is None:
raise UserNotExistError()
return user
@database_exception
async def get_all_users(self) -> list[User]:
async with self.async_session_maker() as session:
statement = select(User)
results = await session.execute(statement)
users = results.scalars().all()
return list(users)
@database_exception
async def get_user_authority(self, user_id: str) -> UserAuthority:
async with self.async_session_maker() as session:
user = await session.get(User, user_id)
if user is None:
raise UserNotExistError()
return user.user_authority
@database_exception
async def change_user_authority(self, user_id: str, new_authority: UserAuthority) -> User:
"""
Changes the authority level of a specific user.
Args:
user_id: The ID of the user whose authority is to be changed.
new_authority: The new authority level to assign to the user.
Returns:
User: The updated user object.
Raises:
UserNotExistError: If the specified user does not exist.
"""
async with self.async_session_maker() as session:
user = await session.get(User, user_id)
if user is None:
raise UserNotExistError()
user.user_authority = new_authority
session.add(user)
await session.commit()
await session.refresh(user)
return user
+140
View File
@@ -0,0 +1,140 @@
# Copyright 2026 zhaoxi826
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import asyncio
import ray
from sqlalchemy.ext.asyncio import create_async_engine, AsyncSession
from sqlalchemy.orm import sessionmaker
from sqlmodel import SQLModel
from pretor.core.database.module.individual import IndividualDatabase
from pretor.core.database.module.user import AuthDatabase
from pretor.core.database.module.provider import ProviderDatabase
from pretor.core.database.module.system_node import SystemNodeDatabase
@ray.remote
class PostgresDatabase:
def __init__(self):
user = os.environ.get('POSTGRES_USER')
password = os.environ.get('POSTGRES_PASSWORD')
host = os.environ.get('POSTGRES_HOST')
port = os.environ.get('POSTGRES_PORT')
database = os.environ.get('POSTGRES_DB')
database_url = f"postgresql+asyncpg://{user}:{password}@{host}:{port}/{database}"
self.async_engine = create_async_engine(database_url, echo=True)
self.async_session_maker = sessionmaker(self.async_engine, class_=AsyncSession, expire_on_commit=False)
self._auth_database = AuthDatabase(self.async_session_maker)
self._provider_database = ProviderDatabase(self.async_session_maker)
self._individual_database = IndividualDatabase(self.async_session_maker)
self._system_node_database = SystemNodeDatabase(self.async_session_maker)
self.ready_event = asyncio.Event()
async def init_db(self) -> None:
try:
async with self.async_engine.begin() as conn:
await conn.run_sync(SQLModel.metadata.create_all)
except Exception as e:
# Provide a warning if the database is not accessible, allowing
# the app to start up for development/UI tests without crashing immediately.
print(f"Warning: Failed to initialize PostgreSQL database: {e}")
finally:
self.ready_event.set()
# Auth Database Methods
async def add_user(self, user_name: str, hashed_password: str):
await self.ready_event.wait()
return await self._auth_database.add_user(user_name, hashed_password)
async def change_password(self, user_name, old_password, new_password):
await self.ready_event.wait()
return await self._auth_database.change_password(user_name, old_password, new_password)
async def delete_user(self, user_name: str):
await self.ready_event.wait()
return await self._auth_database.delete_user(user_name)
async def delete_user_by_id(self, user_id: str):
await self.ready_event.wait()
return await self._auth_database.delete_user_by_id(user_id)
async def login_user(self, user_name: str):
await self.ready_event.wait()
return await self._auth_database.login_user(user_name)
async def get_all_users(self):
await self.ready_event.wait()
return await self._auth_database.get_all_users()
async def get_user_authority(self, user_id: str):
await self.ready_event.wait()
return await self._auth_database.get_user_authority(user_id)
async def change_user_authority(self, user_id: str, new_authority):
await self.ready_event.wait()
return await self._auth_database.change_user_authority(user_id, new_authority)
# Provider Database Methods
async def get_provider(self):
await self.ready_event.wait()
return await self._provider_database.get_provider()
async def add_provider_db(self, **kwargs):
await self.ready_event.wait()
return await self._provider_database.add_provider(**kwargs)
async def delete_provider_db(self, provider_id: str):
await self.ready_event.wait()
return await self._provider_database.delete_provider(provider_id)
async def update_provider_db(self, provider_id: str, **kwargs):
await self.ready_event.wait()
return await self._provider_database.update_provider(provider_id, **kwargs)
# System Node Database Methods
async def upsert_system_node_config(self, node_name: str, provider_title: str, model_id: str, tools: list[str] = None):
await self.ready_event.wait()
return await self._system_node_database.upsert_system_node_config(node_name, provider_title, model_id, tools)
async def get_all_system_node_configs(self):
await self.ready_event.wait()
return await self._system_node_database.get_all_system_node_configs()
# Individual Database Methods
async def add_worker_individual(self, **kwargs):
await self.ready_event.wait()
return await self._individual_database.add_worker_individual(**kwargs)
async def get_worker_individual(self, agent_id: str):
await self.ready_event.wait()
return await self._individual_database.get_worker_individual(agent_id)
async def get_worker_individual_list(self, owner_id: str):
await self.ready_event.wait()
return await self._individual_database.get_worker_individual_list(owner_id)
async def update_worker_individual(self, agent_id: str, **kwargs):
await self.ready_event.wait()
return await self._individual_database.update_worker_individual(agent_id, **kwargs)
async def delete_worker_individual(self, agent_id: str):
await self.ready_event.wait()
return await self._individual_database.delete_worker_individual(agent_id)
async def get_all_worker_individual(self):
await self.ready_event.wait()
return await self._individual_database.get_all_worker_individual()
+18
View File
@@ -0,0 +1,18 @@
# Copyright 2026 zhaoxi826
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from pretor.core.database.table.user import User
from pretor.core.database.table.provider import Provider
from pretor.core.database.table.individual import WorkerIndividual
__all__ = ["User", "Provider", "WorkerIndividual"]
+38
View File
@@ -0,0 +1,38 @@
# Copyright 2026 zhaoxi826
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from sqlmodel import SQLModel, Field
from typing import List, Optional
from sqlalchemy import Column, JSON
from enum import Enum
class AgentType(str, Enum):
SKILL_INDIVIDUAL = "skill_individual"
ORDINARY_INDIVIDUAL = "ordinary_individual"
SPECIAL_INDIVIDUAL = "special_individual"
class WorkerIndividual(SQLModel, table=True):
__tablename__ = "worker_individual"
agent_id: str = Field(primary_key=True)
agent_name: str = Field(index=True)
agent_type: AgentType
description: str = Field(nullable=False)
provider_title: str
model_id: str
system_prompt: Optional[str]
output_template: Optional[dict] = Field(sa_column=Column(JSON),description="输出模板标识")
bound_skill: Optional[str] = Field(sa_column=Column(JSON))
workspace: Optional[List[str]] = Field(sa_column=Column(JSON))
tools: Optional[List[str]] = Field(sa_column=Column(JSON), default=None)
owner_id: str
+14
View File
@@ -0,0 +1,14 @@
# Copyright 2026 zhaoxi826
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
+32
View File
@@ -0,0 +1,32 @@
# Copyright 2026 zhaoxi826
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from sqlmodel import SQLModel, Field
from typing import List
from sqlalchemy import Column, JSON
from typing import Optional
class Provider(SQLModel, table=True):
__tablename__ = "provider"
provider_id: str = Field(primary_key=True)
provider_title: str = Field(index=True)
provider_type: str
provider_url: Optional[str]
provider_apikey: Optional[str]
provider_models: List[str] = Field(sa_column=Column(JSON))
provider_owner: str
is_active: bool = Field(default=True, description="该服务商节点是否在线/启用")
+25
View File
@@ -0,0 +1,25 @@
# Copyright 2026 zhaoxi826
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from sqlmodel import SQLModel, Field
from typing import List, Optional
from sqlalchemy import Column, JSON
class SystemNodeConfig(SQLModel, table=True):
__tablename__ = "system_node_config"
node_name: str = Field(primary_key=True)
provider_title: str
model_id: str
tools: Optional[List[str]] = Field(sa_column=Column(JSON), default=None)
+31
View File
@@ -0,0 +1,31 @@
# Copyright 2026 zhaoxi826
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from sqlmodel import SQLModel, Field
from enum import IntEnum
class UserAuthority(IntEnum):
SUPER_ADMINISTRATOR = 100
ADMINISTRATOR = 50
USER = 20
UNAUTHORIZED_USER = 10
GUEST = 0
class User(SQLModel, table=True):
__tablename__ = 'user'
user_id: str = Field(primary_key=True)
user_name: str = Field(index=True)
hashed_password: str
user_authority: UserAuthority = Field(default=UserAuthority.USER)
@@ -0,0 +1,14 @@
# Copyright 2026 zhaoxi826
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
@@ -0,0 +1,166 @@
# Copyright 2026 zhaoxi826
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import ray
from pretor.core.global_state_machine.provider_manager import ProviderManager
from pretor.core.global_state_machine.tool_manager import GlobalToolManager
from typing import Dict
from pretor.core.database.postgres import PostgresDatabase
from pretor.api.platform.event import PretorEvent
import asyncio
from pretor.core.workflow.workflow import PretorWorkflow
from pretor.core.workflow.workflow_template_manager import WorkflowManager
from pretor.core.global_state_machine.skill_manager import GlobalSkillManager
from pretor.core.global_state_machine.individual_manager import GlobalIndividualManager
@ray.remote
class GlobalStateMachine:
def __init__(self, postgres_database: PostgresDatabase):
import sys
print("GSM __init__ START", file=sys.stderr, flush=True)
self.event_dict: Dict[str, PretorEvent] = {}
print(" event_dict done", file=sys.stderr, flush=True)
self._global_provider_manager = ProviderManager(postgres_database)
print(" provider_manager done", file=sys.stderr, flush=True)
self._global_tool_manager = GlobalToolManager()
print(" tool_manager done", file=sys.stderr, flush=True)
self._global_workflow_template_manager = WorkflowManager()
print(" workflow_template_manager done", file=sys.stderr, flush=True)
self._global_skill_manager = GlobalSkillManager()
print(" skill_manager done", file=sys.stderr, flush=True)
self._global_individual_manager = GlobalIndividualManager()
print(" individual_manager done", file=sys.stderr, flush=True)
self.postgres_database = postgres_database
print("GSM __init__ DONE", file=sys.stderr, flush=True)
async def init_state_machine(self):
await self._global_provider_manager.init_provider_register(self.postgres_database)
await self._global_individual_manager.init_individual_register(self.postgres_database)
async def add_provider_wrap(self, provider_type, provider_title, provider_url, provider_apikey, provider_owner):
return await self._global_provider_manager.add_provider(
provider_type=provider_type,
provider_title=provider_title,
provider_url=provider_url,
provider_apikey=provider_apikey,
provider_owner=provider_owner,
postgres_database=self.postgres_database
)
# Provider Manager Methods
def get_provider_list(self):
return self._global_provider_manager.get_provider_list()
def get_provider(self, provider_title):
return self._global_provider_manager.get_provider(provider_title)
async def delete_provider(self, provider_title: str):
return await self._global_provider_manager.delete_provider(provider_title, self.postgres_database)
# Tool Manager Methods
def get_tool_mapper(self):
return self._global_tool_manager.tool_mapper
def get_tool_list(self, agent_name: str):
# get_tool_list didn't actually exist on tool_manager, let's implement it to return the tools
# for a specific agent name (or scope)
tools = self._global_tool_manager.tool_mapper.get(agent_name, {})
# also include default tools
default_tools = self._global_tool_manager.tool_mapper.get("default", {})
merged_tools = {**default_tools, **tools}
return merged_tools
# Workflow Template Manager Methods
def get_all_workflow_templates(self):
return self._global_workflow_template_manager.get_all_workflow_templates()
def add_workflow_template(self, template_name: str, workflow_template):
return self._global_workflow_template_manager.add_workflow_template(template_name, workflow_template)
def delete_workflow_template(self, template_name: str):
return self._global_workflow_template_manager.delete_workflow_template(template_name)
def generate_workflow_template(self, workflow_template):
return self._global_workflow_template_manager.generate_workflow_template(workflow_template)
# Skill Manager Methods
def add_skill(self, skill_name: str):
return self._global_skill_manager.add_skill(skill_name)
def get_skill_list(self):
return self._global_skill_manager.get_skill_list()
def remove_skill(self, skill_name: str):
return self._global_skill_manager.remove_skill(skill_name)
# Individual Manager Methods
def add_individual(self, agent_id: str, config):
return self._global_individual_manager.add_individual(agent_id, config)
def get_individual(self, agent_id: str):
return self._global_individual_manager.get_individual(agent_id)
def remove_individual(self, agent_id: str):
return self._global_individual_manager.remove_individual(agent_id)
def list_individuals(self):
return self._global_individual_manager.list_individuals()
###以下方法为event_dict方法
def add_event(self, event: PretorEvent) -> None:
event.pending_queue = asyncio.Queue()
event.receive_queue = asyncio.Queue()
self.event_dict[event.trace_id] = event
def delete_event(self, trace_id: str) -> None:
del self.event_dict[trace_id]
def get_event(self, trace_id: str) -> PretorEvent:
return self.event_dict.get(trace_id, None)
def update_attachment(self, trace_id: str, attachment: Dict[str, str]) -> None:
self.event_dict[trace_id].attachment = attachment
def update_workflow(self, trace_id: str, workflow: PretorWorkflow) -> None:
self.event_dict[trace_id].workflow = workflow
def get_workflow(self, trace_id: str) -> PretorWorkflow:
return self.event_dict[trace_id].workflow
def list_events(self) -> list[dict]:
result = []
for trace_id, event in self.event_dict.items():
workflow_title = event.workflow.title if event.workflow else None
workflow_status = event.workflow.status.status if event.workflow and event.workflow.status else None
result.append({
"event_id": trace_id,
"workflow_title": workflow_title,
"status": workflow_status,
"user_name": event.user_name,
"message": event.message,
})
return result
async def put_pending(self, trace_id, item) -> None:
await self.event_dict[trace_id].pending_queue.put(item)
async def get_pending(self, trace_id) -> str:
return await self.event_dict[trace_id].pending_queue.get()
async def put_received(self, trace_id, item) -> None:
await self.event_dict[trace_id].receive_queue.put(item)
async def get_received(self, trace_id) -> str:
return await self.event_dict[trace_id].receive_queue.get()
@@ -0,0 +1,62 @@
# Copyright 2026 zhaoxi826
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Dict, Any
from pretor.utils.logger import get_logger
logger = get_logger('individual_manager')
class GlobalIndividualManager:
def __init__(self):
self._individuals: Dict[str, Dict[str, Any]] = {}
async def init_individual_register(self, postgres) -> None:
try:
try:
individuals = await postgres.get_all_worker_individual.remote()
for ind in individuals:
agent_id = getattr(ind, 'agent_id', None)
if agent_id:
self._individuals[agent_id] = ind.model_dump() if hasattr(ind, 'model_dump') else dict(ind)
logger.info(f"成功从数据库拉取了 {len(self._individuals)} 个 Worker Individual 配置。")
except AttributeError:
logger.warning("数据库中 get_all_worker_individual 方法未实现,跳过全量加载。可以在将来完善该接口。")
except Exception as e:
# 捕获因 Ray 调用目标方法不存在引发的异常
if "has no attribute 'get_all_worker_individual'" in str(e):
logger.warning("数据库 individual_database 中缺少 get_all_worker_individual 方法,无法全量拉取。")
else:
raise e
except Exception as e:
logger.error(f"从数据库拉取 Worker Individual 配置失败: {e}")
def add_individual(self, agent_id: str, config: Dict[str, Any]) -> None:
"""
注册一个 worker individual
config 可以包含 type, prompt, provider_title, model_id 等
"""
config["agent_id"] = agent_id
self._individuals[agent_id] = config
def get_individual(self, agent_id: str) -> Dict[str, Any]:
"""
获取一个 worker individual 的配置
"""
return self._individuals.get(agent_id, None)
def remove_individual(self, agent_id: str) -> None:
if agent_id in self._individuals:
del self._individuals[agent_id]
def list_individuals(self) -> Dict[str, Dict[str, Any]]:
return self._individuals
@@ -0,0 +1,19 @@
# Copyright 2026 zhaoxi826
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from pretor.core.global_state_machine.model_provider.base_provider import Provider, ProviderArgs
from pretor.core.global_state_machine.model_provider.openai_provider import OpenAIProvider
from pretor.core.global_state_machine.model_provider.claude_provider import ClaudeProvider
from pretor.core.global_state_machine.model_provider.deepseek_provider import DeepseekProvider
__all__ = ["Provider", "ProviderArgs", "OpenAIProvider", "ClaudeProvider", "DeepseekProvider"]
@@ -0,0 +1,96 @@
# Copyright 2026 zhaoxi826
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from abc import ABC, abstractmethod
from pydantic import BaseModel
from typing import List
from enum import Enum
class ProviderStatus(str, Enum):
UP = "up"
DOWN = "down"
class Provider(BaseModel):
provider_title: str
provider_url: str
provider_apikey: str
provider_models: List[str]
provider_type: str
provider_owner: str | None = None
provider_status: ProviderStatus = ProviderStatus.UP
class ProviderArgs(BaseModel):
provider_title: str
provider_url: str
provider_apikey: str
provider_owner: str
class BaseProvider(ABC):
@staticmethod
@abstractmethod
async def create_provider(provider_args: ProviderArgs) -> Provider:
"""
创建一个供应商,传入provider_args参数,打包为一个Provider对象
Args:
provider_args: 参数包,包含以下几个参数
provider_title: 供应商的别名
provider_url: 供应商的url
provider_apikey:供应商的apikey
Returns:
返回一个Provider对象,由provider_manager管理
"""
pass
@staticmethod
@abstractmethod
async def _load_models(provider_args: ProviderArgs) -> List[str]:
"""
加载模型列表
base_provider的字类应当按照供应商的api标准,向供应商的接口发送http请求从而或者供应商所提供的模型列表
Args:
provider_args: 参数包,包含以下几个参数
provider_title: 供应商的别名
provider_url: 供应商的url
provider_apikey:供应商的apikey
Returns:
返回一个列表,为http请求获取的模型列表
"""
pass
@staticmethod
@abstractmethod
def _return_provider(provider_args: ProviderArgs, provider_models: List[str]) -> Provider:
"""
包装Provider对象并返回
将provider_args和_load_models获取的方法包装为provider对象
Args:
provider_args: 参数包,包含以下几个参数
provider_title: 供应商的别名
provider_url: 供应商的url
provider_apikey:供应商的apikey
provider_models: 模型列表,为该供应商包含的模型列表
Returns:
返回一个Provider对象
"""
pass
@@ -0,0 +1,60 @@
from pretor.utils.retry import retry_on_retryable_error
# Copyright 2026 zhaoxi826
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from pretor.core.global_state_machine.model_provider.base_provider import BaseProvider, Provider, ProviderArgs
import httpx
from typing import List
class ClaudeProvider(BaseProvider):
@staticmethod
async def create_provider(provider_args: ProviderArgs) -> Provider:
provider_models: List[str] = await ClaudeProvider._load_models(provider_args)
provider: Provider = ClaudeProvider._return_provider(provider_args, provider_models)
return provider
@staticmethod
@retry_on_retryable_error()
async def _load_models(provider_args: ProviderArgs) -> List[str]:
# Anthropic 官方需要 version 头
headers = {
"x-api-key": provider_args.provider_apikey,
"anthropic-version": "2023-06-01",
"Content-Type": "application/json"
}
# 如果是官方 API,通常使用 /v1/models (如果支持)
# 注意:很多时候 Anthropic 并不返回完整列表,如果请求失败,建议返回硬编码的常用模型
url = f"{provider_args.provider_url.rstrip('/')}/v1/models"
try:
async with httpx.AsyncClient(timeout=10.0) as client:
response = await client.get(url, headers=headers)
if response.status_code == 200:
data = response.json()
model_ids = [m["id"] for m in data.get("data", [])]
return sorted(model_ids)
else:
# 如果官方列表接口不可用,fallback 到已知常用模型
return ["claude-3-5-sonnet-20240620", "claude-3-opus-20240229", "claude-3-haiku-20240307"]
except Exception as e:
print(f"[{provider_args.provider_title}] 获取 Claude 模型列表错误: {e}")
return []
@staticmethod
def _return_provider(provider_args: ProviderArgs, provider_models: List[str]) -> Provider:
return Provider(provider_title=provider_args.provider_title,
provider_apikey=provider_args.provider_apikey,
provider_url=provider_args.provider_url,
provider_models=provider_models,
provider_type="claude")
@@ -0,0 +1,59 @@
# Copyright 2026 zhaoxi826
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from pretor.utils.retry import retry_on_retryable_error
from pretor.core.global_state_machine.model_provider.base_provider import BaseProvider, Provider, ProviderArgs
import httpx
from typing import List
class DeepseekProvider(BaseProvider):
@staticmethod
async def create_provider(provider_args: ProviderArgs) -> Provider:
provider_models: List[str] = await DeepseekProvider._load_models(provider_args)
provider: Provider = DeepseekProvider._return_provider(provider_args, provider_models)
return provider
@staticmethod
@retry_on_retryable_error()
async def _load_models(provider_args: ProviderArgs) -> List[str]:
headers = {
"Authorization": f"Bearer {provider_args.provider_apikey}",
"Content-Type": "application/json"
}
url = f"{provider_args.provider_url}/models" if "/v1" in provider_args.provider_url else f"{provider_args.provider_url}/v1/models"
try:
async with httpx.AsyncClient(timeout=10.0) as client:
response = await client.get(url, headers=headers)
if response.status_code != 200:
print(f"[{provider_args.provider_title}] 获取模型失败: {response.status_code}")
return []
data = response.json()
raw_models = data.get("data", [])
model_ids = [m["id"] for m in raw_models]
return sorted(model_ids)
except httpx.RequestError as e:
from pretor.utils.error import RetryableError
print(f"[{provider_args.provider_title}] 网络请求异常: {e}")
raise RetryableError(f"[{provider_args.provider_title}] 网络请求异常: {e}") from e
except Exception as e:
print(f"[{provider_args.provider_title}] 解析模型列表时发生错误: {e}")
return []
@staticmethod
def _return_provider(provider_args: ProviderArgs, provider_models: List[str]) -> Provider:
return Provider(provider_title=provider_args.provider_title,
provider_apikey=provider_args.provider_apikey,
provider_url=provider_args.provider_url,
provider_models=provider_models,
provider_type="deepseek")
@@ -0,0 +1,59 @@
# Copyright 2026 zhaoxi826
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from pretor.utils.retry import retry_on_retryable_error
from pretor.core.global_state_machine.model_provider.base_provider import BaseProvider, Provider, ProviderArgs
import httpx
from typing import List
class OpenAIProvider(BaseProvider):
@staticmethod
async def create_provider(provider_args: ProviderArgs) -> Provider:
provider_models: List[str] = await OpenAIProvider._load_models(provider_args)
provider: Provider = OpenAIProvider._return_provider(provider_args, provider_models)
return provider
@staticmethod
@retry_on_retryable_error()
async def _load_models(provider_args: ProviderArgs) -> List[str]:
headers = {
"Authorization": f"Bearer {provider_args.provider_apikey}",
"Content-Type": "application/json"
}
url = f"{provider_args.provider_url}/models" if "/v1" in provider_args.provider_url else f"{provider_args.provider_url}/v1/models"
try:
async with httpx.AsyncClient(timeout=10.0) as client:
response = await client.get(url, headers=headers)
if response.status_code != 200:
print(f"[{provider_args.provider_title}] 获取模型失败: {response.status_code}")
return []
data = response.json()
raw_models = data.get("data", [])
model_ids = [m["id"] for m in raw_models]
return sorted(model_ids)
except httpx.RequestError as e:
from pretor.utils.error import RetryableError
print(f"[{provider_args.provider_title}] 网络请求异常: {e}")
raise RetryableError(f"[{provider_args.provider_title}] 网络请求异常: {e}") from e
except Exception as e:
print(f"[{provider_args.provider_title}] 解析模型列表时发生错误: {e}")
return []
@staticmethod
def _return_provider(provider_args: ProviderArgs, provider_models: List[str]) -> Provider:
return Provider(provider_title=provider_args.provider_title,
provider_apikey=provider_args.provider_apikey,
provider_url=provider_args.provider_url,
provider_models=provider_models,
provider_type="openai")
@@ -0,0 +1,86 @@
# Copyright 2026 zhaoxi826
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from pretor.core.global_state_machine.model_provider import Provider, OpenAIProvider, ClaudeProvider, DeepseekProvider
from typing import Dict, Type
class ProviderManager:
"""
模型供应商管理器 (ProviderManager)。
负责维护不同的 LLM 协议适配器,提供从配置注册到 Agent 实例化的全生命周期管理。
"""
# --- 类属性显式标注 (IDE 友好) ---
provider_mapper: Dict[str, Type[Provider]]
"""协议映射表:键为协议名(如 'openai'),值为对应的 Provider 类。"""
provider_register: Dict[str, Provider]
"""供应商注册表:键为用户自定义别名,值为已实例化的 Provider 对象。"""
def __init__(self, postgres):
self.provider_mapper = {"openai": OpenAIProvider,
"claude": ClaudeProvider,
"deepseek": DeepseekProvider}
self.provider_register = {}
async def init_provider_register(self, postgres) -> None:
providers = await postgres.get_provider.remote()
for provider in providers:
self.provider_register[provider.provider_title] = provider
async def add_provider(self, provider_type, provider_title, provider_url, provider_apikey, provider_owner, postgres_database) -> None:
from pretor.core.global_state_machine.model_provider import ProviderArgs
from pretor.utils.logger import get_logger
logger = get_logger('provider_manager')
import httpx
provider_args: ProviderArgs = ProviderArgs(provider_title=provider_title,
provider_url=provider_url,
provider_apikey=provider_apikey,
provider_owner=provider_owner)
try:
import ulid
provider_class = self.provider_mapper.get(provider_type, None)
if provider_class is None:
logger.warning(f"Provider type {provider_type} is not supported.")
return None
provider: Provider = await provider_class.create_provider(provider_args)
provider.provider_owner = provider_owner
self.provider_register[provider_title] = provider
await postgres_database.add_provider_db.remote(
provider_id=str(ulid.ULID()),
provider_title=provider.provider_title,
provider_url=provider.provider_url,
provider_apikey=provider.provider_apikey,
provider_models=provider.provider_models,
provider_type=provider.provider_type,
provider_owner=provider.provider_owner)
logger.info(f"已添加适配器{provider_title}")
except httpx.RequestError as e:
from pretor.utils.error import RetryableError
logger.warning(f"[{provider_args.provider_title}] 网络请求异常: {e}")
raise RetryableError(f"[{provider_args.provider_title}] 网络请求异常: {e}") from e
except Exception as e:
logger.warning(f"[{provider_args.provider_title}] 解析模型列表时发生错误: {e}")
def get_provider_list(self):
return self.provider_register
def get_provider(self, provider_title):
return self.provider_register.get(provider_title)
async def delete_provider(self, provider_title: str, postgres_database) -> None:
if provider_title in self.provider_register:
provider = self.provider_register[provider_title]
await postgres_database.delete_provider_db.remote( provider_id=provider.provider_id)
del self.provider_register[provider_title]
@@ -0,0 +1,75 @@
# Copyright 2026 zhaoxi826
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Tuple, Dict
from collections import defaultdict
import pathlib
import json
class GlobalSkillManager:
skill_mapper = Dict[str,Tuple[str]]
"""skill的存储表"""
def __init__(self):
self.skill_mapper = defaultdict(tuple)
import os
skill_plugin_dir = pathlib.Path(os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "..", "plugin", "skill")))
if not skill_plugin_dir.exists() or not skill_plugin_dir.is_dir():
return
for item in skill_plugin_dir.iterdir():
if item.is_dir() and not item.name.startswith((".", "__")):
json_path = item / "skill.json" # 拼接文件路径
if json_path.exists():
try:
with open(json_path, "r", encoding="utf-8") as f:
skill = json.load(f)
# 提取并映射
name = skill.get("name")
if name:
self.skill_mapper[name] = (
skill.get("description", ""),
skill.get("instructions", "")
)
except (json.JSONDecodeError, OSError) as e:
print(f"警告: 加载插件 {item.name} 失败: {e}")
def add_skill(self, skill_name: str) -> None:
"""Add a skill to the manager by reading its skill.json from the path"""
import os
skill_plugin_dir = pathlib.Path(os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "..", "plugin", "skill")))
item = skill_plugin_dir / skill_name
if item.is_dir() and not item.name.startswith((".", "__")):
json_path = item / "skill.json"
if json_path.exists():
try:
with open(json_path, "r", encoding="utf-8") as f:
skill = json.load(f)
name = skill.get("name")
if name:
self.skill_mapper[name] = (
skill.get("description", ""),
skill.get("instructions", "")
)
except (json.JSONDecodeError, OSError) as e:
print(f"警告: 加载插件 {item.name} 失败: {e}")
def get_skill_list(self) -> dict:
"""Return all skills currently loaded."""
return self.skill_mapper
def remove_skill(self, skill_name: str) -> None:
"""Remove a skill from the manager mapping."""
if skill_name in self.skill_mapper:
del self.skill_mapper[skill_name]
@@ -0,0 +1,52 @@
# Copyright 2026 zhaoxi826
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import pathlib
import importlib
import inspect
from collections import defaultdict
from pretor.plugin.tool_plugin.base_tool import BaseToolData
from typing import Dict, Type
from pretor.utils.logger import get_logger
logger = get_logger('tool_manager')
class GlobalToolManager:
tool_mapper: Dict[str, Dict[str, Type[BaseToolData]]]
def __init__(self):
self.tool_mapper = defaultdict(dict)
tool_plugin_dir = pathlib.Path(__file__).parent.parent.parent / "plugin" / "tool_plugin"
if not tool_plugin_dir.exists() or not tool_plugin_dir.is_dir():
return
for item in tool_plugin_dir.iterdir():
if item.is_dir() and not item.name.startswith("__"):
plugin_name = item.name
module_name = f"pretor.plugin.tool_plugin.{plugin_name}"
try:
module = importlib.import_module(module_name)
for name, obj in inspect.getmembers(module, inspect.isclass):
if issubclass(obj, BaseToolData) and obj is not BaseToolData:
# It's a valid tool class
action_scopes = obj.model_fields.get("action_scope").default
if not action_scopes:
self.tool_mapper["default"][plugin_name] = obj
else:
for scope in action_scopes:
self.tool_mapper[scope][plugin_name] = obj
except Exception as e:
logger.warning(f"Failed to load tool plugin {plugin_name}: {e}")
+14
View File
@@ -0,0 +1,14 @@
# Copyright 2026 zhaoxi826
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
@@ -0,0 +1,16 @@
# Copyright 2026 zhaoxi826
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .consciousness_node import ConsciousnessNode
__all__ = ["ConsciousnessNode"]
@@ -0,0 +1,179 @@
# Copyright 2026 zhaoxi826
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import ray
from typing import Union, overload
from pretor.core.individual.consciousness_node.template import (ConsciousnessNodeDeps, ForSupervisoryNode, ForWorkflow,\
ForWorkflowEngine, ForWorkflowInput, ForSupervisoryInput, ForWorkflowEngineInput)
from pydantic_ai import Agent, RunContext
from pretor.core.global_state_machine.global_state_machine import GlobalStateMachine
from pretor.core.global_state_machine.model_provider.base_provider import Provider
from pretor.adapter.model_adapter.agent_factory import AgentFactory
@ray.remote
class ConsciousnessNode:
def __init__(self) -> None:
from pretor.utils.logger import get_logger
self.logger = get_logger('consciousness_node')
self.agent: None | Agent = None
async def create_agent(self, global_state_machine: GlobalStateMachine, provider_title: str, model_id: str, tools_list: list[str] = None) -> None:
"""
create_agent方法,将agent对象装配到ConsciousnessNode的属性内
该方法通过provider_title从global_state_machine中获取provider对象,然后从provider对象中取出供应商形象,装配为pydantic_ai的
Agent实例,
并挂载到self.agent属性
Args:
global_state_machine: 全局状态机
provider_title: 供应商名
model_id: 模型id
Returns:
无返回
"""
system_prompt: str = (
"你叫Pretor,是一个多智能体AI助手系统中的【意识节点 (Consciousness Node)】。\n"
"你是系统的'高级规划师''架构师',负责处理监控节点分配过来的复杂任务。\n"
"你的主要工作场景包括:\n"
"1. 拆解任务 (Workflow Generation):结合用户的原始命令和提供的模板,生成严谨、可执行的工作流 (PretorWorkflow),并将其输出为 ForWorkflowEngine 格式。拆解时步骤应清晰连贯。\n"
"2. 中途指导 (Workflow Execution):在工作流执行中,如果某一步骤指派给你,你需要对控制节点的结果进行分析或提供下一步的指导,输出 ForWorkflow 格式。\n"
"3. 总结报告 (Supervisory Report):在整个工作流执行完毕后,你需要对整体流程、各个控制节点的执行情况进行审查,并生成一份技术性的总结报告,输出 ForSupervisoryNode 格式。\n"
"请确保所有的思考和生成过程符合逻辑,严密且高质量。"
)
output_type = Union[ForSupervisoryNode, ForWorkflow, ForWorkflowEngine]
from pretor.utils.get_tool import load_tools_from_list
provider: Provider = await global_state_machine.get_provider.remote( provider_title)
agent_factory = AgentFactory()
callables = load_tools_from_list(tools_list)
self.agent = agent_factory.create_agent(provider=provider,
model_id=model_id,
output_type=output_type,
system_prompt=system_prompt,
deps_type=ConsciousnessNodeDeps,
agent_name="consciousness_node",
tools=callables)
@self.agent.system_prompt
async def dynamic_prompt(ctx: RunContext[ConsciousnessNodeDeps]):
prompt = system_prompt + "\n\n"
prompt += (
f"=== 当前任务上下文 ===\n"
f"- 当前指令 (Command): {ctx.deps.command}\n"
f"- 原始用户命令 (Original Command): {ctx.deps.original_command}\n"
)
if ctx.deps.workflow_template:
prompt += f"- 选定工作流模板 (Workflow Template): {ctx.deps.workflow_template}\n"
if ctx.deps.available_skills:
prompt += "\n=== 当前可用 Skill Individual ===\n"
prompt += "你可以直接将以下 Skill Individual 安排进工作流的步骤中(设置 node 为 skill_individual,并将 agent_id 设置为对应 Skill Individual 的真实 agent_id,不要用名称!),作为可调用的工具。\n"
for skill in ctx.deps.available_skills:
prompt += f"- 真实 agent_id: {skill.get('agent_id')}\n 名称: {skill['name']}\n 描述: {skill['description']}\n"
return prompt
async def working(self, payload: Union[ForWorkflowEngineInput, ForWorkflowInput, ForSupervisoryInput]) -> Union[ForWorkflowEngine, ForWorkflow, ForSupervisoryNode, None]:
try:
result = await self._run(payload)
if isinstance(result, (ForWorkflowEngine, ForWorkflow, ForSupervisoryNode)):
return result
else:
self.logger.error(f"ConsciousnessNode: 未知或不匹配的返回类型: {type(result)}")
return None
except Exception:
self.logger.exception("ConsciousnessNode在执行working时发生严重错误")
return None
@overload
async def _run(self, payload: ForWorkflowEngineInput) -> ForWorkflowEngine:
"""
_run方法
该分支应当在supervisory_node简单处理用户命令后,工作流创建前调用!
Args:
payload: 应当包含workflow_template和event对象
Returns:
ForWorkflowEngine对象,将被放到全局状态机后丢入WorkflowEngine的异步队列
"""
pass
@overload
async def _run(self, payload: ForWorkflow) -> ForWorkflow:
"""
_run方法
该分支应当在workflow运行时,由WorkflowEngine进行调用!
Args:
payload: 应当包含workflow中的WorkStep对象
Returns:
ForWorkflow对象,作为ConsciousnessNode执行Workflow中的WorkStep的结果
"""
pass
@overload
async def _run(self, payload: ForSupervisoryInput) -> ForSupervisoryNode:
"""
_run方法
该分支应当在workflow运行完全结束后,由WorkflowEngine进行调用!
Args:
payload: 应当包含整个Workflow的情况
Returns:
ForSupervisory对象,作为ConsciousnessNode对于全工作流的技术性总结,返回给SupervisoryNode
"""
pass
async def _run(self, payload: Union[ForSupervisoryInput, ForWorkflowInput, ForWorkflowEngineInput]) -> Union[ForSupervisoryNode, ForWorkflow, ForWorkflowEngine]:
try:
self.agent.retries = 3
if isinstance(payload, ForWorkflowEngineInput):
deps = ConsciousnessNodeDeps(
original_command=payload.original_command,
workflow_template=payload.workflow_template,
command="拆解原始命令变成一个工作流",
available_skills=payload.available_skills
)
self.logger.debug("ConsciousnessNode: 开始生成工作流 (原生重试开启)")
prompt = "根据original_command制定严密的可执行workflow"
if payload.workflow_template:
prompt += ",可以学习并参考workflow_template的设计理念"
result = await self.agent.run(prompt, deps=deps)
return result.output
elif isinstance(payload, ForWorkflowInput):
deps = ConsciousnessNodeDeps(
original_command=payload.original_command,
command="完成workflow step中分配给意识节点的特定任务或指导"
)
self.logger.debug("ConsciousnessNode: 开始处理工作流节点任务 (原生重试开启)")
result = await self.agent.run(f"处理此工作流步骤信息:\n{payload.workflow_step.model_dump_json()}",
deps=deps)
return result.output
elif isinstance(payload, ForSupervisoryInput):
deps = ConsciousnessNodeDeps(
original_command=payload.original_command,
command="对于工作流整体执行结果进行检查,并且生成一份专业的技术性总结报告"
)
self.logger.debug("ConsciousnessNode: 开始生成技术总结报告 (原生重试开启)")
result = await self.agent.run(f"基于以下工作流的执行记录,生成技术报告:\n{payload.workflow.model_dump_json()}",
deps=deps)
return result.output
except Exception as e:
self.logger.exception(f"ConsciousnessNode 模型生成最终失败: {str(e)}")
raise RuntimeError(f"ConsciousnessNode 无法完成任务: {str(e)}") from e
@@ -0,0 +1,67 @@
# Copyright 2026 zhaoxi826
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from pretor.core.workflow.workflow import PretorWorkflow, WorkStep
from pretor.utils.agent_model import ResponseModel, DepsModel, InputModel
from pydantic import Field
#意识节点回复类
class ConsciousnessNodeResponse(ResponseModel):
"""Consciousness response model,是意识节点所有回复类型的父类"""
pass
class ForWorkflowEngine(ConsciousnessNodeResponse):
"""生成workflow并放入WorkflowEngine"""
workflow: PretorWorkflow = Field(..., description="生成好的符合规范的完整工作流对象。")
reasoning: str = Field(..., description="生成此工作流的原因和思路简述。")
class ForWorkflow(ConsciousnessNodeResponse):
"""处理workflow中需要ConsciousnessNode的工作"""
output: str = Field(..., description="对当前工作流步骤的具体处理结果或指导意见。")
class ForSupervisoryNode(ConsciousnessNodeResponse):
"""工作流完成后进行校验并返回给SupervisoryNode"""
output: str = Field(..., description="为监控节点提供的全工作流执行情况的技术性总结报告。")
class ConsciousnessNodeDeps(DepsModel):
original_command: str
workflow_template: str | None = None
command: str
available_skills: list[dict] | None = None
class ConsciousnessNodeInput(InputModel):
pass
class ForWorkflowEngineInput(ConsciousnessNodeInput):
workflow_template: str | None = None
original_command: str
available_skills: list[dict] | None = None
class ForWorkflowInput(ConsciousnessNodeInput):
workflow_step: WorkStep
original_command: str
class ForSupervisoryInput(ConsciousnessNodeInput):
workflow: PretorWorkflow
original_command: str
@@ -0,0 +1,16 @@
# Copyright 2026 zhaoxi826
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .control_node import ControlNode
__all__ = ["ControlNode"]
@@ -0,0 +1,102 @@
# Copyright 2026 zhaoxi826
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import ray
from pydantic_ai import Agent, RunContext
from pretor.core.global_state_machine.global_state_machine import GlobalStateMachine
from pretor.core.global_state_machine.model_provider.base_provider import Provider
from pretor.adapter.model_adapter.agent_factory import AgentFactory
from pretor.core.individual.control_node.template import ForWorkflow, ForWorkflowInput, ControlNodeDeps
@ray.remote
class ControlNode:
def __init__(self):
from pretor.utils.logger import get_logger
self.logger = get_logger('control_node')
self.agent: Agent | None = None
async def create_agent(self, global_state_machine: GlobalStateMachine, provider_title: str, model_id: str, tools_list: list[str] = None) -> None:
"""
create_agent方法,将agent对象装配到Control的属性内
该方法通过provider_title从global_state_machine中获取provider对象,然后从provider对象中取出供应商形象,装配为pydantic_ai的
Agent实例,
并挂载到self.agent属性
Args:
global_state_machine: 全局状态机
provider_title: 供应商名
model_id: 模型id
Returns:
无返回
"""
system_prompt: str = (
"你叫Pretor,是一个多智能体AI助手系统中的【控制节点 (Control Node)】。\n"
"你是系统的'执行者''车间主任',专门负责执行工作流中分配给你的具体子任务。\n"
"你的工作职责是:\n"
"1. 仔细分析分配给你的工作流步骤 (workflow_step) 的目标和要求。\n"
"2. 运用你被分配的工具 (如有) 或者依靠自身的知识和推理能力,精准、高效地完成该任务。\n"
"3. 将执行的结果、产生的数据或者具体的输出,严格按照 ForWorkflow 格式返回。\n"
"请注意:你的输出应当具体、实用,直接提供任务所要求的结果,不要做过多无关的寒暄。"
)
output_type = ForWorkflow
from pretor.utils.get_tool import load_tools_from_list
provider: Provider = await global_state_machine.get_provider.remote( provider_title)
agent_factory = AgentFactory()
callables = load_tools_from_list(tools_list)
self.agent = agent_factory.create_agent(provider=provider,
model_id=model_id,
output_type=output_type,
system_prompt=system_prompt,
deps_type=ControlNodeDeps,
agent_name="control_node",
tools=callables)
@self.agent.system_prompt
async def dynamic_prompt(ctx: RunContext[ControlNodeDeps]):
prompt = system_prompt + "\n\n"
prompt += (
f"=== 当前任务步骤上下文 ===\n"
f"- 步骤名称 (Name): {ctx.deps.workflow_step.name}\n"
f"- 步骤目标/描述 (Description): {ctx.deps.workflow_step.desc}\n"
f"- 前置输入(input: {ctx.deps.workflow_step.inputs}\n"
)
return prompt
async def working(self, payload: ForWorkflowInput) -> str:
try:
result: ForWorkflow = await self._run(payload)
return result
except Exception:
self.logger.exception("ControlNode在执行working时发生严重错误")
return None
async def _run(self, payload: ForWorkflowInput) -> ForWorkflow:
try:
self.agent.retries = 3
deps = ControlNodeDeps(
workflow_step=payload.workflow_step
)
self.logger.debug(f"ControlNode: 开始执行工作流节点 [{payload.workflow_step.name}] (原生重试开启)")
result = await self.agent.run(
f"请根据提供的 workflow_step 上下文,执行此步骤并输出结果。\n详细指令或附加数据:{payload.workflow_step.model_dump_json()}",
deps=deps
)
return result.output
except Exception as e:
self.logger.exception(f"ControlNode 在执行步骤 [{payload.workflow_step.name}] 时最终失败: {str(e)}")
raise RuntimeError(f"ControlNode 执行步骤失败: {str(e)}") from e
@@ -0,0 +1,39 @@
# Copyright 2026 zhaoxi826
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from pydantic import Field
from pretor.core.workflow.workflow import WorkStep
from pretor.utils.agent_model import ResponseModel, InputModel, DepsModel
class ControlNodeResponse(ResponseModel):
"""控制节点回复的基类"""
pass
class ControlNodeInput(InputModel):
pass
class ControlNodeDeps(DepsModel):
workflow_step: WorkStep
# In the future, this can be dynamically populated with tools specific to the current task execution
class ForWorkflow(ControlNodeResponse):
output: str = Field(..., description="控制节点执行特定工作流步骤的结果。包含执行细节和输出数据。")
class ForWorkflowInput(ControlNodeInput):
workflow_step: WorkStep
@@ -0,0 +1,14 @@
# Copyright 2026 zhaoxi826
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
@@ -0,0 +1,14 @@
# Copyright 2026 zhaoxi826
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
@@ -0,0 +1,16 @@
# Copyright 2026 zhaoxi826
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .supervisory_node import SupervisoryNode
__all__ = ["SupervisoryNode"]
@@ -0,0 +1,192 @@
# Copyright 2026 zhaoxi826
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import datetime
import ray
from typing import Union, overload
from pretor.api.platform.event import PretorEvent
from pretor.adapter.model_adapter.agent_factory import AgentFactory
from pretor.core.global_state_machine.global_state_machine import GlobalStateMachine
from pretor.core.global_state_machine.model_provider import Provider
from pretor.core.individual.supervisory_node.template import ForConsciousnessNode, ForUser, SupervisoryNodeDeps, TerminationMessage
from pydantic_ai import RunContext, Agent
from pretor.utils.ray_hook import ray_actor_hook
@ray.remote
class SupervisoryNode:
def __init__(self) -> None:
from pretor.utils.logger import get_logger
self.logger = get_logger('supervisory_node')
self.agent: None | Agent = None
async def create_agent(self, global_state_machine: GlobalStateMachine, provider_title: str, model_id: str, tools_list: list[str] = None) -> None:
"""
create_agent方法,将agent对象装配到SupervisoryNode的属性内
该方法通过provider_title从global_state_machine中获取provider对象,然后从provider对象中取出供应商形象,装配为pydantic_ai的Agent实例,
并挂载到self.agent属性
Args:
global_state_machine: 全局状态机
provider_title: 供应商名
model_id: 模型id
Returns:
无返回
"""
system_prompt: str = (
"你叫Pretor,是一个多智能体AI助手系统中的【监控节点 (Supervisory Node)】。\n"
"你是系统的'前台接待''大脑皮层',负责接收用户的初始请求或工作流的最终报告。\n"
"你的核心职责是进行【意图识别与路由】。请仔细阅读用户的请求:\n"
"1. 如果用户只是进行简单的问候、闲聊或查询非常基础的信息,请直接生成友好的回复,使用 ForUser 格式。\n"
"2. 如果用户提出的是复杂任务(如需要编写代码、多步骤规划、数据处理等),请务必将其判定为需要工作流处理的任务,"
" 并使用 ForConsciousnessNode 格式。若提供的【可用模板列表】中有合适的模板请选用,若都不匹配则 workflow_template 设为 null。\n"
"3. 如果你收到的是 TerminationMessage(代表工作流已完成并生成了报告),请将报告内容转化为友好的面向用户的回复,使用 ForUser 格式。\n"
"请保持冷静、专业,并严格遵循上述路由规则。"
)
output_type = Union[ForConsciousnessNode, ForUser]
from pretor.utils.get_tool import load_tools_from_list
provider: Provider = await global_state_machine.get_provider.remote( provider_title)
agent_factory = AgentFactory()
callables = load_tools_from_list(tools_list)
self.agent = agent_factory.create_agent(provider=provider,
model_id=model_id,
output_type=output_type,
system_prompt=system_prompt,
deps_type=SupervisoryNodeDeps,
agent_name="supervisory_node",
tools=callables)
@self.agent.system_prompt
async def dynamic_prompt(ctx: RunContext[SupervisoryNodeDeps]):
prompt = system_prompt + "\n\n"
prompt += (
f"=== 当前上下文 ===\n"
f"- 平台 (Platform): {ctx.deps.platform}\n"
f"- 用户名 (User): {ctx.deps.user_name}\n"
f"- 当前时间 (Time): {ctx.deps.time}\n"
f"- 可用工作流模板 (Available Templates): {ctx.deps.available_templates}\n"
)
# 修改 system_prompt 变量
prompt += (
"\n\n注意:你必须调用且只能调用一个函数(工具)来输出结果。"
"如果你想直接回复用户,请调用 ForUser;"
"如果你想移交给工作流,请调用 ForConsciousnessNode(若没有合适的模板,workflow_template 填 null)。"
"严禁返回纯文本,必须使用工具格式!"
)
if ctx.deps.error_history:
prompt += (
f"\n=== 错误重试指示 ===\n"
f"警告:前一次尝试失败,错误信息如下:\n{ctx.deps.error_history}\n"
f"请务必修正该错误并按照要求的 Pydantic 格式输出。"
)
return prompt
###工作函数
async def working(self, payload: Union[PretorEvent, TerminationMessage]) -> str:
"""
working方法,是节点唯一的调用方法,对于_run函数的结果进行判断并实现最终回复
Args:
payload: 消息载荷,包含所有信息
Returns:
str,监控节点对于用户的回复
"""
try:
result = await self._run(payload)
if isinstance(result, ForConsciousnessNode):
self.logger.info(f"SupervisoryNode: 任务已分配给工作流引擎处理,选用模板 [{result.workflow_template}]")
if isinstance(payload, PretorEvent):
payload.context["workflow_template"] = result.workflow_template
try:
global_state_machine = ray_actor_hook("global_state_machine").global_state_machine
await global_state_machine.add_event.remote(payload)
workflow_running_engine = ray_actor_hook("workflow_running_engine").workflow_running_engine
await workflow_running_engine.put_event.remote(payload)
except Exception as e:
self.logger.error(f"SupervisoryNode: 无法将事件放入 WorkflowRunningEngine: {e}")
return "抱歉,任务提交失败,系统内部错误。"
return f"任务已创建,准备创建工作流。原因:{result.reasoning}"
elif isinstance(result, ForUser):
self.logger.info("SupervisoryNode: 直接向用户返回简单回复。")
return result.context
else:
self.logger.error(f"SupervisoryNode: 未知响应类型: {type(result)}")
return "抱歉,系统内部遇到未知错误,无法正确处理您的请求。"
except Exception:
self.logger.exception("SupervisoryNode在处理请求时发生未捕获的严重错误")
return "抱歉,监控节点处理请求时发生严重错误,请联系管理员。"
@overload
async def _run(self, payload: PretorEvent) -> Union[ForConsciousnessNode, ForUser]:
"""
_run方法
Args:
payload: PretorEvent的实例,是用户输入时对于消息的封装
Returns:
ForUser对象,监控节点对于用户进行的简单回答
ForConsciousnessNode对象,监控节点将用户的请求判断为复杂任务,将PretorEvent传递给意识节点,并且给选择好的工作流模板
"""
...
@overload
async def _run(self, payload: TerminationMessage) -> ForUser:
"""
_run方法
Args:
payload: Termination的实例,是工作流结束后到达监控节点的最后结果
Returns:
ForUser对象,工作流结束后给用户的返回
"""
...
async def _run(self, payload: Union[PretorEvent, TerminationMessage]) -> Union[ForConsciousnessNode, ForUser]:
"""
_run方法,将payload转化为对llm发送的消息并发送
Args:
payload: 消息载荷
Returns:
ForConsciousnessNode对象,对意识节点发送的消息
ForUser对象,对用户发送到消息
"""
platform = payload.platform
user_name = payload.user_name
message = payload.message
time_str = datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")
try:
global_state_machine = ray_actor_hook("global_state_machine").global_state_machine
workflow_template_dict = await global_state_machine.get_all_workflow_templates.remote()
available_templates_str = "\n".join([f"- 名称: {k}, 描述/内容: {v}" for k, v in
workflow_template_dict.items()]) if workflow_template_dict else "暂无注册的工作流模板"
deps = SupervisoryNodeDeps(
platform=platform,
user_name=user_name,
time=time_str,
available_templates=available_templates_str
)
self.logger.debug("SupervisoryNode 开始生成 (启用原生 Pydantic-AI 重试)")
prompt_message = message
if isinstance(payload, TerminationMessage):
prompt_message = f"【工作流执行结束报告】\n请将以下技术报告转化为对用户的友好回复:\n{message}"
self.agent.retries = 3
result = await self.agent.run(prompt_message,
deps=deps)
return result.output
except Exception as e:
self.logger.exception(f"SupervisoryNode 模型生成或解析最终失败: {str(e)}")
return ForUser(context="系统当前负载过高或遇到复杂内部错误,请稍后再试。")
@@ -0,0 +1,40 @@
# Copyright 2026 zhaoxi826
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from pydantic import Field
from pretor.utils.agent_model import ResponseModel, DepsModel
from pydantic import BaseModel
class SupervisoryNodeResponse(ResponseModel):
pass
class ForUser(SupervisoryNodeResponse):
context: str = Field(..., description="对用户的回复,应当使用和蔼的语气进行回复。用于直接解答简单问题或返回最终报告。")
class ForConsciousnessNode(SupervisoryNodeResponse):
workflow_template: str | None = Field(default=None, description="选择的工作流模板的名称,用于处理复杂任务。若无需模板则为 None。")
reasoning: str = Field(..., description="选择将任务移交意识节点并选用该模板的简短原因。")
class TerminationMessage(BaseModel):
platform: str
user_name: str
message: str
class SupervisoryNodeDeps(DepsModel):
platform: str
user_name: str
time: str
retry_count: int = 0
error_history: str = ""
available_templates: str = "默认工作流 (default_workflow)"
+14
View File
@@ -0,0 +1,14 @@
# Copyright 2026 zhaoxi826
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
+85
View File
@@ -0,0 +1,85 @@
# Copyright 2026 zhaoxi826
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import List, Optional, Union, Literal, Dict, Any
from pydantic import BaseModel, Field, model_validator
from pretor.utils.logger import get_logger
logger = get_logger('workflow')
NodeType = Literal[
"consciousness_node", "control_node", "supervisory_node", "skill_individual"
]
class EventInfo(BaseModel):
platform: str
user_name: str
class LogicGate(BaseModel):
if_fail: str = Field(..., description="失败跳转目标,如 'jump_to_step_1'")
if_pass: Literal["continue", "exit"] = Field(default="continue", description="成功后的动作")
class WorkStep(BaseModel):
step: int = Field(..., gt=0, description="步骤序号,严格自增")
name: str = Field(..., description="步骤名称")
node: NodeType = Field(..., description="负责执行的节点类型")
action: str = Field(..., description="执行的原子动作")
desc: str = Field(..., description="动作细节的自然语言描述,包含人工规范指导")
inputs: Optional[Union[str, List[str]]] = Field(default=None, description="前置依赖输出")
outputs: Optional[str] = Field(default=None, description="当前步骤产出物变量名")
agent_id: Optional[str] = Field(default=None, description="分配给 skill_individual 的 Skill Individual 真实 agent_id,不可用名称代替")
logic_gate: Optional[LogicGate] = Field(default=None, description="逻辑跳转控制")
status: Literal["waiting", "running", "completed", "failed"] = Field(
default="waiting",
description="执行状态 (LLM建议保留默认值)"
)
class WorkflowStatus(BaseModel):
step: int = Field(default=1, gt=0, description="当前运行到的工作流步数")
status: Literal["waiting_llm_working", "waiting_tool_working", "llm_working", "tool_working"] = Field(
default="waiting_llm_working",
description="当前系统调度状态"
)
class PretorWorkflow(BaseModel):
title: str = Field(..., description="工作流的标题")
work_link: List[WorkStep] = Field(..., description="工作链逻辑定义")
# ---------------- 以下为系统级管控字段,LLM 无需关心 ---------------- #
trace_id: str | None = Field(description="系统自动生成的追溯ID")
version: str = Field(default="v1.0", description="系统协议版本号")
command: Optional[str] = Field(default=None, description="触发此工作流的原始命令")
output: Dict[str, Any] = Field(default_factory=dict, description="工作流最终产出结果")
status: WorkflowStatus = Field(default_factory=WorkflowStatus, description="运行时状态对象")
event_info: EventInfo | None = Field(default=None)
context_memory: Dict[str, Any] = Field(default_factory=dict)
@model_validator(mode='after')
def validate_workflow_integrity(self) -> 'PretorWorkflow':
steps = [s.step for s in self.work_link]
expected = list(range(1, len(steps) + 1))
if steps != expected:
raise ValueError(f"工作链步数不连续!期望 {expected},实际 {steps}")
max_step = len(steps)
for s in self.work_link:
if s.logic_gate and "jump_to_step_" in s.logic_gate.if_fail:
try:
target = int(s.logic_gate.if_fail.split("_")[-1])
if target > max_step or target < 1:
raise ValueError(f"Step {s.step} 的跳转目标 Step {target} 越界了!")
except ValueError as e:
if "越界" in str(e):
raise e
raise ValueError(f"LogicGate 格式错误: {s.logic_gate.if_fail}")
return self
+23
View File
@@ -0,0 +1,23 @@
# workflow文档
---
- workflow(工作流)是作为pretor中运行任务的基本单位,workflow_manager管理整个workflow模块,包括生成workflow_template(工作流模板),生成workflow对象,和保存整个workflow_template表。
- workflow_template是一个工作流模板,旨在由专业人士教导LLM如何编写工作流并进行任务,每个workflow_template都应该保存在 **pretor/workflow_pugin/** 文件夹下,保存格式为~_workflow_template.jsonjson格式为:
```json
{
"name": "",
"desc": "",
"work_link": [
{
"step": "",
"node": "",
"action": "",
"desc": "",
"input": [],
"output": [],
"logic_gate": {}
}
]
}
```
- workflow_template将由监管节点挑选交给意识节点,意识节点按照参考模板生成标准的workflow对象,转交给pipeline开始执行任务链。
+364
View File
@@ -0,0 +1,364 @@
# Copyright 2026 zhaoxi826
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from pretor.utils.ray_hook import ray_actor_hook
import ray
import asyncio
from pretor.core.workflow.workflow import PretorWorkflow, WorkStep, EventInfo
from typing import Optional, Dict, Union, Any, List
from pretor.utils.error import WorkflowError, WorkflowExit
from pretor.api.platform.event import PretorEvent
from pretor.core.individual.control_node.template import ForWorkflowInput as ControlForWorkflowInput, \
ForWorkflow as ControlForWorkflow
from pretor.core.individual.consciousness_node.template import (
ForWorkflowInput as ConsciousnessForWorkflowInput,
ForSupervisoryInput,
ForSupervisoryNode,
ForWorkflow as ConsciousnessForWorkflow,
ForWorkflowEngineInput,
ForWorkflowEngine
)
from pretor.core.individual.supervisory_node.template import TerminationMessage
import pathlib
def get_workflow_template(workflow_name: str) -> str:
workflow_template = pathlib.Path(__file__).parent.parent.parent / "workflow_template" / (workflow_name + "_workflow_template.json")
with open(workflow_template, "r", encoding="utf-8") as workflow_template_file:
workflow_template = workflow_template_file.read()
return workflow_template
class WorkflowEngine:
def __init__(self,
workflow: PretorWorkflow,
consciousness_node=None,
control_node=None,
supervisory_node=None):
from pretor.utils.logger import get_logger
self.logger = get_logger('workflow_runner')
self.workflow: PretorWorkflow = workflow
"""工作流:当前WorkflowEngine待执行的workflow"""
self._steps_by_id: Dict[int, WorkStep] = {step.step: step for step in self.workflow.work_link}
"""步骤表:将当前workflow的步骤序号和步骤内容存放"""
self.consciousness_node = consciousness_node
"""意识节点"""
self.control_node = control_node
"""控制节点"""
self.supervisory_node = supervisory_node
"""监督节点"""
self._gsm = ray_actor_hook("global_state_machine").global_state_machine
async def _push_sse(self, msg: str) -> None:
try:
await self._gsm.put_pending.remote(self.workflow.trace_id, msg)
except Exception:
pass
def _prepare_inputs(self, inputs: Optional[Union[str, List[str]]]) -> Any:
"""
准备输入的方法
Args:
inputs: 待输入的名称
Returns:
"""
match inputs:
case None:
return None
case str(name):
return self.workflow.context_memory.get(name)
case list(names):
return {k: self.workflow.context_memory.get(k) for k in names}
async def run(self):
"""
run方法
处理并执行workflow的方法
"""
self.logger.info(f"🚀 工作流引擎启动: {self.workflow.title} [Trace ID: {self.workflow.trace_id}]")
await self._push_sse(f"[工作流启动] {self.workflow.title}")
max_step = len(self.workflow.work_link)
while 1 <= self.workflow.status.step <= max_step:
current_step_id = self.workflow.status.step
current_step = self._steps_by_id.get(current_step_id)
if not current_step:
self.logger.error(f"严重错误:找不到步骤 {current_step_id},工作流强制终止。")
self.workflow.status.status = "failed"
await self._push_sse(f"[工作流失败] 找不到步骤 {current_step_id}")
break
self.logger.info(f"▶️ 开始执行 Step {current_step_id}: [{current_step.node}] -> {current_step.action}")
current_step.status = "running"
await self._push_sse(f"[Step {current_step_id}] {current_step.name}: {current_step.desc}")
try:
step_input_data = self._prepare_inputs(current_step.inputs)
step_result, is_success = await self._dispatch_to_node(current_step, step_input_data)
if is_success:
if current_step.outputs:
self.workflow.context_memory[current_step.outputs] = step_result
self.logger.debug(f"Step {current_step_id} 产出已保存至变量: '{current_step.outputs}'")
current_step.status = "completed"
await self._push_sse(f"[Step {current_step_id} 完成] {current_step.name}")
else:
self.logger.warning(f"Step {current_step_id} 执行遇到业务失败/驳回。")
current_step.status = "failed"
await self._push_sse(f"[Step {current_step_id} 失败] {current_step.name}")
self._handle_logic_gate(current_step, is_success)
except WorkflowExit:
self.logger.info("命中 if_pass='exit',工作流被主动要求结束。")
await self._push_sse("[工作流结束] 主动退出")
break
except WorkflowError as e:
self.logger.error(f"{e},终止工作流。")
self.workflow.status.status = "failed"
await self._push_sse(f"[工作流失败] {e}")
break
except Exception as e:
self.logger.error(f"❌ Step {current_step_id} 发生系统级未捕获异常: {e}", exc_info=True)
current_step.status = "failed"
self.workflow.status.status = "failed"
await self._push_sse(f"[工作流异常] {e}")
break
self.logger.info(f"✅ 工作流 {self.workflow.title} 执行步骤结束。")
self.workflow.output = self.workflow.context_memory
await self._push_sse(f"[工作流完成] {self.workflow.title}")
await self._report_results()
async def _report_results(self):
"""
结果汇报函数
在工作流结束后执行
Returns:
"""
if self.workflow.status.status == "failed":
self.logger.warning("工作流执行失败,跳过正常汇报流程。")
return
try:
self.logger.info("开始生成工作流结束技术报告...")
report = ""
if self.consciousness_node:
supervisory_input = ForSupervisoryInput(
workflow=self.workflow,
original_command=self.workflow.command or "未知命令"
)
report_obj = await self.consciousness_node.working.remote(supervisory_input)
if isinstance(report_obj, ForSupervisoryNode):
report = report_obj.output
elif isinstance(report_obj, str):
report = report_obj
self.logger.debug(f"生成的报告摘要: {report[:100]}...")
else:
self.logger.warning("未提供 consciousness_node 句柄,跳过报告生成。")
if self.supervisory_node:
term_msg = TerminationMessage(
platform=self.workflow.event_info.platform,
user_name=self.workflow.event_info.user_name,
message=f"工作流执行完毕。系统报告:{report}"
)
user_response = await self.supervisory_node.working.remote(term_msg)
self.workflow.context_memory["_final_user_response"] = user_response
self.logger.info(f"Supervisory 最终回复:{user_response}")
else:
self.logger.warning("未提供 supervisory_node 句柄,跳过用户反馈生成。")
except Exception:
self.logger.exception("生成工作流执行汇报时发生错误")
async def _dispatch_to_node(self, step: WorkStep, input_data: Any) -> tuple[Any, bool]:
"""
分流器
调用当前step的执行对象
Args:
step: WorkStep对象,当前需要执行的step
input_data: 输入数据
Returns:
返回llm的输出和一个bool类型的判断
"""
self.logger.debug(f"正在向 {step.node} 节点发送动作 {step.action}...")
try:
if step.node == "control_node":
if not self.control_node:
raise WorkflowError("未提供 control_node 句柄!")
payload = ControlForWorkflowInput(workflow_step=step)
# 可选:如果 input_data 需要合并,可以扩展 ControlForWorkflowInput 或将其放在 context_memory
result_obj = await self.control_node.working.remote(payload)
if isinstance(result_obj, ControlForWorkflow):
return result_obj.output, True
return result_obj, True
elif step.node == "consciousness_node":
if not self.consciousness_node:
raise WorkflowError("未提供 consciousness_node 句柄!")
original_cmd = self.workflow.command or ""
payload = ConsciousnessForWorkflowInput(
workflow_step=step,
original_command=original_cmd
)
result_obj = await self.consciousness_node.working.remote(payload)
if isinstance(result_obj, ConsciousnessForWorkflow):
return result_obj.output, True
return result_obj, True
elif step.node == "skill_individual":
self.logger.info(f"正在通过 WorkerCluster 调度 skill_individual 执行 {step.action}")
try:
from pretor.utils.ray_hook import ray_actor_hook
worker_cluster = ray_actor_hook("worker_cluster").worker_cluster
task_id = f"{self.workflow.trace_id}_step_{step.step}"
agent_id = step.agent_id or f"default_{step.node}"
task_event = {
"action": step.action,
"description": step.desc,
"input_data": input_data,
"context_memory": self.workflow.context_memory
}
result_response = await worker_cluster.submit_task.remote(task_id, agent_id, task_event)
if result_response.get("success"):
return result_response.get("data"), True
else:
self.logger.error(f"WorkerCluster 执行 {step.node} 失败: {result_response.get('error')}")
return result_response.get("error"), False
except Exception as e:
self.logger.exception(f"调度 WorkerCluster 执行 {step.node} 时发生异常: {e}")
raise WorkflowError(f"WorkerCluster 调度异常: {e}")
else:
raise WorkflowError(f"未知的节点类型:{step.node}")
except Exception:
self.logger.exception(f"节点 {step.node} 执行动作 {step.action} 失败")
return None, False
def _handle_logic_gate(self, step: WorkStep, is_success: bool):
"""
状态机,检测任务执行情况
Args:
step: WorkStep对象,当前执行的step
is_success: bool类型,当前步骤是否成功
"""
gate = step.logic_gate
if is_success:
if gate and gate.if_pass == "exit":
raise WorkflowExit()
self.workflow.status.step += 1
else:
if not gate or not gate.if_fail:
raise WorkflowError(f"步骤 {step.step} 失败且未配置 if_fail 兜底方案")
match gate.if_fail.split("_"):
case ["jump", "to", "step", target] if target.isdigit():
target_step = int(target)
self.logger.warning(f"触发逻辑门分支!从 Step {step.step} 跳转至 Step {target_step}")
self.workflow.status.step = target_step
case _:
raise WorkflowError(f"未知的 if_fail 格式: {gate.if_fail}")
@ray.remote
class WorkflowRunningEngine:
def __init__(self, consciousness_node=None, control_node=None, supervisory_node=None):
from pretor.utils.logger import get_logger
self.logger = get_logger('workflow_runner')
self.runner_engine = {}
self.workflow_queue: asyncio.Queue[PretorEvent] = None
self.consciousness_node = consciousness_node
self.control_node = control_node
self.supervisory_node = supervisory_node
self.global_state_machine = None
async def run(self):
# Move actor hook to async start so we don't race during __init__ across cluster
self.global_state_machine = ray_actor_hook("global_state_machine").global_state_machine
self.workflow_queue = asyncio.Queue()
self.runner_engine = {
f"runner_{i}": asyncio.create_task(self.runner(i))
for i in range(10)
}
async def put_event(self, event: PretorEvent) -> None:
await self.workflow_queue.put(event)
async def runner(self, i: int) -> None:
"""
runner方法,从self.workflow_queue中不断取出任务并执行
Args:
i: runner序列号
"""
while True:
try:
event = await self.workflow_queue.get()
self.logger.info(f"WorkflowRunningEngine: runner_{i} 接收到事件 {event.trace_id} 准备生成工作流。")
if not self.consciousness_node:
raise WorkflowError("未配置 consciousness_node,无法生成工作流")
workflow_template_name = event.context.get("workflow_template", "")
workflow_template = get_workflow_template(workflow_template_name) if workflow_template_name else None
available_skills = None
if self.global_state_machine:
try:
all_individuals = await self.global_state_machine.list_individuals.remote()
available_skills = []
for agent_id, config in all_individuals.items():
if config.get("agent_type") == "skill_individual" or config.get("type") == "skill_individual":
available_skills.append({
"agent_id": agent_id,
"name": config.get("agent_name", "Unknown"),
"description": config.get("description", "")
})
except Exception as e:
self.logger.warning(f"获取Skill Individual列表失败: {e}")
payload = ForWorkflowEngineInput(
original_command=event.message,
workflow_template=workflow_template,
available_skills=available_skills
)
result_obj = await self.consciousness_node.working.remote(payload)
if isinstance(result_obj, ForWorkflowEngine):
workflow = result_obj.workflow
workflow.trace_id = event.trace_id
workflow.command = event.message
workflow.event_info = EventInfo(platform=event.platform,
user_name=event.user_name,)
self.logger.info(
f"WorkflowRunningEngine: runner_{i} 成功生成工作流 {workflow.trace_id}:{workflow.title}")
await self.global_state_machine.update_workflow.remote(event.trace_id, workflow)
workflow_engine = WorkflowEngine(workflow,
self.consciousness_node,
self.control_node,
self.supervisory_node)
await workflow_engine.run()
else:
self.logger.error(f"WorkflowRunningEngine: runner_{i} 无法生成工作流,返回类型为 {type(result_obj)}")
except asyncio.CancelledError:
self.logger.info(f"WorkflowRunningEngine: runner_{i} 被取消。")
raise
except Exception as e:
self.logger.error(f"WorkflowRunningEngine: runner_{i} 遇到未捕获的异常: {e}", exc_info=True)
@@ -0,0 +1,14 @@
# Copyright 2026 zhaoxi826
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
@@ -0,0 +1,39 @@
# Copyright 2026 zhaoxi826
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from pydantic import BaseModel, model_validator
from typing import Dict,List
class WorkflowTemplateStep(BaseModel):
step: int
node: str
action: str
desc: str
input: List[str]
output: List[str]
logic_gate: Dict[str, str]
class WorkflowTemplate(BaseModel):
name: str
desc: str
work_link: list[WorkflowTemplateStep]
@model_validator(mode='after')
def validate_steps(self) -> 'WorkflowTemplate':
steps = [s.step for s in self.work_link]
if len(steps) != len(set(steps)):
raise ValueError("Step numbers in work_link must be unique")
return self
@@ -0,0 +1,26 @@
# Copyright 2026 zhaoxi826
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from pathlib import Path
from pretor.core.workflow.workflow_template_generator.workflow_template import WorkflowTemplate
class WorkflowTemplateGenerator:
@staticmethod
def generate_workflow_template(workflow_template: WorkflowTemplate) -> WorkflowTemplate:
output_dir = Path("pretor") / "workflow_template"
if not output_dir.exists():
output_dir.mkdir(parents=True)
output_file = output_dir / f"{workflow_template.name}_workflow_template.json"
with output_file.open("w", encoding="utf-8") as f:
f.write(workflow_template.model_dump_json(indent=4))
@@ -0,0 +1,56 @@
# Copyright 2026 zhaoxi826
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import json
from pretor.core.workflow.workflow_template_generator.workflow_template_generator import WorkflowTemplateGenerator
from pathlib import Path
from pretor.core.workflow.workflow_template_generator.workflow_template import WorkflowTemplate
from pretor.utils.logger import get_logger
logger = get_logger('workflow_template_manager')
class WorkflowManager:
def __init__(self):
self.workflow_template_generator = WorkflowTemplateGenerator()
self.workflow_templates_registry = {}
self.template_path = Path("pretor/workflow_template")
self._load_workflow_template()
def _load_workflow_template(self) -> None:
for workflow_template_file in self.template_path.glob("*_workflow_template.json"):
with workflow_template_file.open("r",encoding="utf-8") as f:
try:
workflow_template = json.load(f)
self.workflow_templates_registry[workflow_template.get("name")] = workflow_template.get("desc")
except json.decoder.JSONDecodeError:
logger.warning(f"{workflow_template_file}不是json文件或格式错误")
except KeyError:
logger.warning(f"{workflow_template_file}不符合workflow_template格式")
def generate_workflow_template(self, workflow_template: WorkflowTemplate) -> None:
try:
workflow_template = self.workflow_template_generator.generate_workflow_template(workflow_template=workflow_template)
self.workflow_templates_registry[workflow_template.name] = workflow_template.desc
except Exception:
logger.exception("Failed to generate workflow template")
def add_workflow_template(self, template_name: str, workflow_template: WorkflowTemplate) -> None:
self.generate_workflow_template(workflow_template)
def get_all_workflow_templates(self) -> dict:
return self.workflow_templates_registry
def delete_workflow_template(self, template_name: str) -> None:
if template_name in self.workflow_templates_registry:
del self.workflow_templates_registry[template_name]
View File
+13
View File
@@ -0,0 +1,13 @@
# Copyright 2026 zhaoxi826
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
@@ -0,0 +1,2 @@
from .approval import ApprovalToolData, approval
__all__ = ["ApprovalToolData", "approval"]
@@ -0,0 +1,39 @@
# Copyright 2026 zhaoxi826
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from pretor.plugin.tool_plugin.base_tool import BaseToolData
from pretor.utils.ray_hook import ray_actor_hook
from typing import List, Literal, Dict
class ApprovalToolData(BaseToolData):
is_system: bool = True
action_scope: List[Literal["control_node", "consciousness_node", "supervisory_node", "growth_node", "", ""]] = [
"control_node", "consciousness_node"]
config_args: Dict[str, str] = {}
async def approval(message: str, trace_id: str) -> str:
"""
当任务存在某些高风险操作或者计划需要让用户审批,发送请求给用户等待用户审批
Args:
message: 发送给用户的请求
trace_id:
Returns:
用户的审批结果
"""
actor_list = ray_actor_hook("global_state_machine")
await actor_list.global_state_machine.put_pending.remote(trace_id, message)
reply = await actor_list.global_state_machine.get_received.remote(trace_id)
return reply
@@ -0,0 +1,2 @@
{
}
+23
View File
@@ -0,0 +1,23 @@
# Copyright 2026 zhaoxi826
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from pydantic import BaseModel
from typing import List, Literal, Dict
from pydantic import ConfigDict
class BaseToolData(BaseModel):
model_config = ConfigDict(extra="allow")
is_system: bool
action_scope: List[Literal["control_node", "consciousness_node", "supervisory_node", "growth_node", "", ""]] = []
config_args: Dict[str, str] = {}
@@ -0,0 +1,3 @@
from .file_reader import FileReaderData, file_reader
__all__ = ["FileReaderData", "file_reader"]
@@ -0,0 +1,43 @@
# Copyright 2026 zhaoxi826
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from pydantic_ai import RunContext
from pretor.plugin.tool_plugin.base_tool import BaseToolData
import os
class FileReaderData(BaseToolData):
is_system: bool = True
name: str = "file_reader"
description: str = "读取本地文件的内容"
def file_reader(ctx: RunContext, filepath: str) -> str:
"""读取本地文件内容的工具。
Args:
filepath: 目标文件的绝对路径或相对路径。
Returns:
如果文件存在并可读,返回文件内容;否则返回错误信息。
"""
if not os.path.exists(filepath):
return f"Error: 文件 {filepath} 不存在。"
if not os.path.isfile(filepath):
return f"Error: {filepath} 不是一个文件。"
try:
with open(filepath, 'r', encoding='utf-8') as f:
content = f.read()
return content
except Exception as e:
return f"Error: 读取文件失败,原因:{str(e)}"
+14
View File
@@ -0,0 +1,14 @@
# Copyright 2026 zhaoxi826
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
+104
View File
@@ -0,0 +1,104 @@
# Copyright 2026 zhaoxi826
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import jwt
import os
from datetime import datetime, timedelta, timezone
from typing import Optional
from fastapi import HTTPException, status, Request
from pydantic import BaseModel, ValidationError
from pretor.core.database.table.user import User
from pwdlib import PasswordHash
class TokenData(BaseModel):
user_id: str
username: Optional[str] = None
exp: Optional[int] = None
SECRET_KEY = os.getenv("SECRET_KEY")
ALGORITHM = "HS256"
ACCESS_TOKEN_EXPIRE_MINUTES = 60 * 24
password_hasher = PasswordHash.recommended()
class Accessor:
@staticmethod
def _decode_token(token: str) -> TokenData:
try:
payload = jwt.decode(
token,
SECRET_KEY,
algorithms=[ALGORITHM]
)
return TokenData(**payload)
except jwt.ExpiredSignatureError:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Token 已过期",
)
except (jwt.InvalidTokenError, ValidationError):
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="无效的认证凭证",
)
@staticmethod
def _create_access_token(data: dict) -> str:
to_encode = data.copy()
expire = datetime.now(timezone.utc) + timedelta(minutes=ACCESS_TOKEN_EXPIRE_MINUTES)
to_encode.update({"exp": int(expire.timestamp())})
return jwt.encode(to_encode, SECRET_KEY, algorithm=ALGORITHM)
@staticmethod
def verify_password(plain_password: str, hashed_password: str) -> bool:
return password_hasher.verify(plain_password, hashed_password)
@staticmethod
def get_current_user(request: Request) -> TokenData:
auth_header = request.headers.get("Authorization")
if not auth_header or not auth_header.startswith("Bearer "):
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="未提供认证头部",
)
token = auth_header.split(" ")[1]
return Accessor._decode_token(token)
@staticmethod
def login_hashed_password(user: User, password: str) -> str:
if not user:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="用户不存在",
)
if not Accessor.verify_password(password, user.hashed_password):
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="用户名或密码错误",
)
token_payload = {
"user_id": str(user.user_id),
"username": user.user_name
}
return Accessor._create_access_token(data=token_payload)
@staticmethod
def hash_password(password: str) -> str:
if not password:
raise ValueError("密码不能为空")
if len(password) < 6:
raise ValueError("密码长度不能小于 6 位")
return password_hasher.hash(password)
+25
View File
@@ -0,0 +1,25 @@
# Copyright 2026 zhaoxi826
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from pydantic import BaseModel
class ResponseModel(BaseModel):
pass
class DepsModel(BaseModel):
pass
class InputModel(BaseModel):
pass
+25
View File
@@ -0,0 +1,25 @@
from rich.console import Console
from rich.text import Text
import yaml
def print_banner() -> None:
with open("config/config.yml","r") as config:
config = yaml.load(config, Loader=yaml.FullLoader)
version = config.get("version", "unknown")
pretor_banner = """
██████╗ ██████╗ ███████╗████████╗ ██████╗ ██████╗
██╔══██╗██╔══██╗██╔════╝╚══██╔══╝██╔═══██╗██╔══██╗
██████╔╝██████╔╝█████╗ ██║ ██║ ██║██████╔╝
██╔═══╝ ██╔══██╗██╔══╝ ██║ ██║ ██║██╔══██╗
██║ ██║ ██║███████╗ ██║ ╚██████╔╝██║ ██║
╚═╝ ╚═╝ ╚═╝╚══════╝ ╚═╝ ╚═════╝ ╚═╝ ╚═╝
"""
console = Console()
banner_colored = Text(pretor_banner, style="gold3 bold")
console.print(banner_colored)
console.print("=" * 40, style="dim") # dim=灰色,低调
console.print("🚀 Multi-Agent Orchestration Platform", style="blue")
console.print(f"📦 Version: {version}", style="green")
console.print("👤 Author: zhaoxi826", style="yellow")
console.print("📜 License: Apache 2.0", style="magenta")
console.print("🐙 github: https://github.com/zhaoxi826/pretor", style="yellow")
console.print("=" * 40, style="dim")
+53
View File
@@ -0,0 +1,53 @@
# Copyright 2026 zhaoxi826
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Annotated
from fastapi import Depends, HTTPException
from pretor.utils.access import Accessor, TokenData
from pretor.core.database.table.user import UserAuthority
from pretor.utils.ray_hook import ray_actor_hook
async def get_authority(user_id: str) -> UserAuthority:
from pretor.utils.error import UserNotExistError
postgres_database = ray_actor_hook("postgres_database").postgres_database
try:
user_authority = await postgres_database.get_user_authority.remote(user_id=user_id)
return user_authority
except UserNotExistError:
raise HTTPException(
status_code=401,
detail="用户不存在或已被删除,请重新登录"
)
except Exception as e:
# Check if it's a RayTaskError wrapping UserNotExistError
if "UserNotExistError" in str(e):
raise HTTPException(
status_code=401,
detail="用户不存在或已被删除,请重新登录"
)
raise
class RoleChecker:
def __init__(self, **kwargs):
self.allowed_roles = kwargs.get("allowed_roles", )
async def __call__(self,
token_data: Annotated[TokenData, Depends(Accessor.get_current_user)]):
user_authority = await get_authority(token_data.user_id)
if user_authority < self.allowed_roles:
raise HTTPException(
status_code=403,
detail={"message": f"User {token_data.user_id} does not have allowed roles"},
)
return token_data
+50
View File
@@ -0,0 +1,50 @@
# Copyright 2026 zhaoxi826
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
class RetryableError(Exception):
"""基类:所有可重试错误(如网络断开、抖动等临时性故障)"""
pass
class NonRetryableError(Exception):
"""基类:所有不可重试错误(如数据验证失败、类型错误等业务逻辑故障)"""
pass
class DemandError(NonRetryableError):
pass
class ModelNotExistError(Exception):
pass
class UserError(Exception):
pass
class UserNotExistError(UserError):
pass
class UserPasswordError(UserError):
pass
class ProviderError(Exception):
pass
class ProviderNotExistError(ProviderError):
pass
class WorkflowError(Exception):
pass
class WorkflowExit(WorkflowError):
pass
+80
View File
@@ -0,0 +1,80 @@
# Copyright 2026 zhaoxi826
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import importlib.util
import os
import sys
from typing import Callable, Dict, List
import pathlib
from pretor.utils.ray_hook import ray_actor_hook
from pretor.utils.logger import get_logger
logger = get_logger('get_tool')
_tool_cache: Dict[str, Callable] = {}
def _get_tool_func(tool_name: str) -> Callable | None:
func = _tool_cache.get(tool_name, None)
if func:
return func
app_root = "/app"
tool_plugin_dir = os.path.join(app_root, "pretor", "plugin", "tool_plugin", tool_name)
if not os.path.exists(tool_plugin_dir) or not os.path.isdir(tool_plugin_dir):
logger.error(f"Tool directory not found: {tool_plugin_dir}")
return None
init_file = os.path.join(tool_plugin_dir, "__init__.py")
if not os.path.exists(init_file):
logger.error(f"Tool init file not found: {init_file}")
return None
try:
module_name = f"pretor.plugin.tool_plugin.{tool_name}"
spec = importlib.util.spec_from_file_location(module_name, init_file)
if spec is None or spec.loader is None:
logger.error(f"Failed to create spec for {module_name}")
return None
module = importlib.util.module_from_spec(spec)
sys.modules[module_name] = module
spec.loader.exec_module(module)
func = getattr(module, tool_name, None)
if not callable(func):
logger.error(f"Tool function '{tool_name}' not found or not callable in {module_name}")
return None
_tool_cache[tool_name] = func
return func
except Exception as e:
logger.error(f"Failed to load module {module_name}: {e}")
return None
def del_tool_cache(tool_name: str) -> None:
if tool_name in _tool_cache:
del _tool_cache[tool_name]
def load_tools_from_list(tool_names: List[str] | None) -> List[Callable]:
if not tool_names:
return []
tool_list = []
for tool_name in tool_names:
tool_func = _get_tool_func(tool_name)
if tool_func:
tool_list.append(tool_func)
return tool_list
+44
View File
@@ -0,0 +1,44 @@
# Copyright 2026 zhaoxi826
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from loguru import logger
from rich.logging import RichHandler
from loguru._logger import Logger
def setup_logger() -> Logger:
logger.remove()
def format_record(record):
# Format string for rich handler
actor = record["extra"].get("actor_name", "System")
trace_id = record["extra"].get("trace_id", "")
trace_str = f" | trace_id:({trace_id})" if trace_id else ""
return f"actor:({actor}){trace_str} : {record['message']}"
logger.configure(extra={"actor_name": "System", "trace_id": ""})
logger.add(
RichHandler(rich_tracebacks=True, markup=True, show_time=False, show_level=False, show_path=False),
format=format_record,
level="DEBUG",
enqueue=True, # 异步记录
)
return logger
global_logger = setup_logger()
def get_logger(actor_name: str, trace_id: str = "") -> Logger:
return global_logger.bind(actor_name=actor_name, trace_id=trace_id)
+37
View File
@@ -0,0 +1,37 @@
# Copyright 2026 zhaoxi826
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Type, TypeVar
from pydantic import BaseModel
T = TypeVar("T", bound=Type[BaseModel])
def pickle(cls: T) -> T:
"""
类装饰器pickle
通过装饰继承了BaseModel的类,用pydantic的高效序列化替代python原生__reduce__魔术方法,实现ray在通讯时的高效序列化
Args:
cls: 继承了BaseModel类的类,需要被装饰的对象
Returns:
返回被重写了__reduce__魔术方法的cls类
"""
def __reduce__(self):
# 1. 序列化:触发 Pydantic-core (Rust) 的极速序列化
data = self.model_dump_json()
# 2. 反序列化:告诉 Pickle 重建时调用 cls.model_validate_json
return cls.model_validate_json, (data,)
cls.__reduce__ = __reduce__
return cls
+50
View File
@@ -0,0 +1,50 @@
# Copyright 2026 zhaoxi826
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import ray
from functools import lru_cache
class ActorList:
def __init__(self):
super().__setattr__('dict', {})
def __setattr__(self, key, value):
self.dict[key] = value
def __getattr__(self, key):
if key in self.dict:
return self.dict[key]
raise AttributeError(f"ActorList 对象没有属性 '{key}'")
def __delattr__(self, key):
if key in self.dict:
del self.dict[key]
else:
raise AttributeError(f"ActorList对象没有属性 '{key}'")
@lru_cache(maxsize=128)
def _get_cached_actor_handle(actor_name: str):
"""缓存接口"""
return ray.get_actor(actor_name, namespace="pretor")
def clear_actor_cache():
"""清理接口"""
_get_cached_actor_handle.cache_clear()
def ray_actor_hook(*actor_names: str):
actor_list = ActorList()
for actor_name in actor_names:
handle = _get_cached_actor_handle(actor_name)
setattr(actor_list, actor_name, handle)
return actor_list
+31
View File
@@ -0,0 +1,31 @@
import asyncio
from functools import wraps
from pretor.utils.error import RetryableError
def retry_on_retryable_error(max_retries=3, base_delay=1):
def decorator(func):
if asyncio.iscoroutinefunction(func):
@wraps(func)
async def async_wrapper(*args, **kwargs):
for attempt in range(max_retries):
try:
return await func(*args, **kwargs)
except RetryableError:
if attempt == max_retries - 1:
raise
await asyncio.sleep(base_delay * (2 ** attempt))
return async_wrapper
else:
@wraps(func)
def sync_wrapper(*args, **kwargs):
import time
for attempt in range(max_retries):
try:
return func(*args, **kwargs)
except RetryableError:
if attempt == max_retries - 1:
raise
time.sleep(base_delay * (2 ** attempt))
return sync_wrapper
return decorator
+11
View File
@@ -0,0 +1,11 @@
from pretor.worker_individual.base_individual import BaseIndividual
from pretor.worker_individual.skill_individual import SkillIndividual
from pretor.worker_individual.ordinary_individual import OrdinaryIndividual
from pretor.worker_individual.special_individual import SpecialIndividual
__all__ = [
"BaseIndividual",
"SkillIndividual",
"OrdinaryIndividual",
"SpecialIndividual",
]
@@ -0,0 +1,76 @@
# Copyright 2026 zhaoxi826
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from pydantic_ai import Agent, RunContext
from pydantic import Field
from pretor.adapter.model_adapter.agent_factory import AgentFactory
from pretor.core.global_state_machine.model_provider.base_provider import Provider
from pretor.utils.agent_model import ResponseModel, InputModel, DepsModel
from pretor.utils.ray_hook import ray_actor_hook
from pretor.utils.logger import get_logger
logger = get_logger('worker_individual')
class WorkerIndividualResponse(ResponseModel):
output: str = Field(..., description="Worker执行任务的输出结果")
class WorkerIndividualDeps(DepsModel):
task_event: dict
class WorkerIndividualInput(InputModel):
task_event: dict
class BaseIndividual:
"""
Worker Individual 的基类
"""
def __init__(self, agent_config: dict):
self.agent_config = agent_config
self.agent_id = agent_config.get("agent_id")
self.agent: Agent | None = None
async def _init_agent(self, agent_name: str, system_prompt: str):
from pretor.utils.get_tool import load_tools_from_list
global_state_machine = ray_actor_hook("global_state_machine").global_state_machine
provider_title = self.agent_config.get("provider_title", "openai") # default fallback
model_id = self.agent_config.get("model_id", "gpt-4o") # default fallback
tools_list = self.agent_config.get("tools", None)
provider: Provider = await global_state_machine.get_provider.remote( provider_title)
agent_factory = AgentFactory()
callables = load_tools_from_list(tools_list)
self.agent = agent_factory.create_agent(
provider=provider,
model_id=model_id,
output_type=WorkerIndividualResponse,
system_prompt=system_prompt,
deps_type=WorkerIndividualDeps,
agent_name=agent_name,
tools=callables
)
@self.agent.system_prompt
async def dynamic_prompt(ctx: RunContext[WorkerIndividualDeps]):
prompt = system_prompt + "\n\n"
prompt += (
f"=== 当前任务上下文 ===\n"
f"{ctx.deps.task_event}\n"
)
return prompt
async def run(self, task_event: dict) -> dict:
raise NotImplementedError("子类必须实现 run 方法")
+14
View File
@@ -0,0 +1,14 @@
worker_individual
---
**worker_individual**是pretor中的基础工作对象,主要分为三类:**skill_individual**,**ordinary_individual**和**special_individual**,庞大的**worker_individual**将负责具体的生产工作。
---
## worker_individual分类
### skill_individual(专家子个体)
**skill_individual(专家子个体)** 是拥有专业**skill**的agent,通常使用MoE(混合专家模型)或者大参数的专家模型来作为agent的模型。通过装配专业化的知识从而实现完成复杂任务。
### ordinary_individual(普通子个体)
**ordinary_individual(普通子个体)** 是普通的agent,通常使用小参数微调专家模型来作为agent的模型。通过专业化数据的微调,在一定程度上实现比大参数MoE模型在单一方面上的能力。
### special_individual(特殊子个体)
**special_individual(特殊子个体)** 是特殊的agent,这类agent一般不承担普通的生成任务,更多是实现一些特殊的任务,比如生成语音生成视频等。
@@ -0,0 +1,43 @@
# Copyright 2026 zhaoxi826
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from pretor.worker_individual.base_individual import BaseIndividual, WorkerIndividualDeps
from pretor.utils.logger import get_logger
logger = get_logger('ordinary_individual')
class OrdinaryIndividual(BaseIndividual):
"""
普通子个体:普通的 agent。
"""
def __init__(self, agent_config: dict):
super().__init__(agent_config)
async def run(self, task_event: dict) -> dict:
if self.agent is None:
system_prompt = self.agent_config.get("prompt", "你是一个普通的AI助手,请尽力完成给定的任务。")
await self._init_agent("ordinary_individual", system_prompt)
deps = WorkerIndividualDeps(task_event=task_event)
self.agent.retries = 3
try:
result = await self.agent.run(
f"请执行以下任务:\n{task_event}",
deps=deps
)
return {"output": result.data.output}
except Exception as e:
logger.exception(f"OrdinaryIndividual {self.agent_id} 执行失败: {e}")
raise
@@ -0,0 +1,110 @@
# Copyright 2026 zhaoxi826
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from pretor.worker_individual.base_individual import BaseIndividual, WorkerIndividualDeps
from pretor.utils.logger import get_logger
import os
import json
from pydantic_ai import Tool
import importlib.util
logger = get_logger('skill_individual')
class SkillIndividual(BaseIndividual):
"""
专家子个体:拥有专业 skill 的 agent。
"""
def __init__(self, agent_config: dict):
super().__init__(agent_config)
async def _load_skill_tools(self):
"""动态加载已绑定的 skill 工具。"""
tools = []
bound_skill = self.agent_config.get("bound_skill", "")
# bound_skill can be string or dict {"skill_name": ["file1", "file2"]}
skill_mapper = {}
if isinstance(bound_skill, str) and bound_skill:
try:
skill_mapper = json.loads(bound_skill)
except json.JSONDecodeError:
pass
elif isinstance(bound_skill, dict):
skill_mapper = bound_skill
skill_base_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "plugin", "skill"))
for skill_name, _ in skill_mapper.items():
skill_path = os.path.join(skill_base_dir, skill_name)
metadata_path = os.path.join(skill_path, "metadata.json")
if not os.path.exists(metadata_path):
continue
try:
with open(metadata_path, 'r', encoding='utf-8') as f:
metadata = json.load(f)
except Exception as e:
logger.error(f"Failed to load metadata for skill {skill_name}: {e}")
continue
if "functions" in metadata:
for func_info in metadata["functions"]:
# Ensure path is absolute
script_path = func_info.get("file_path", "")
if not os.path.isabs(script_path):
script_path = os.path.join(skill_path, script_path)
if not os.path.exists(script_path):
logger.warning(f"Skill script not found: {script_path}")
continue
func_name = func_info.get("name")
try:
# Dynamically load the python module
spec = importlib.util.spec_from_file_location(func_name, script_path)
module = importlib.util.module_from_spec(spec)
spec.loader.exec_module(module)
func = getattr(module, func_name)
if callable(func):
# Convert to PydanticAI Tool
tool = Tool(func, name=func_name, description=func_info.get("docstring", ""))
tools.append(tool)
logger.info(f"Loaded skill tool: {func_name} from {skill_name}")
except Exception as e:
logger.error(f"Failed to load function {func_name} from {script_path}: {e}")
return tools
async def run(self, task_event: dict) -> dict:
if self.agent is None:
system_prompt = self.agent_config.get("prompt",
"你是一个拥有专业技能的专家级AI助手,请利用你的专业知识完成给定的任务。")
await self._init_agent("skill_individual", system_prompt)
deps = WorkerIndividualDeps(task_event=task_event)
self.agent.retries = 3
tools = await self._load_skill_tools()
try:
result = await self.agent.run(
f"请执行以下任务:\n{task_event}",
deps=deps,
tools=tools if tools else None
)
return {"output": result.data.output}
except Exception as e:
logger.exception(f"SkillIndividual {self.agent_id} 执行失败: {e}")
raise
@@ -0,0 +1,43 @@
# Copyright 2026 zhaoxi826
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from pretor.worker_individual.base_individual import BaseIndividual, WorkerIndividualDeps
from pretor.utils.logger import get_logger
logger = get_logger('special_individual')
class SpecialIndividual(BaseIndividual):
"""
特殊子个体:执行特殊任务的 agent,如生成语音、视频等。
"""
def __init__(self, agent_config: dict):
super().__init__(agent_config)
async def run(self, task_event: dict) -> dict:
if self.agent is None:
system_prompt = self.agent_config.get("prompt", "你是一个特殊的AI助手,负责处理特殊类型的任务。")
await self._init_agent("special_individual", system_prompt)
deps = WorkerIndividualDeps(task_event=task_event)
self.agent.retries = 3
try:
result = await self.agent.run(
f"请执行以下任务:\n{task_event}",
deps=deps
)
return {"output": result.data.output}
except Exception as e:
logger.exception(f"SpecialIndividual {self.agent_id} 执行失败: {e}")
raise
+148
View File
@@ -0,0 +1,148 @@
# Copyright 2026 zhaoxi826
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import ray
import time
import asyncio
from collections import OrderedDict
from ray.util.queue import Queue
from pretor.utils.ray_hook import ray_actor_hook
from pretor.worker_individual.base_individual import BaseIndividual
from pretor.worker_individual.skill_individual import SkillIndividual
from pretor.worker_individual.ordinary_individual import OrdinaryIndividual
from pretor.worker_individual.special_individual import SpecialIndividual
from pretor.utils.logger import get_logger
@ray.remote
class WorkerCluster:
"""
工作集群 Actor:管理和调度所有的 worker_individual
设计理念:按需加载,内存 LRU 淘汰,避免 Actor 爆炸
"""
def __init__(self, max_capacity: int = 200, num_runners: int = 10):
self.max_capacity = max_capacity
self._active_workers: OrderedDict[str, BaseIndividual] = OrderedDict()
self.status = "running"
self.task_queue = None
self.results_futures = {}
self.runners = []
self.num_runners = num_runners
self.logger = get_logger('worker_cluster')
async def start(self):
if self.task_queue is None:
self.task_queue = Queue()
self.runners = [asyncio.create_task(self._runner(i)) for i in range(self.num_runners)]
self.logger.info(f"WorkerCluster 已启动 {self.num_runners} 个 runner 协程。")
async def _recruit_worker(self, agent_id: str) -> BaseIndividual:
"""内部方法:招聘/唤醒一个具体的 Agent 对象"""
if agent_id in self._active_workers:
self._active_workers.move_to_end(agent_id)
return self._active_workers[agent_id]
global_state_machine = ray_actor_hook("global_state_machine").global_state_machine
agent_config = await global_state_machine.get_individual.remote( agent_id)
if not agent_config:
raise ValueError(f"无法唤醒 Agent {agent_id}:数据库中不存在该档案")
worker_type = agent_config.get("type", "ordinary")
if worker_type == "skill":
worker = SkillIndividual(agent_config)
elif worker_type == "special":
worker = SpecialIndividual(agent_config)
else:
worker = OrdinaryIndividual(agent_config)
self._active_workers[agent_id] = worker
if len(self._active_workers) > self.max_capacity:
evicted_id, _ = self._active_workers.popitem(last=False)
self.logger.info(f"[WorkerCluster] 内存池满,休眠老化 Agent: {evicted_id}")
return worker
async def _runner(self, runner_id: int):
while True:
try:
if self.task_queue is None:
await asyncio.sleep(0.1)
continue
task = await self.task_queue.get_async()
task_id = task.get("task_id")
agent_id = task.get("agent_id")
task_event = task.get("task_event")
self.logger.debug(f"[WorkerCluster Runner {runner_id}] 开始处理任务 {task_id} 给 Agent {agent_id}")
start_time = time.time()
try:
worker = await self._recruit_worker(agent_id)
result = await worker.run(task_event)
cost_time = time.time() - start_time
response = {
"success": True,
"agent_id": agent_id,
"data": result,
"metrics": {"cost_time_sec": round(cost_time, 2)}
}
except Exception as e:
self.logger.exception(f"[WorkerCluster Runner {runner_id}] 执行任务 {task_id} 时发生错误: {e}")
response = {
"success": False,
"agent_id": agent_id,
"error": str(e)
}
if task_id in self.results_futures:
future = self.results_futures[task_id]
if not future.done():
future.set_result(response)
except Exception as e:
self.logger.error(f"[WorkerCluster Runner {runner_id}] 循环发生异常: {e}")
await asyncio.sleep(1)
async def submit_task(self, task_id: str, agent_id: str, task_event: dict):
if not self.runners:
await self.start()
future = asyncio.Future()
self.results_futures[task_id] = future
task = {
"task_id": task_id,
"agent_id": agent_id,
"task_event": task_event
}
await self.task_queue.put_async(task)
self.logger.debug(f"[WorkerCluster] 任务 {task_id} 已加入队列。")
try:
result = await future
return result
finally:
self.results_futures.pop(task_id, None)
def get_cluster_metrics(self):
return {
"active_worker_count": len(self._active_workers),
"max_capacity": self.max_capacity,
"cached_agent_ids": list(self._active_workers.keys()),
"queue_size": self.task_queue.size()
}
@@ -0,0 +1,82 @@
{
"name": "programme",
"desc": "一个示范型的编程工作流",
"work_link": [
{
"step": 1,
"node": "consciousness_node",
"action": "architect",
"desc": "【人类规范】分析用户需求,构建程序整体架构,定义需要拉起的子个体名称与数量。"
},
{
"step": 2,
"node": "control_node",
"action": "spawn_actors",
"desc": "【人类规范】根据架构要求,拉起对应的开发与测试工作组,并挂载 /workspace 目录。"
},
{
"step": 3,
"node": "composite_individual",
"action": "decompose",
"desc": "【人类规范】将整体架构拆解为可独立执行的原子任务包 (Task Packets)。",
"output": "task_packets"
},
{
"step": 4,
"node": "primary_individual",
"action": "execute_code",
"desc": "【人类规范】执行编码任务,必须确保所有代码写入指定的挂载目录。",
"input": "task_packets",
"output": "source_code"
},
{
"step": 5,
"node": "composite_individual",
"action": "audit",
"desc": "【人类规范】对产出的源码进行静态逻辑检查与 PEP8 代码规范审计。",
"input": "source_code",
"output": "audit_report"
},
{
"step": 6,
"node": "control_node",
"action": "resource_recycle",
"desc": "【安全规范】暂存当前编码子个体的状态,释放非必要显存,为测试环境腾出算力。",
"input": "audit_report"
},
{
"step": 7,
"node": "consciousness_node",
"action": "design_test",
"desc": "【人类规范】基于源码设计测试用例架构,覆盖边缘场景。",
"input": "source_code",
"output": "test_spec"
},
{
"step": 8,
"node": "primary_individual",
"action": "run_test",
"desc": "【人类规范】在独立的 Docker 沙箱中运行 test,并生成结构化的实验报告。",
"input": "test_spec",
"output": "test_report"
},
{
"step": 9,
"node": "consciousness_node",
"action": "analyze_report",
"desc": "【逻辑网关】研究测试报告。如果存在 Error 或 Fail,必须触发逻辑跳转,重写代码。",
"input": "test_report",
"logic_gate": {
"if_fail": "jump_to_step_4",
"if_pass": "continue"
}
},
{
"step": 10,
"node": "supervisory_node",
"action": "terminate_workflow",
"desc": "【系统规范】核对所有产出物,关闭工作流管道,向宿主机发送 .done 信号。",
"input": ["source_code", "test_report"]
}
]
}