chore: initial commit for Pretor v0.1.0-alpha
正式发布 Pretor 平台的首个 alpha 版本。本项目旨在构建一个基于分布式架构的多智能体协同工作流水线。 核心功能实现: 1. 建立基于 BaseIndividual 的动态插件加载机制。 2. 实现三类核心 worker_individual 子个体。 3. 集成 Ray 框架支持分布式集群调度。 4. 基于 PostgreSQL 的全量持久化存储方案。 5. 提供完整的 FastAPI 后端与 React 前端交互界面。
This commit is contained in:
@@ -0,0 +1,14 @@
|
||||
# Copyright 2026 zhaoxi826
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
@@ -0,0 +1,14 @@
|
||||
# Copyright 2026 zhaoxi826
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
@@ -0,0 +1,14 @@
|
||||
# Copyright 2026 zhaoxi826
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
@@ -0,0 +1,77 @@
|
||||
# Copyright 2026 zhaoxi826
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from pydantic_ai import Agent
|
||||
from pydantic_ai.models.openai import OpenAIChatModel
|
||||
from pydantic_ai.models.anthropic import AnthropicModel
|
||||
from pydantic_ai.providers.openai import OpenAIProvider
|
||||
from pydantic_ai.providers.anthropic import AnthropicProvider
|
||||
from pretor.adapter.model_adapter.deepseek_reasoner import DeepSeekReasonerAgent
|
||||
from pretor.core.global_state_machine.model_provider import Provider
|
||||
from pretor.utils.agent_model import ResponseModel, DepsModel
|
||||
from pretor.utils.error import ModelNotExistError
|
||||
|
||||
class AgentFactory:
|
||||
def __init__(self):
|
||||
self._models_mapping = {"openai": (OpenAIChatModel, OpenAIProvider),
|
||||
"claude": (AnthropicModel, AnthropicProvider),
|
||||
"deepseek": (OpenAIChatModel, OpenAIProvider),}
|
||||
|
||||
def create_agent(self,
|
||||
provider: Provider,
|
||||
model_id: str,
|
||||
output_type: ResponseModel,
|
||||
system_prompt: str,
|
||||
deps_type: DepsModel,
|
||||
agent_name: str,
|
||||
tools: list = None) -> Agent:
|
||||
"""
|
||||
create_agent方法,将输入的provider对象实例化为一个pydantic-ai的agent对象
|
||||
|
||||
Args:
|
||||
provider: Provider对象,从global_state_machine中获取
|
||||
model_id: 模型名
|
||||
output_type: 输出格式
|
||||
system_prompt: 系统提示词
|
||||
deps_type: 依赖类型,在agent运行时动态输入的格式化消息
|
||||
agent_name: agent的名字
|
||||
tools: 工具列表
|
||||
|
||||
Returns:
|
||||
返回被实例化的pydantic-ai的Agent对象
|
||||
"""
|
||||
if model_id not in provider.provider_models:
|
||||
raise ModelNotExistError("模型不存在")
|
||||
if provider.provider_type not in self._models_mapping:
|
||||
raise ValueError(f"不支持的协议类型: {provider.provider_type}")
|
||||
model_class, provider_class = self._models_mapping[provider.provider_type]
|
||||
model = model_class(model_id, provider=provider_class(api_key=provider.provider_apikey, base_url=provider.provider_url))
|
||||
match provider.provider_type:
|
||||
case "deepseek":
|
||||
agent = DeepSeekReasonerAgent(model=model,
|
||||
name=agent_name,
|
||||
output_type=output_type,
|
||||
deps_type=deps_type,
|
||||
system_prompt=system_prompt,
|
||||
tools=tools,
|
||||
retries=3,
|
||||
)
|
||||
case _:
|
||||
agent = Agent(model=model,
|
||||
name=agent_name,
|
||||
system_prompt=system_prompt,
|
||||
output_type=output_type,
|
||||
deps_type=deps_type,
|
||||
tools=tools)
|
||||
return agent
|
||||
@@ -0,0 +1,150 @@
|
||||
import re
|
||||
import json
|
||||
from typing import Type, TypeVar, Any, Generic
|
||||
from pydantic import BaseModel, ValidationError
|
||||
from pydantic_ai import Agent, RunContext
|
||||
from pydantic_ai.run import AgentRunResult
|
||||
|
||||
T = TypeVar('T', bound=BaseModel)
|
||||
|
||||
class AgentRunResultProxy:
|
||||
def __init__(self, original, parsed):
|
||||
self._original = original
|
||||
self._parsed = parsed
|
||||
def __getattr__(self, name):
|
||||
if name == 'data':
|
||||
return self._parsed
|
||||
if name == 'output':
|
||||
return self._parsed
|
||||
return getattr(self._original, name)
|
||||
|
||||
class DeepSeekReasonerAgent(Generic[T]):
|
||||
"""
|
||||
专为 DeepSeek-V4/R1 设计的适配器。
|
||||
将结构化输出降级为文本解析模式,并支持重试逻辑以确保系统兼容性。
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model,
|
||||
name,
|
||||
output_type: Any = str,
|
||||
system_prompt: str = "",
|
||||
deps_type: Type[Any] = None,
|
||||
tools: list = None,
|
||||
retries: int = 3,
|
||||
**kwargs
|
||||
):
|
||||
self.output_schema = output_type
|
||||
self.has_custom_output = output_type is not str and output_type is not None
|
||||
self.tools = tools or []
|
||||
self.retries = retries
|
||||
|
||||
format_instruction = ""
|
||||
if self.has_custom_output:
|
||||
try:
|
||||
from pydantic import TypeAdapter
|
||||
schema_dict = TypeAdapter(self.output_schema).json_schema()
|
||||
schema_str = json.dumps(schema_dict, ensure_ascii=False)
|
||||
format_instruction = (
|
||||
f"\n\nCRITICAL: 你必须输出且只能输出一段纯 JSON 格式的数据,"
|
||||
f"并包裹在 ```json 和 ``` 之间。格式必须符合以下 JSON Schema 结构(或对应数据类型):\n"
|
||||
f"{schema_str}"
|
||||
)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
tool_instruction = ""
|
||||
if self.tools:
|
||||
tool_descs = []
|
||||
for t in self.tools:
|
||||
desc = getattr(t, '__name__', str(t))
|
||||
if hasattr(t, '__doc__') and t.__doc__:
|
||||
desc += f": {t.__doc__.strip()}"
|
||||
tool_descs.append(f"- {desc}")
|
||||
tool_instruction = (
|
||||
"\n\n系统为您提供了以下工具。由于当前处于结构化降级模式,无法原生调用。"
|
||||
"但如果您在思考过程中判断必须使用这些工具,请在返回的结构中(或如果是自由文本)注明意图,由外层逻辑进行调度:\n" +
|
||||
"\n".join(tool_descs)
|
||||
)
|
||||
|
||||
self.agent = Agent(
|
||||
model=model,
|
||||
name=name,
|
||||
output_type=str, # Force native agent to return str to disable function calling
|
||||
system_prompt=system_prompt + format_instruction + tool_instruction,
|
||||
deps_type=deps_type,
|
||||
**kwargs
|
||||
)
|
||||
|
||||
def _parse_output(self, text: str) -> Any:
|
||||
if not self.has_custom_output:
|
||||
return text
|
||||
|
||||
match = re.search(r'```json\s*(.*?)\s*```', text, re.DOTALL)
|
||||
json_str = match.group(1).strip() if match else text
|
||||
|
||||
if not json_str.startswith('{') and not json_str.startswith('['):
|
||||
start_obj = json_str.find('{')
|
||||
start_arr = json_str.find('[')
|
||||
start = -1
|
||||
end = -1
|
||||
if start_obj != -1 and (start_arr == -1 or start_obj < start_arr):
|
||||
start = start_obj
|
||||
end = json_str.rfind('}')
|
||||
elif start_arr != -1:
|
||||
start = start_arr
|
||||
end = json_str.rfind(']')
|
||||
|
||||
if start != -1 and end != -1 and end > start:
|
||||
json_str = json_str[start:end+1]
|
||||
|
||||
if not json_str:
|
||||
raise ValueError("未找到有效的 JSON 块。请将结果包装在 ```json 中。")
|
||||
|
||||
try:
|
||||
from pydantic import TypeAdapter
|
||||
adapter = TypeAdapter(self.output_schema)
|
||||
return adapter.validate_json(json_str)
|
||||
except ValidationError as e:
|
||||
raise ValueError(f"返回的 JSON 无法匹配所需结构:{e}")
|
||||
except json.JSONDecodeError as e:
|
||||
raise ValueError(f"返回的不是合法的 JSON:{e}")
|
||||
|
||||
|
||||
def __getattr__(self, item):
|
||||
# Delegate any unknown attributes (like .system_prompt, .tool) to the underlying pydantic_ai Agent
|
||||
return getattr(self.agent, item)
|
||||
|
||||
async def run(self, user_prompt: str, deps: Any = None, message_history: list = None, **kwargs) -> Any:
|
||||
# Custom retry loop
|
||||
current_history = message_history or []
|
||||
last_exception = None
|
||||
|
||||
for attempt in range(self.retries + 1):
|
||||
result = await self.agent.run(
|
||||
user_prompt,
|
||||
deps=deps,
|
||||
message_history=current_history,
|
||||
**kwargs
|
||||
)
|
||||
|
||||
raw_text = result.data if hasattr(result, 'data') else getattr(result, 'output', str(result))
|
||||
|
||||
try:
|
||||
parsed_data = self._parse_output(raw_text)
|
||||
|
||||
# Proxy the result to inject the parsed data seamlessly
|
||||
return AgentRunResultProxy(result, parsed_data)
|
||||
|
||||
except ValueError as e:
|
||||
last_exception = e
|
||||
# Prepare retry prompt
|
||||
user_prompt = f"你的上一次输出解析失败,错误原因是: {e}\n请修正格式后重新输出。"
|
||||
|
||||
# We need to maintain history manually so the model sees what it did wrong
|
||||
# Actually, pydantic-ai manages history inside the result. Let's use the all_messages from result
|
||||
if hasattr(result, 'all_messages'):
|
||||
current_history = result.all_messages()
|
||||
|
||||
raise ValueError(f"Exceeded maximum retries ({self.retries}) for output validation. Last error: {last_exception}")
|
||||
@@ -0,0 +1,14 @@
|
||||
# Copyright 2026 zhaoxi826
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
@@ -0,0 +1,185 @@
|
||||
# Copyright 2026 zhaoxi826
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
from typing import Union
|
||||
from pretor.utils.ray_hook import ray_actor_hook
|
||||
from fastapi import APIRouter, Depends
|
||||
from pydantic import BaseModel
|
||||
from pretor.utils.access import Accessor, TokenData
|
||||
from pretor.core.database.table.individual import AgentType
|
||||
from fastapi import HTTPException
|
||||
from typing import Optional, List, Dict
|
||||
from pretor.utils.check_user.role_check import RoleChecker
|
||||
from pretor.core.database.table.user import UserAuthority
|
||||
|
||||
agent_router = APIRouter(prefix="/api/v1/agent", tags=["agent"])
|
||||
|
||||
class AgentRegister(BaseModel):
|
||||
provider_title: str
|
||||
model_id: str
|
||||
individual_name: str
|
||||
tools: Optional[List[str]] = None
|
||||
|
||||
class AgentLocalRegister(BaseModel):
|
||||
path: str
|
||||
individual_name: str
|
||||
tools: Optional[List[str]] = None
|
||||
|
||||
@agent_router.get("")
|
||||
async def get_system_nodes(_: TokenData = Depends(RoleChecker(allowed_roles=UserAuthority.USER))):
|
||||
postgres_database = ray_actor_hook("postgres_database").postgres_database
|
||||
configs = await postgres_database.get_all_system_node_configs.remote()
|
||||
return {"system_nodes": configs}
|
||||
|
||||
@agent_router.post("")
|
||||
async def load_agent(agent_register: Union[AgentRegister, AgentLocalRegister],
|
||||
_: TokenData = Depends(RoleChecker(allowed_roles=UserAuthority.USER))):
|
||||
global_state_machine = ray_actor_hook("global_state_machine").global_state_machine
|
||||
postgres_database = ray_actor_hook("postgres_database").postgres_database
|
||||
|
||||
if isinstance(agent_register, AgentLocalRegister):
|
||||
pass
|
||||
|
||||
elif isinstance(agent_register, AgentRegister):
|
||||
try:
|
||||
# Persist configuration
|
||||
await postgres_database.upsert_system_node_config.remote(
|
||||
agent_register.individual_name,
|
||||
agent_register.provider_title,
|
||||
agent_register.model_id,
|
||||
agent_register.tools
|
||||
)
|
||||
|
||||
# Load agent into state machine
|
||||
match agent_register.individual_name:
|
||||
case "supervisory_node":
|
||||
node = ray_actor_hook("supervisory_node").supervisory_node
|
||||
await node.create_agent.remote(global_state_machine,agent_register.provider_title,agent_register.model_id, agent_register.tools)
|
||||
case "consciousness_node":
|
||||
node = ray_actor_hook("consciousness_node").consciousness_node
|
||||
await node.create_agent.remote(global_state_machine,agent_register.provider_title,agent_register.model_id, agent_register.tools)
|
||||
case "control_node":
|
||||
node = ray_actor_hook("control_node").control_node
|
||||
await node.create_agent.remote(global_state_machine,agent_register.provider_title,agent_register.model_id, agent_register.tools)
|
||||
case _:
|
||||
pass
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=f"加载节点失败: {str(e)}")
|
||||
return {"message": "创建成功"}
|
||||
|
||||
|
||||
class WorkerIndividualCreate(BaseModel):
|
||||
agent_name: str
|
||||
agent_type: AgentType
|
||||
description: str
|
||||
provider_title: str
|
||||
model_id: str
|
||||
system_prompt: str
|
||||
output_template: dict
|
||||
bound_skill: Dict[str, List[str]]
|
||||
workspace: List[str]
|
||||
tools: Optional[List[str]] = None
|
||||
|
||||
|
||||
class WorkerIndividualUpdate(BaseModel):
|
||||
agent_name: Optional[str] = None
|
||||
agent_type: Optional[AgentType] = None
|
||||
description: Optional[str] = None
|
||||
provider_title: Optional[str] = None
|
||||
model_id: Optional[str] = None
|
||||
system_prompt: Optional[str] = None
|
||||
output_template: Optional[dict] = None
|
||||
bound_skill: Optional[Dict[str, List[str]]] = None
|
||||
workspace: Optional[List[str]] = None
|
||||
tools: Optional[List[str]] = None
|
||||
|
||||
|
||||
@agent_router.post("/worker")
|
||||
async def create_worker_individual(worker_data: WorkerIndividualCreate,
|
||||
token_data: TokenData = Depends(RoleChecker(allowed_roles=UserAuthority.USER))):
|
||||
postgres_database = ray_actor_hook("postgres_database").postgres_database
|
||||
data_dict = worker_data.model_dump()
|
||||
data_dict["owner_id"] = token_data.user_id
|
||||
worker = await postgres_database.add_worker_individual.remote( **data_dict)
|
||||
return {"message": "success", "agent_id": worker.agent_id}
|
||||
|
||||
|
||||
@agent_router.get("/worker")
|
||||
async def get_worker_individual_list(token_data: TokenData = Depends(Accessor.get_current_user)):
|
||||
postgres_database = ray_actor_hook("postgres_database").postgres_database
|
||||
workers = await postgres_database.get_worker_individual_list.remote( owner_id=token_data.user_id)
|
||||
return {"workers": workers}
|
||||
|
||||
|
||||
@agent_router.get("/worker/{agent_id}")
|
||||
async def get_worker_individual(agent_id: str,
|
||||
token_data: TokenData = Depends(Accessor.get_current_user)):
|
||||
postgres_database = ray_actor_hook("postgres_database").postgres_database
|
||||
worker = await postgres_database.get_worker_individual.remote( agent_id=agent_id)
|
||||
if not worker:
|
||||
raise HTTPException(status_code=404, detail="Agent not found")
|
||||
if worker.owner_id != token_data.user_id:
|
||||
raise HTTPException(status_code=403, detail="Forbidden: You do not own this agent")
|
||||
return worker
|
||||
|
||||
|
||||
@agent_router.put("/worker/{agent_id}")
|
||||
async def update_worker_individual(agent_id: str,
|
||||
worker_data: WorkerIndividualUpdate,
|
||||
token_data: TokenData = Depends(Accessor.get_current_user)):
|
||||
postgres_database = ray_actor_hook("postgres_database").postgres_database
|
||||
worker = await postgres_database.get_worker_individual.remote( agent_id=agent_id)
|
||||
if not worker:
|
||||
raise HTTPException(status_code=404, detail="Agent not found")
|
||||
if worker.owner_id != token_data.user_id:
|
||||
raise HTTPException(status_code=403, detail="Forbidden: You do not own this agent")
|
||||
|
||||
update_data = worker_data.model_dump(exclude_unset=True)
|
||||
updated_worker = await postgres_database.update_worker_individual.remote( agent_id=agent_id, **update_data)
|
||||
|
||||
global_state_machine = ray_actor_hook("global_state_machine").global_state_machine
|
||||
try:
|
||||
await global_state_machine.remove_individual.remote(agent_id)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
return {"message": "success", "worker": updated_worker}
|
||||
|
||||
@agent_router.post("/worker/{agent_id}/reload")
|
||||
async def reload_worker_individual(agent_id: str, token_data: TokenData = Depends(Accessor.get_current_user)):
|
||||
postgres_database = ray_actor_hook("postgres_database").postgres_database
|
||||
worker = await postgres_database.get_worker_individual.remote(agent_id=agent_id)
|
||||
if not worker:
|
||||
raise HTTPException(status_code=404, detail="Agent not found")
|
||||
if worker.owner_id != token_data.user_id:
|
||||
raise HTTPException(status_code=403, detail="Forbidden: You do not own this agent")
|
||||
|
||||
global_state_machine = ray_actor_hook("global_state_machine").global_state_machine
|
||||
await global_state_machine.remove_individual.remote(agent_id)
|
||||
|
||||
return {"message": "Worker will be reloaded on next use"}
|
||||
|
||||
|
||||
@agent_router.delete("/worker/{agent_id}")
|
||||
async def delete_worker_individual(agent_id: str,
|
||||
token_data: TokenData = Depends(Accessor.get_current_user)):
|
||||
postgres_database = ray_actor_hook("postgres_database").postgres_database
|
||||
worker = await postgres_database.get_worker_individual.remote( agent_id=agent_id)
|
||||
if not worker:
|
||||
raise HTTPException(status_code=404, detail="Agent not found")
|
||||
if worker.owner_id != token_data.user_id:
|
||||
raise HTTPException(status_code=403, detail="Forbidden: You do not own this agent")
|
||||
await postgres_database.delete_worker_individual.remote( agent_id=agent_id)
|
||||
return {"message": "success"}
|
||||
@@ -0,0 +1,88 @@
|
||||
# Copyright 2026 zhaoxi826
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from fastapi import APIRouter
|
||||
from fastapi import Depends
|
||||
from pydantic import BaseModel
|
||||
from pretor.utils.access import Accessor, TokenData
|
||||
from fastapi.concurrency import run_in_threadpool
|
||||
from pretor.utils.ray_hook import ray_actor_hook
|
||||
from pretor.utils.check_user.role_check import RoleChecker
|
||||
from pretor.core.database.table.user import UserAuthority
|
||||
from pretor.utils.error import UserNotExistError
|
||||
|
||||
auth_router = APIRouter(prefix="/api/v1/auth", tags=["auth"])
|
||||
|
||||
class UserRegister(BaseModel):
|
||||
user_name: str
|
||||
password: str
|
||||
|
||||
@auth_router.post("/register")
|
||||
async def create_user(user_register: UserRegister):
|
||||
postgres_database = ray_actor_hook("postgres_database").postgres_database
|
||||
hashed_password = await run_in_threadpool(Accessor.hash_password, user_register.password)
|
||||
user = await postgres_database.add_user.remote( user_register.user_name, hashed_password)
|
||||
return {"message": "success", "user_id": user.user_id}
|
||||
|
||||
class UserLogin(BaseModel):
|
||||
user_name: str
|
||||
password: str
|
||||
|
||||
@auth_router.post("/login")
|
||||
async def login_user(user_login: UserLogin):
|
||||
postgres_database = ray_actor_hook("postgres_database").postgres_database
|
||||
user = await postgres_database.login_user.remote( user_login.user_name)
|
||||
if not user:
|
||||
raise UserNotExistError()
|
||||
token = await run_in_threadpool(Accessor.login_hashed_password, user, user_login.password)
|
||||
return {"message":"success", "token":token}
|
||||
|
||||
class ChangeAuthorityRequest(BaseModel):
|
||||
user_id: str
|
||||
new_authority: UserAuthority
|
||||
|
||||
@auth_router.put("/authority")
|
||||
async def change_authority(
|
||||
request: ChangeAuthorityRequest,
|
||||
_: TokenData = Depends(RoleChecker(allowed_roles=UserAuthority.SUPER_ADMINISTRATOR))
|
||||
):
|
||||
"""
|
||||
Update a user's authority level. Only accessible by SUPER_ADMINISTRATOR.
|
||||
"""
|
||||
postgres_database = ray_actor_hook("postgres_database").postgres_database
|
||||
user = await postgres_database.change_user_authority.remote( user_id=request.user_id, new_authority=request.new_authority)
|
||||
return {"message": "success", "user_id": user.user_id, "new_authority": user.user_authority}
|
||||
|
||||
@auth_router.get("/list")
|
||||
async def get_user_list(
|
||||
_: TokenData = Depends(RoleChecker(allowed_roles=UserAuthority.SUPER_ADMINISTRATOR))
|
||||
):
|
||||
"""
|
||||
Get a list of all users. Only accessible by SUPER_ADMINISTRATOR.
|
||||
"""
|
||||
postgres_database = ray_actor_hook("postgres_database").postgres_database
|
||||
users = await postgres_database.get_all_users.remote()
|
||||
return {"users": [{"user_id": u.user_id, "user_name": u.user_name, "role": u.user_authority} for u in users]}
|
||||
|
||||
@auth_router.delete("/{user_id}")
|
||||
async def delete_user(
|
||||
user_id: str,
|
||||
_: TokenData = Depends(RoleChecker(allowed_roles=UserAuthority.SUPER_ADMINISTRATOR))
|
||||
):
|
||||
"""
|
||||
Delete a user. Only accessible by SUPER_ADMINISTRATOR.
|
||||
"""
|
||||
postgres_database = ray_actor_hook("postgres_database").postgres_database
|
||||
await postgres_database.delete_user_by_id.remote( user_id=user_id)
|
||||
return {"message": "success"}
|
||||
@@ -0,0 +1,19 @@
|
||||
# Copyright 2026 zhaoxi826
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from fastapi import APIRouter
|
||||
|
||||
cluster_router = APIRouter(prefix="/api/v1/cluster", tags=["cluster"])
|
||||
|
||||
# Monitor websocket API temporarily removed
|
||||
@@ -0,0 +1,16 @@
|
||||
# Copyright 2026 zhaoxi826
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from .frontend import client_router
|
||||
__all__ = ["client_router"]
|
||||
@@ -0,0 +1,36 @@
|
||||
# Copyright 2026 zhaoxi826
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import datetime
|
||||
from pydantic import BaseModel, Field, ConfigDict
|
||||
from ulid import ULID
|
||||
from typing import Any, Dict
|
||||
from pretor.core.workflow.workflow import PretorWorkflow
|
||||
import asyncio
|
||||
|
||||
class PretorEvent(BaseModel):
|
||||
model_config = ConfigDict(arbitrary_types_allowed=True)
|
||||
trace_id: str = Field(default_factory=lambda: str(ULID()), description="事件的唯一标识符")
|
||||
platform: str = Field(description="消息来源的平台")
|
||||
user_id: str = Field(description="用户id")
|
||||
user_name: str = Field(description="用户名")
|
||||
create_time: str = Field(default_factory=lambda: str(datetime.datetime.now(datetime.timezone.utc).isoformat()),
|
||||
description="事件创建时间")
|
||||
message: str = Field(description="用户发来的消息")
|
||||
attachment: Dict[str, str] | None = Field(default=None,description="附件")
|
||||
#--------------------------------------------------------------------------------------------------------------
|
||||
context: Dict[str, Any] = Field(default_factory=dict, description="事件上下文内容,可包含工作流模板等信息")
|
||||
workflow: PretorWorkflow | None = Field(default=None,description="工作流")
|
||||
pending_queue: asyncio.Queue[str] | None= Field(default=None,description="待处理队列")
|
||||
receive_queue: asyncio.Queue[str] | None = Field(default=None,description="待接收队列")
|
||||
@@ -0,0 +1,63 @@
|
||||
# Copyright 2026 zhaoxi826
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, status, UploadFile, File
|
||||
from pydantic import BaseModel
|
||||
from pretor.utils.access import Accessor, TokenData
|
||||
from pretor.api.platform.event import PretorEvent
|
||||
from pretor.utils.ray_hook import ray_actor_hook
|
||||
import os
|
||||
import shutil
|
||||
|
||||
from pretor.utils.logger import get_logger
|
||||
logger = get_logger('frontend')
|
||||
client_router = APIRouter(prefix="/api/v1/adapter/client", tags=["client"])
|
||||
|
||||
class Message(BaseModel):
|
||||
message: str
|
||||
|
||||
@client_router.post("")
|
||||
async def create_message(message: Message,
|
||||
token_data: TokenData = Depends(Accessor.get_current_user)):
|
||||
logger.info("收到消息,来源:客户端")
|
||||
logger.debug(f"消息内容:{message.message}")
|
||||
event = PretorEvent(platform="client",
|
||||
user_id=str(token_data.user_id),
|
||||
user_name=token_data.username,
|
||||
message=message.message)
|
||||
supervisory_node = ray_actor_hook("supervisory_node").supervisory_node
|
||||
message = await supervisory_node.working.remote(event)
|
||||
if message == "任务已创建":
|
||||
return {"message": event.trace_id}
|
||||
elif message == "未知相应类型":
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail="模型回复错误")
|
||||
else:
|
||||
return {"message": message}
|
||||
|
||||
@client_router.post("/upload")
|
||||
async def upload_file(file: UploadFile = File(...),
|
||||
token_data: TokenData = Depends(Accessor.get_current_user)):
|
||||
try:
|
||||
upload_dir = "uploads"
|
||||
os.makedirs(upload_dir, exist_ok=True)
|
||||
file_path = os.path.join(upload_dir, file.filename)
|
||||
with open(file_path, "wb") as buffer:
|
||||
shutil.copyfileobj(file.file, buffer)
|
||||
logger.info(f"用户 {token_data.username} 上传了文件: {file.filename}")
|
||||
return {"filename": file.filename, "message": f"File {file.filename} uploaded successfully"}
|
||||
except Exception as e:
|
||||
logger.error(f"文件上传失败: {e}")
|
||||
raise HTTPException(status_code=500, detail="文件上传失败")
|
||||
@@ -0,0 +1,54 @@
|
||||
# Copyright 2026 zhaoxi826
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from fastapi import APIRouter, Depends
|
||||
from pydantic import BaseModel
|
||||
from typing import Literal
|
||||
from pretor.utils.access import TokenData, Accessor
|
||||
from pretor.utils.check_user.role_check import RoleChecker
|
||||
from pretor.core.database.table.user import UserAuthority
|
||||
from typing import Dict
|
||||
from pretor.core.global_state_machine.model_provider.base_provider import Provider
|
||||
from pretor.utils.ray_hook import ray_actor_hook
|
||||
|
||||
provider_router = APIRouter(prefix="/api/v1/provider", tags=["provider"])
|
||||
|
||||
class ProviderRegister(BaseModel):
|
||||
provider_type: Literal["openai", "claude", "deepseek"]
|
||||
provider_title: str
|
||||
provider_url: str
|
||||
provider_apikey: str
|
||||
|
||||
@provider_router.post("")
|
||||
async def create_provider(provider_register: ProviderRegister,
|
||||
token_data: TokenData = Depends(RoleChecker(allowed_roles=UserAuthority.USER))) -> None:
|
||||
global_state_machine = ray_actor_hook("global_state_machine").global_state_machine
|
||||
await global_state_machine.add_provider_wrap.remote(provider_type=provider_register.provider_type,
|
||||
provider_title=provider_register.provider_title,
|
||||
provider_url=provider_register.provider_url,
|
||||
provider_apikey=provider_register.provider_apikey,
|
||||
provider_owner=token_data.user_id)
|
||||
|
||||
|
||||
@provider_router.get("/list")
|
||||
async def get_provider_list(_: TokenData = Depends(Accessor.get_current_user)) -> Dict[str, Dict[str, Provider]]:
|
||||
global_state_machine = ray_actor_hook("global_state_machine").global_state_machine
|
||||
provider_list: Dict[str, Provider] = await global_state_machine.get_provider_list.remote()
|
||||
return {"provider_list": provider_list}
|
||||
|
||||
@provider_router.delete("/{provider_title}")
|
||||
async def delete_provider(provider_title: str, _: TokenData = Depends(RoleChecker(allowed_roles=UserAuthority.SUPER_ADMINISTRATOR))) -> dict:
|
||||
global_state_machine = ray_actor_hook("global_state_machine").global_state_machine
|
||||
await global_state_machine.delete_provider.remote(provider_title=provider_title)
|
||||
return {"message": "success"}
|
||||
@@ -0,0 +1,89 @@
|
||||
# Copyright 2026 zhaoxi826
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from pydantic import BaseModel
|
||||
import viceroy
|
||||
from pretor.core.workflow.workflow_template_generator.workflow_template import WorkflowTemplate
|
||||
from pretor.utils.ray_hook import ray_actor_hook
|
||||
from fastapi import APIRouter, Depends
|
||||
from pretor.utils.access import TokenData
|
||||
from pretor.utils.check_user.role_check import RoleChecker
|
||||
from pretor.core.database.table.user import UserAuthority
|
||||
|
||||
resource_router = APIRouter(prefix="/api/v1/resource")
|
||||
|
||||
@resource_router.post("/workflow_template")
|
||||
async def create_workflow_template(workflow_template: WorkflowTemplate,
|
||||
_: TokenData = Depends(RoleChecker(allowed_roles=UserAuthority.USER))):
|
||||
global_state_machine = ray_actor_hook("global_state_machine").global_state_machine
|
||||
await global_state_machine.add_workflow_template.remote( workflow_template.name, workflow_template)
|
||||
return {"message": "创建成功"}
|
||||
|
||||
@resource_router.get("/workflow_template")
|
||||
async def get_workflow_templates(_: TokenData = Depends(RoleChecker(allowed_roles=UserAuthority.USER))):
|
||||
global_state_machine = ray_actor_hook("global_state_machine").global_state_machine
|
||||
templates = await global_state_machine.get_all_workflow_templates.remote()
|
||||
return {"templates": templates}
|
||||
|
||||
@resource_router.delete("/workflow_template/{template_name}")
|
||||
async def delete_workflow_template(template_name: str, _: TokenData = Depends(RoleChecker(allowed_roles=UserAuthority.SUPER_ADMINISTRATOR))):
|
||||
global_state_machine = ray_actor_hook("global_state_machine").global_state_machine
|
||||
await global_state_machine.delete_workflow_template.remote( template_name)
|
||||
return {"message": "success"}
|
||||
|
||||
|
||||
|
||||
class Skill(BaseModel):
|
||||
repo_url: str
|
||||
path: str | None
|
||||
|
||||
@resource_router.post("/skill")
|
||||
async def install_skill(skill: Skill,
|
||||
_: TokenData = Depends(RoleChecker(allowed_roles=UserAuthority.USER))):
|
||||
global_state_machine = ray_actor_hook("global_state_machine").global_state_machine
|
||||
# noinspection PyUnresolvedReferences
|
||||
import os
|
||||
skill_output_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "plugin", "skill"))
|
||||
os.makedirs(skill_output_dir, exist_ok=True)
|
||||
await viceroy.install_skill_async(url = skill.repo_url,
|
||||
path = skill.path,
|
||||
output = skill_output_dir)
|
||||
if skill.path:
|
||||
skill_name = skill.path.split("/")[-1]
|
||||
else:
|
||||
skill_name = skill.repo_url.split("/")[-1]
|
||||
await global_state_machine.add_skill.remote( skill_name)
|
||||
return {"message": "创建成功"}
|
||||
|
||||
@resource_router.get("/skill")
|
||||
async def get_skills(_: TokenData = Depends(RoleChecker(allowed_roles=UserAuthority.USER))):
|
||||
global_state_machine = ray_actor_hook("global_state_machine").global_state_machine
|
||||
skills = await global_state_machine.get_skill_list.remote()
|
||||
return {"skills": skills}
|
||||
|
||||
@resource_router.delete("/skill/{skill_name}")
|
||||
async def delete_skill(skill_name: str, _: TokenData = Depends(RoleChecker(allowed_roles=UserAuthority.SUPER_ADMINISTRATOR))):
|
||||
global_state_machine = ray_actor_hook("global_state_machine").global_state_machine
|
||||
# Note: this only removes it from the state machine manager.
|
||||
await global_state_machine.remove_skill.remote( skill_name)
|
||||
return {"message": "success"}
|
||||
|
||||
@resource_router.get("/tool")
|
||||
async def get_tools(_: TokenData = Depends(RoleChecker(allowed_roles=UserAuthority.USER))):
|
||||
global_state_machine = ray_actor_hook("global_state_machine").global_state_machine
|
||||
tool_mapper = await global_state_machine.get_tool_mapper.remote()
|
||||
all_tool_names = set()
|
||||
for scope_tools in tool_mapper.values():
|
||||
all_tool_names.update(scope_tools.keys())
|
||||
return {"tools": list(all_tool_names)}
|
||||
@@ -0,0 +1,99 @@
|
||||
# Copyright 2026 zhaoxi826
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
from pretor.utils.ray_hook import ray_actor_hook
|
||||
from fastapi import APIRouter, Request, HTTPException
|
||||
from fastapi.responses import StreamingResponse
|
||||
import asyncio
|
||||
|
||||
workflow_router = APIRouter(prefix="/api/v1/workflow", tags=["workflow"])
|
||||
|
||||
@workflow_router.get("/list")
|
||||
async def get_workflow_list():
|
||||
global_state_machine = ray_actor_hook("global_state_machine").global_state_machine
|
||||
events = await global_state_machine.list_events.remote()
|
||||
return events
|
||||
|
||||
|
||||
@workflow_router.get("/{trace_id}")
|
||||
async def get_workflow_detail(trace_id: str):
|
||||
global_state_machine = ray_actor_hook("global_state_machine").global_state_machine
|
||||
event = await global_state_machine.get_event.remote(trace_id)
|
||||
if not event:
|
||||
raise HTTPException(status_code=404, detail="Workflow not found")
|
||||
|
||||
workflow = event.workflow
|
||||
if not workflow:
|
||||
return {
|
||||
"event_id": trace_id,
|
||||
"workflow_title": None,
|
||||
"status": "waiting",
|
||||
"user_name": event.user_name,
|
||||
"message": event.message,
|
||||
"create_time": event.create_time,
|
||||
"steps": [],
|
||||
}
|
||||
|
||||
steps = []
|
||||
for step in workflow.work_link:
|
||||
steps.append({
|
||||
"step": step.step,
|
||||
"name": step.name,
|
||||
"node": step.node,
|
||||
"action": step.action,
|
||||
"desc": step.desc,
|
||||
"status": step.status,
|
||||
"agent_id": step.agent_id,
|
||||
})
|
||||
return {
|
||||
"event_id": trace_id,
|
||||
"workflow_title": workflow.title,
|
||||
"status": workflow.status.status,
|
||||
"command": workflow.command,
|
||||
"current_step": workflow.status.step,
|
||||
"user_name": event.user_name,
|
||||
"message": event.message,
|
||||
"create_time": event.create_time,
|
||||
"steps": steps,
|
||||
}
|
||||
|
||||
@workflow_router.get("/sse/{trace_id}")
|
||||
async def get_workflow_sse(trace_id: str, request: Request):
|
||||
global_state_machine = ray_actor_hook("global_state_machine").global_state_machine
|
||||
|
||||
async def event_generator():
|
||||
try:
|
||||
while True:
|
||||
if await request.is_disconnected():
|
||||
break
|
||||
|
||||
# You might also want to send the workflow state periodically or when updated
|
||||
# Here we just wait for pending messages and send them
|
||||
message = await global_state_machine.get_pending.remote(trace_id)
|
||||
# Ensure the message is formatted as SSE
|
||||
yield f"data: {message}\n\n"
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
|
||||
return StreamingResponse(event_generator(), media_type="text/event-stream")
|
||||
|
||||
@workflow_router.post("/reply/{trace_id}")
|
||||
async def post_workflow_reply(trace_id: str, request: Request):
|
||||
data = await request.json()
|
||||
reply_msg = data.get("message", "")
|
||||
global_state_machine = ray_actor_hook("global_state_machine").global_state_machine
|
||||
await global_state_machine.put_received.remote(trace_id, reply_msg)
|
||||
return {"status": "ok"}
|
||||
|
||||
@@ -0,0 +1,14 @@
|
||||
# Copyright 2026 zhaoxi826
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
@@ -0,0 +1,127 @@
|
||||
# Copyright 2026 zhaoxi826
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import os
|
||||
from typing import Dict
|
||||
from fastapi import FastAPI, WebSocket, Request
|
||||
from fastapi.staticfiles import StaticFiles
|
||||
from fastapi.responses import FileResponse, JSONResponse
|
||||
|
||||
from pretor.api.platform.frontend import client_router
|
||||
from pretor.api.auth import auth_router
|
||||
from pretor.api.provider import provider_router
|
||||
from pretor.api.resource import resource_router
|
||||
from pretor.api.cluster import cluster_router
|
||||
from pretor.api.agent import agent_router
|
||||
from pretor.utils.error import (
|
||||
DemandError, ModelNotExistError, UserError,
|
||||
UserNotExistError, UserPasswordError, ProviderError,
|
||||
ProviderNotExistError, WorkflowError, WorkflowExit
|
||||
)
|
||||
|
||||
from ray import serve
|
||||
|
||||
app = FastAPI()
|
||||
|
||||
app.include_router(client_router) # 客户端路径
|
||||
app.include_router(auth_router) # 用户路径
|
||||
app.include_router(provider_router) # 供应商路径
|
||||
app.include_router(resource_router) # 资源路径
|
||||
app.include_router(cluster_router) # 集群信息路径
|
||||
app.include_router(agent_router) # agent路径
|
||||
|
||||
@app.exception_handler(UserNotExistError)
|
||||
async def user_not_exist_handler(request: Request, exc: UserNotExistError):
|
||||
return JSONResponse(status_code=404, content={"message": "用户不存在"})
|
||||
|
||||
|
||||
@app.exception_handler(UserPasswordError)
|
||||
async def user_password_handler(request: Request, exc: UserPasswordError):
|
||||
return JSONResponse(status_code=401, content={"message": "密码错误"})
|
||||
|
||||
|
||||
@app.exception_handler(UserError)
|
||||
async def user_error_handler(request: Request, exc: UserError):
|
||||
return JSONResponse(status_code=400, content={"message": "用户相关错误"})
|
||||
|
||||
|
||||
@app.exception_handler(ProviderNotExistError)
|
||||
async def provider_not_exist_handler(request: Request, exc: ProviderNotExistError):
|
||||
return JSONResponse(status_code=404, content={"message": "服务提供商不存在"})
|
||||
|
||||
|
||||
@app.exception_handler(ProviderError)
|
||||
async def provider_error_handler(request: Request, exc: ProviderError):
|
||||
return JSONResponse(status_code=400, content={"message": "服务提供商错误"})
|
||||
|
||||
|
||||
@app.exception_handler(ModelNotExistError)
|
||||
async def model_not_exist_handler(request: Request, exc: ModelNotExistError):
|
||||
return JSONResponse(status_code=404, content={"message": "模型不存在"})
|
||||
|
||||
|
||||
@app.exception_handler(DemandError)
|
||||
async def demand_error_handler(request: Request, exc: DemandError):
|
||||
return JSONResponse(status_code=400, content={"message": "需求格式错误或不满足"})
|
||||
|
||||
|
||||
@app.exception_handler(WorkflowExit)
|
||||
async def workflow_exit_handler(request: Request, exc: WorkflowExit):
|
||||
return JSONResponse(status_code=400, content={"message": "工作流已退出"})
|
||||
|
||||
|
||||
@app.exception_handler(WorkflowError)
|
||||
async def workflow_error_handler(request: Request, exc: WorkflowError):
|
||||
return JSONResponse(status_code=500, content={"message": "工作流执行错误"})
|
||||
|
||||
base_dir = os.path.dirname(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))))
|
||||
frontend_dir = os.path.join(base_dir, "frontend", "dist")
|
||||
|
||||
if os.path.exists(frontend_dir):
|
||||
app.mount("/assets", StaticFiles(directory=os.path.join(frontend_dir, "assets")), name="assets")
|
||||
|
||||
|
||||
@app.get("/favicon.svg", include_in_schema=False)
|
||||
async def serve_favicon():
|
||||
return FileResponse(os.path.join(frontend_dir, "favicon.svg"))
|
||||
|
||||
|
||||
@app.get("/icons.svg", include_in_schema=False)
|
||||
async def serve_icons():
|
||||
return FileResponse(os.path.join(frontend_dir, "icons.svg"))
|
||||
|
||||
|
||||
@app.get("/{full_path:path}", include_in_schema=False)
|
||||
async def serve_frontend(full_path: str):
|
||||
# 【重要安全修复】避免拦截不存在的 API 路由。如果是调用了不存在的 /api/ 接口,直接返回 404,不返回前端页面
|
||||
if full_path.startswith("api/"):
|
||||
return JSONResponse(status_code=404, content={"detail": "API endpoint not found"})
|
||||
|
||||
index_path = os.path.join(frontend_dir, "index.html")
|
||||
if os.path.exists(index_path):
|
||||
return FileResponse(index_path)
|
||||
return JSONResponse(status_code=404, content={"detail": "Frontend build not found"})
|
||||
else:
|
||||
import logging
|
||||
|
||||
logging.getLogger("pretor").warning(f"Frontend dist folder not found at {frontend_dir}. Skipping frontend mount.")
|
||||
|
||||
|
||||
@serve.deployment
|
||||
@serve.ingress(app)
|
||||
class PretorGateway:
|
||||
gateway: Dict[str, WebSocket]
|
||||
|
||||
def __init__(self):
|
||||
self.gateway = {}
|
||||
@@ -0,0 +1,14 @@
|
||||
# Copyright 2026 zhaoxi826
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
@@ -0,0 +1,39 @@
|
||||
# Copyright 2026 zhaoxi826
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from sqlalchemy.exc import IntegrityError, OperationalError
|
||||
from pydantic import ValidationError
|
||||
from pretor.utils.error import UserNotExistError
|
||||
|
||||
from pretor.utils.logger import get_logger
|
||||
logger = get_logger('database_exception')
|
||||
def database_exception(func):
|
||||
async def wrapper(*args, **kwargs):
|
||||
try:
|
||||
return await func(*args, **kwargs)
|
||||
except ValidationError as e:
|
||||
logger.error(f"对象校验失败:{e}")
|
||||
raise e
|
||||
except IntegrityError as e:
|
||||
logger.error(f"数据库完整性错误 (如重复记录): {e}")
|
||||
raise e
|
||||
except OperationalError as e:
|
||||
logger.error(f"数据库连接异常: {e}")
|
||||
raise e
|
||||
except UserNotExistError as e:
|
||||
logger.error(f"更改密码失败,用户不存在:{e}")
|
||||
except Exception as e:
|
||||
logger.exception(f"未预期的数据库错误: {e}")
|
||||
raise e
|
||||
return wrapper
|
||||
@@ -0,0 +1,14 @@
|
||||
# Copyright 2026 zhaoxi826
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
@@ -0,0 +1,83 @@
|
||||
# Copyright 2026 zhaoxi826
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from pretor.core.database.table.individual import WorkerIndividual
|
||||
from sqlmodel import select
|
||||
from typing import List, Optional
|
||||
from pretor.core.database.database_exception import database_exception
|
||||
|
||||
from ulid import ULID
|
||||
|
||||
class IndividualDatabase:
|
||||
def __init__(self, async_session_maker):
|
||||
self.async_session_maker = async_session_maker
|
||||
|
||||
@database_exception
|
||||
async def add_worker_individual(self, **kwargs) -> WorkerIndividual:
|
||||
async with self.async_session_maker() as session:
|
||||
agent_id = str(ULID())
|
||||
individual = WorkerIndividual(agent_id=agent_id, **kwargs)
|
||||
session.add(individual)
|
||||
await session.commit()
|
||||
await session.refresh(individual)
|
||||
return individual
|
||||
|
||||
@database_exception
|
||||
async def get_worker_individual(self, agent_id: str) -> Optional[WorkerIndividual]:
|
||||
async with self.async_session_maker() as session:
|
||||
statement = select(WorkerIndividual).where(WorkerIndividual.agent_id == agent_id)
|
||||
results = await session.execute(statement)
|
||||
return results.scalar_one_or_none()
|
||||
|
||||
@database_exception
|
||||
async def get_worker_individual_list(self, owner_id: str) -> List[WorkerIndividual]:
|
||||
async with self.async_session_maker() as session:
|
||||
statement = select(WorkerIndividual).where(WorkerIndividual.owner_id == owner_id)
|
||||
results = await session.execute(statement)
|
||||
return list(results.scalars().all())
|
||||
|
||||
@database_exception
|
||||
async def update_worker_individual(self, agent_id: str, **kwargs) -> Optional[WorkerIndividual]:
|
||||
async with self.async_session_maker() as session:
|
||||
statement = select(WorkerIndividual).where(WorkerIndividual.agent_id == agent_id)
|
||||
results = await session.execute(statement)
|
||||
individual = results.scalar_one_or_none()
|
||||
if not individual:
|
||||
return None
|
||||
for key, value in kwargs.items():
|
||||
if value is not None:
|
||||
setattr(individual, key, value)
|
||||
session.add(individual)
|
||||
await session.commit()
|
||||
await session.refresh(individual)
|
||||
return individual
|
||||
|
||||
@database_exception
|
||||
async def delete_worker_individual(self, agent_id: str) -> bool:
|
||||
async with self.async_session_maker() as session:
|
||||
statement = select(WorkerIndividual).where(WorkerIndividual.agent_id == agent_id)
|
||||
results = await session.execute(statement)
|
||||
individual = results.scalar_one_or_none()
|
||||
if not individual:
|
||||
return False
|
||||
session.delete(individual)
|
||||
await session.commit()
|
||||
return True
|
||||
|
||||
@database_exception
|
||||
async def get_all_worker_individual(self) -> List[WorkerIndividual]:
|
||||
async with self.async_session_maker() as session:
|
||||
statement = select(WorkerIndividual)
|
||||
results = await session.execute(statement)
|
||||
return list(results.scalars().all())
|
||||
@@ -0,0 +1,68 @@
|
||||
# Copyright 2026 zhaoxi826
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from sqlmodel import SQLModel, Field, select
|
||||
from typing import Optional, List
|
||||
import json
|
||||
|
||||
class WorkflowRecord(SQLModel, table=True):
|
||||
id: Optional[int] = Field(default=None, primary_key=True)
|
||||
workflow_id: str = Field(index=True)
|
||||
workflow_data_json: str
|
||||
|
||||
class MemoryRecord(SQLModel, table=True):
|
||||
id: Optional[int] = Field(default=None, primary_key=True)
|
||||
memory_text: str
|
||||
embedding: List[float] = Field(sa_column_kwargs={"type_": "VECTOR"}) # Requires pgvector extension setup in DB
|
||||
|
||||
class MemoryRAG:
|
||||
def __init__(self, async_session_maker):
|
||||
self.async_session_maker = async_session_maker
|
||||
|
||||
async def save_workflow(self, workflow_id: str, workflow_data: dict):
|
||||
async with self.async_session_maker() as session:
|
||||
record = WorkflowRecord(
|
||||
workflow_id=workflow_id,
|
||||
workflow_data_json=json.dumps(workflow_data)
|
||||
)
|
||||
session.add(record)
|
||||
await session.commit()
|
||||
await session.refresh(record)
|
||||
return record
|
||||
|
||||
async def get_workflow(self, workflow_id: str):
|
||||
async with self.async_session_maker() as session:
|
||||
statement = select(WorkflowRecord).where(WorkflowRecord.workflow_id == workflow_id)
|
||||
results = await session.execute(statement)
|
||||
record = results.scalar_one_or_none()
|
||||
if record:
|
||||
return json.loads(record.workflow_data_json)
|
||||
return None
|
||||
|
||||
async def add_memory(self, memory_text: str, embedding: List[float]):
|
||||
async with self.async_session_maker() as session:
|
||||
record = MemoryRecord(memory_text=memory_text, embedding=embedding)
|
||||
session.add(record)
|
||||
await session.commit()
|
||||
await session.refresh(record)
|
||||
return record
|
||||
|
||||
async def retrieve_memory(self, query_embedding: List[float], limit: int = 5):
|
||||
# Requires pgvector specific operations; simplified retrieval simulation here
|
||||
async with self.async_session_maker() as session:
|
||||
# A true pgvector query would use an ORDER BY using `<->` operator
|
||||
# e.g. statement = select(MemoryRecord).order_by(MemoryRecord.embedding.l2_distance(query_embedding)).limit(limit)
|
||||
statement = select(MemoryRecord).limit(limit)
|
||||
results = await session.execute(statement)
|
||||
return results.all()
|
||||
@@ -0,0 +1,64 @@
|
||||
# Copyright 2026 zhaoxi826
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import List
|
||||
|
||||
from pretor.core.database.table.provider import Provider
|
||||
from sqlmodel import select
|
||||
from pretor.core.database.database_exception import database_exception
|
||||
|
||||
class ProviderDatabase:
|
||||
def __init__(self, async_session_maker):
|
||||
self.async_session_maker = async_session_maker
|
||||
|
||||
@database_exception
|
||||
async def get_provider(self) -> List[Provider]:
|
||||
async with self.async_session_maker() as session:
|
||||
statement = select(Provider)
|
||||
results = await session.execute(statement)
|
||||
results = results.scalars().all()
|
||||
providers = [Provider(provider_title=provider.provider_title,
|
||||
provider_url=provider.provider_url,
|
||||
provider_apikey=provider.provider_apikey,
|
||||
provider_models=provider.provider_models,
|
||||
provider_type=provider.provider_type) for provider in results]
|
||||
return providers
|
||||
|
||||
@database_exception
|
||||
async def add_provider(self, **kwargs) -> None:
|
||||
async with self.async_session_maker() as session:
|
||||
provider = Provider(**kwargs)
|
||||
session.add(provider)
|
||||
await session.commit()
|
||||
|
||||
@database_exception
|
||||
async def delete_provider(self, provider_id: str) -> None:
|
||||
async with self.async_session_maker() as session:
|
||||
provider = await session.get(Provider, provider_id)
|
||||
if provider is not None:
|
||||
session.delete(provider)
|
||||
await session.commit()
|
||||
|
||||
@database_exception
|
||||
async def update_provider(self, provider_id: str, **kwargs) -> Provider:
|
||||
async with self.async_session_maker() as session:
|
||||
provider = await session.get(Provider, provider_id)
|
||||
if provider is not None:
|
||||
for key, value in kwargs.items():
|
||||
setattr(provider, key, value)
|
||||
session.add(provider)
|
||||
await session.commit()
|
||||
await session.refresh(provider)
|
||||
return provider
|
||||
return None
|
||||
@@ -0,0 +1,54 @@
|
||||
# Copyright 2026 zhaoxi826
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from pretor.core.database.table.system_node import SystemNodeConfig
|
||||
from sqlmodel import select
|
||||
from typing import List, Optional
|
||||
from pretor.core.database.database_exception import database_exception
|
||||
|
||||
class SystemNodeDatabase:
|
||||
def __init__(self, async_session_maker):
|
||||
self.async_session_maker = async_session_maker
|
||||
|
||||
@database_exception
|
||||
async def upsert_system_node_config(self, node_name: str, provider_title: str, model_id: str, tools: Optional[List[str]] = None) -> SystemNodeConfig:
|
||||
async with self.async_session_maker() as session:
|
||||
statement = select(SystemNodeConfig).where(SystemNodeConfig.node_name == node_name)
|
||||
results = await session.execute(statement)
|
||||
config = results.scalar_one_or_none()
|
||||
if config:
|
||||
config.provider_title = provider_title
|
||||
config.model_id = model_id
|
||||
if tools is not None:
|
||||
config.tools = tools
|
||||
else:
|
||||
config = SystemNodeConfig(node_name=node_name, provider_title=provider_title, model_id=model_id, tools=tools)
|
||||
session.add(config)
|
||||
await session.commit()
|
||||
await session.refresh(config)
|
||||
return config
|
||||
|
||||
@database_exception
|
||||
async def get_all_system_node_configs(self) -> List[SystemNodeConfig]:
|
||||
async with self.async_session_maker() as session:
|
||||
statement = select(SystemNodeConfig)
|
||||
results = await session.execute(statement)
|
||||
return list(results.scalars().all())
|
||||
|
||||
@database_exception
|
||||
async def get_system_node_config(self, node_name: str) -> Optional[SystemNodeConfig]:
|
||||
async with self.async_session_maker() as session:
|
||||
statement = select(SystemNodeConfig).where(SystemNodeConfig.node_name == node_name)
|
||||
results = await session.execute(statement)
|
||||
return results.scalar_one_or_none()
|
||||
@@ -0,0 +1,135 @@
|
||||
# Copyright 2026 zhaoxi826
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from pretor.core.database.table.user import User
|
||||
from sqlmodel import select
|
||||
from pretor.utils.error import UserNotExistError, UserPasswordError
|
||||
from pretor.core.database.database_exception import database_exception
|
||||
from pretor.core.database.table.user import UserAuthority
|
||||
from pretor.utils.access import Accessor
|
||||
|
||||
class AuthDatabase:
|
||||
def __init__(self, async_session_maker):
|
||||
self.async_session_maker = async_session_maker
|
||||
|
||||
@database_exception
|
||||
async def add_user(self, user_name: str, hashed_password: str) -> User:
|
||||
from ulid import ULID
|
||||
async with self.async_session_maker() as session:
|
||||
# Check if any users exist
|
||||
statement = select(User).limit(1)
|
||||
results = await session.execute(statement)
|
||||
existing_user = results.first()
|
||||
|
||||
authority = UserAuthority.USER
|
||||
if existing_user is None:
|
||||
authority = UserAuthority.SUPER_ADMINISTRATOR
|
||||
|
||||
user = User(
|
||||
user_id=str(ULID()),
|
||||
user_name=user_name,
|
||||
hashed_password=hashed_password,
|
||||
user_authority=authority
|
||||
)
|
||||
session.add(user)
|
||||
await session.commit()
|
||||
await session.refresh(user)
|
||||
return user
|
||||
|
||||
@database_exception
|
||||
async def change_password(self, user_name, old_password, new_password) -> User:
|
||||
async with self.async_session_maker() as session:
|
||||
statement = select(User).where(User.user_name == user_name)
|
||||
results = await session.execute(statement)
|
||||
user = results.scalar_one_or_none()
|
||||
if user is None:
|
||||
raise UserNotExistError()
|
||||
if not Accessor.verify_password(old_password, user.hashed_password):
|
||||
raise UserPasswordError()
|
||||
user.hashed_password = new_password
|
||||
session.add(user)
|
||||
await session.commit()
|
||||
await session.refresh(user)
|
||||
return user
|
||||
|
||||
@database_exception
|
||||
async def delete_user(self, user_name: str) -> None:
|
||||
async with self.async_session_maker() as session:
|
||||
statement = select(User).where(User.user_name == user_name)
|
||||
results = await session.execute(statement)
|
||||
user = results.scalar_one_or_none()
|
||||
if user is None:
|
||||
raise UserNotExistError()
|
||||
session.delete(user)
|
||||
await session.commit()
|
||||
|
||||
@database_exception
|
||||
async def delete_user_by_id(self, user_id: str) -> None:
|
||||
async with self.async_session_maker() as session:
|
||||
user = await session.get(User, user_id)
|
||||
if user is None:
|
||||
raise UserNotExistError()
|
||||
session.delete(user)
|
||||
await session.commit()
|
||||
|
||||
@database_exception
|
||||
async def login_user(self, user_name: str) -> str:
|
||||
async with self.async_session_maker() as session:
|
||||
statement = select(User).where(User.user_name == user_name)
|
||||
results = await session.execute(statement)
|
||||
user = results.scalar_one_or_none()
|
||||
if user is None:
|
||||
raise UserNotExistError()
|
||||
return user
|
||||
|
||||
@database_exception
|
||||
async def get_all_users(self) -> list[User]:
|
||||
async with self.async_session_maker() as session:
|
||||
statement = select(User)
|
||||
results = await session.execute(statement)
|
||||
users = results.scalars().all()
|
||||
return list(users)
|
||||
|
||||
@database_exception
|
||||
async def get_user_authority(self, user_id: str) -> UserAuthority:
|
||||
async with self.async_session_maker() as session:
|
||||
user = await session.get(User, user_id)
|
||||
if user is None:
|
||||
raise UserNotExistError()
|
||||
return user.user_authority
|
||||
|
||||
@database_exception
|
||||
async def change_user_authority(self, user_id: str, new_authority: UserAuthority) -> User:
|
||||
"""
|
||||
Changes the authority level of a specific user.
|
||||
|
||||
Args:
|
||||
user_id: The ID of the user whose authority is to be changed.
|
||||
new_authority: The new authority level to assign to the user.
|
||||
|
||||
Returns:
|
||||
User: The updated user object.
|
||||
|
||||
Raises:
|
||||
UserNotExistError: If the specified user does not exist.
|
||||
"""
|
||||
async with self.async_session_maker() as session:
|
||||
user = await session.get(User, user_id)
|
||||
if user is None:
|
||||
raise UserNotExistError()
|
||||
user.user_authority = new_authority
|
||||
session.add(user)
|
||||
await session.commit()
|
||||
await session.refresh(user)
|
||||
return user
|
||||
@@ -0,0 +1,140 @@
|
||||
# Copyright 2026 zhaoxi826
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import os
|
||||
import asyncio
|
||||
|
||||
import ray
|
||||
from sqlalchemy.ext.asyncio import create_async_engine, AsyncSession
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
from sqlmodel import SQLModel
|
||||
|
||||
from pretor.core.database.module.individual import IndividualDatabase
|
||||
from pretor.core.database.module.user import AuthDatabase
|
||||
from pretor.core.database.module.provider import ProviderDatabase
|
||||
from pretor.core.database.module.system_node import SystemNodeDatabase
|
||||
|
||||
@ray.remote
|
||||
class PostgresDatabase:
|
||||
def __init__(self):
|
||||
user = os.environ.get('POSTGRES_USER')
|
||||
password = os.environ.get('POSTGRES_PASSWORD')
|
||||
host = os.environ.get('POSTGRES_HOST')
|
||||
port = os.environ.get('POSTGRES_PORT')
|
||||
database = os.environ.get('POSTGRES_DB')
|
||||
database_url = f"postgresql+asyncpg://{user}:{password}@{host}:{port}/{database}"
|
||||
self.async_engine = create_async_engine(database_url, echo=True)
|
||||
self.async_session_maker = sessionmaker(self.async_engine, class_=AsyncSession, expire_on_commit=False)
|
||||
|
||||
self._auth_database = AuthDatabase(self.async_session_maker)
|
||||
self._provider_database = ProviderDatabase(self.async_session_maker)
|
||||
self._individual_database = IndividualDatabase(self.async_session_maker)
|
||||
self._system_node_database = SystemNodeDatabase(self.async_session_maker)
|
||||
|
||||
self.ready_event = asyncio.Event()
|
||||
|
||||
async def init_db(self) -> None:
|
||||
try:
|
||||
async with self.async_engine.begin() as conn:
|
||||
await conn.run_sync(SQLModel.metadata.create_all)
|
||||
except Exception as e:
|
||||
# Provide a warning if the database is not accessible, allowing
|
||||
# the app to start up for development/UI tests without crashing immediately.
|
||||
print(f"Warning: Failed to initialize PostgreSQL database: {e}")
|
||||
finally:
|
||||
self.ready_event.set()
|
||||
|
||||
# Auth Database Methods
|
||||
async def add_user(self, user_name: str, hashed_password: str):
|
||||
await self.ready_event.wait()
|
||||
return await self._auth_database.add_user(user_name, hashed_password)
|
||||
|
||||
async def change_password(self, user_name, old_password, new_password):
|
||||
await self.ready_event.wait()
|
||||
return await self._auth_database.change_password(user_name, old_password, new_password)
|
||||
|
||||
async def delete_user(self, user_name: str):
|
||||
await self.ready_event.wait()
|
||||
return await self._auth_database.delete_user(user_name)
|
||||
|
||||
async def delete_user_by_id(self, user_id: str):
|
||||
await self.ready_event.wait()
|
||||
return await self._auth_database.delete_user_by_id(user_id)
|
||||
|
||||
async def login_user(self, user_name: str):
|
||||
await self.ready_event.wait()
|
||||
return await self._auth_database.login_user(user_name)
|
||||
|
||||
async def get_all_users(self):
|
||||
await self.ready_event.wait()
|
||||
return await self._auth_database.get_all_users()
|
||||
|
||||
async def get_user_authority(self, user_id: str):
|
||||
await self.ready_event.wait()
|
||||
return await self._auth_database.get_user_authority(user_id)
|
||||
|
||||
async def change_user_authority(self, user_id: str, new_authority):
|
||||
await self.ready_event.wait()
|
||||
return await self._auth_database.change_user_authority(user_id, new_authority)
|
||||
|
||||
# Provider Database Methods
|
||||
async def get_provider(self):
|
||||
await self.ready_event.wait()
|
||||
return await self._provider_database.get_provider()
|
||||
|
||||
async def add_provider_db(self, **kwargs):
|
||||
await self.ready_event.wait()
|
||||
return await self._provider_database.add_provider(**kwargs)
|
||||
|
||||
async def delete_provider_db(self, provider_id: str):
|
||||
await self.ready_event.wait()
|
||||
return await self._provider_database.delete_provider(provider_id)
|
||||
|
||||
async def update_provider_db(self, provider_id: str, **kwargs):
|
||||
await self.ready_event.wait()
|
||||
return await self._provider_database.update_provider(provider_id, **kwargs)
|
||||
|
||||
# System Node Database Methods
|
||||
async def upsert_system_node_config(self, node_name: str, provider_title: str, model_id: str, tools: list[str] = None):
|
||||
await self.ready_event.wait()
|
||||
return await self._system_node_database.upsert_system_node_config(node_name, provider_title, model_id, tools)
|
||||
|
||||
async def get_all_system_node_configs(self):
|
||||
await self.ready_event.wait()
|
||||
return await self._system_node_database.get_all_system_node_configs()
|
||||
|
||||
# Individual Database Methods
|
||||
async def add_worker_individual(self, **kwargs):
|
||||
await self.ready_event.wait()
|
||||
return await self._individual_database.add_worker_individual(**kwargs)
|
||||
|
||||
async def get_worker_individual(self, agent_id: str):
|
||||
await self.ready_event.wait()
|
||||
return await self._individual_database.get_worker_individual(agent_id)
|
||||
|
||||
async def get_worker_individual_list(self, owner_id: str):
|
||||
await self.ready_event.wait()
|
||||
return await self._individual_database.get_worker_individual_list(owner_id)
|
||||
|
||||
async def update_worker_individual(self, agent_id: str, **kwargs):
|
||||
await self.ready_event.wait()
|
||||
return await self._individual_database.update_worker_individual(agent_id, **kwargs)
|
||||
|
||||
async def delete_worker_individual(self, agent_id: str):
|
||||
await self.ready_event.wait()
|
||||
return await self._individual_database.delete_worker_individual(agent_id)
|
||||
|
||||
async def get_all_worker_individual(self):
|
||||
await self.ready_event.wait()
|
||||
return await self._individual_database.get_all_worker_individual()
|
||||
@@ -0,0 +1,18 @@
|
||||
# Copyright 2026 zhaoxi826
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from pretor.core.database.table.user import User
|
||||
from pretor.core.database.table.provider import Provider
|
||||
from pretor.core.database.table.individual import WorkerIndividual
|
||||
__all__ = ["User", "Provider", "WorkerIndividual"]
|
||||
@@ -0,0 +1,38 @@
|
||||
# Copyright 2026 zhaoxi826
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from sqlmodel import SQLModel, Field
|
||||
from typing import List, Optional
|
||||
from sqlalchemy import Column, JSON
|
||||
from enum import Enum
|
||||
|
||||
class AgentType(str, Enum):
|
||||
SKILL_INDIVIDUAL = "skill_individual"
|
||||
ORDINARY_INDIVIDUAL = "ordinary_individual"
|
||||
SPECIAL_INDIVIDUAL = "special_individual"
|
||||
|
||||
class WorkerIndividual(SQLModel, table=True):
|
||||
__tablename__ = "worker_individual"
|
||||
agent_id: str = Field(primary_key=True)
|
||||
agent_name: str = Field(index=True)
|
||||
agent_type: AgentType
|
||||
description: str = Field(nullable=False)
|
||||
provider_title: str
|
||||
model_id: str
|
||||
system_prompt: Optional[str]
|
||||
output_template: Optional[dict] = Field(sa_column=Column(JSON),description="输出模板标识")
|
||||
bound_skill: Optional[str] = Field(sa_column=Column(JSON))
|
||||
workspace: Optional[List[str]] = Field(sa_column=Column(JSON))
|
||||
tools: Optional[List[str]] = Field(sa_column=Column(JSON), default=None)
|
||||
owner_id: str
|
||||
@@ -0,0 +1,14 @@
|
||||
# Copyright 2026 zhaoxi826
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
@@ -0,0 +1,32 @@
|
||||
# Copyright 2026 zhaoxi826
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from sqlmodel import SQLModel, Field
|
||||
from typing import List
|
||||
from sqlalchemy import Column, JSON
|
||||
from typing import Optional
|
||||
|
||||
class Provider(SQLModel, table=True):
|
||||
__tablename__ = "provider"
|
||||
provider_id: str = Field(primary_key=True)
|
||||
provider_title: str = Field(index=True)
|
||||
provider_type: str
|
||||
|
||||
provider_url: Optional[str]
|
||||
provider_apikey: Optional[str]
|
||||
|
||||
provider_models: List[str] = Field(sa_column=Column(JSON))
|
||||
|
||||
provider_owner: str
|
||||
is_active: bool = Field(default=True, description="该服务商节点是否在线/启用")
|
||||
@@ -0,0 +1,25 @@
|
||||
# Copyright 2026 zhaoxi826
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from sqlmodel import SQLModel, Field
|
||||
|
||||
from typing import List, Optional
|
||||
from sqlalchemy import Column, JSON
|
||||
|
||||
class SystemNodeConfig(SQLModel, table=True):
|
||||
__tablename__ = "system_node_config"
|
||||
node_name: str = Field(primary_key=True)
|
||||
provider_title: str
|
||||
model_id: str
|
||||
tools: Optional[List[str]] = Field(sa_column=Column(JSON), default=None)
|
||||
@@ -0,0 +1,31 @@
|
||||
# Copyright 2026 zhaoxi826
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from sqlmodel import SQLModel, Field
|
||||
from enum import IntEnum
|
||||
|
||||
class UserAuthority(IntEnum):
|
||||
SUPER_ADMINISTRATOR = 100
|
||||
ADMINISTRATOR = 50
|
||||
USER = 20
|
||||
UNAUTHORIZED_USER = 10
|
||||
GUEST = 0
|
||||
|
||||
class User(SQLModel, table=True):
|
||||
__tablename__ = 'user'
|
||||
user_id: str = Field(primary_key=True)
|
||||
user_name: str = Field(index=True)
|
||||
hashed_password: str
|
||||
user_authority: UserAuthority = Field(default=UserAuthority.USER)
|
||||
|
||||
@@ -0,0 +1,14 @@
|
||||
# Copyright 2026 zhaoxi826
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
@@ -0,0 +1,166 @@
|
||||
# Copyright 2026 zhaoxi826
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import ray
|
||||
from pretor.core.global_state_machine.provider_manager import ProviderManager
|
||||
from pretor.core.global_state_machine.tool_manager import GlobalToolManager
|
||||
from typing import Dict
|
||||
from pretor.core.database.postgres import PostgresDatabase
|
||||
from pretor.api.platform.event import PretorEvent
|
||||
import asyncio
|
||||
from pretor.core.workflow.workflow import PretorWorkflow
|
||||
from pretor.core.workflow.workflow_template_manager import WorkflowManager
|
||||
from pretor.core.global_state_machine.skill_manager import GlobalSkillManager
|
||||
from pretor.core.global_state_machine.individual_manager import GlobalIndividualManager
|
||||
|
||||
|
||||
@ray.remote
|
||||
class GlobalStateMachine:
|
||||
def __init__(self, postgres_database: PostgresDatabase):
|
||||
import sys
|
||||
print("GSM __init__ START", file=sys.stderr, flush=True)
|
||||
self.event_dict: Dict[str, PretorEvent] = {}
|
||||
print(" event_dict done", file=sys.stderr, flush=True)
|
||||
self._global_provider_manager = ProviderManager(postgres_database)
|
||||
print(" provider_manager done", file=sys.stderr, flush=True)
|
||||
self._global_tool_manager = GlobalToolManager()
|
||||
print(" tool_manager done", file=sys.stderr, flush=True)
|
||||
self._global_workflow_template_manager = WorkflowManager()
|
||||
print(" workflow_template_manager done", file=sys.stderr, flush=True)
|
||||
self._global_skill_manager = GlobalSkillManager()
|
||||
print(" skill_manager done", file=sys.stderr, flush=True)
|
||||
self._global_individual_manager = GlobalIndividualManager()
|
||||
print(" individual_manager done", file=sys.stderr, flush=True)
|
||||
self.postgres_database = postgres_database
|
||||
print("GSM __init__ DONE", file=sys.stderr, flush=True)
|
||||
|
||||
async def init_state_machine(self):
|
||||
await self._global_provider_manager.init_provider_register(self.postgres_database)
|
||||
await self._global_individual_manager.init_individual_register(self.postgres_database)
|
||||
|
||||
async def add_provider_wrap(self, provider_type, provider_title, provider_url, provider_apikey, provider_owner):
|
||||
return await self._global_provider_manager.add_provider(
|
||||
provider_type=provider_type,
|
||||
provider_title=provider_title,
|
||||
provider_url=provider_url,
|
||||
provider_apikey=provider_apikey,
|
||||
provider_owner=provider_owner,
|
||||
postgres_database=self.postgres_database
|
||||
)
|
||||
|
||||
# Provider Manager Methods
|
||||
def get_provider_list(self):
|
||||
return self._global_provider_manager.get_provider_list()
|
||||
|
||||
def get_provider(self, provider_title):
|
||||
return self._global_provider_manager.get_provider(provider_title)
|
||||
|
||||
async def delete_provider(self, provider_title: str):
|
||||
return await self._global_provider_manager.delete_provider(provider_title, self.postgres_database)
|
||||
|
||||
# Tool Manager Methods
|
||||
def get_tool_mapper(self):
|
||||
return self._global_tool_manager.tool_mapper
|
||||
|
||||
def get_tool_list(self, agent_name: str):
|
||||
# get_tool_list didn't actually exist on tool_manager, let's implement it to return the tools
|
||||
# for a specific agent name (or scope)
|
||||
tools = self._global_tool_manager.tool_mapper.get(agent_name, {})
|
||||
# also include default tools
|
||||
default_tools = self._global_tool_manager.tool_mapper.get("default", {})
|
||||
merged_tools = {**default_tools, **tools}
|
||||
return merged_tools
|
||||
|
||||
# Workflow Template Manager Methods
|
||||
def get_all_workflow_templates(self):
|
||||
return self._global_workflow_template_manager.get_all_workflow_templates()
|
||||
|
||||
def add_workflow_template(self, template_name: str, workflow_template):
|
||||
return self._global_workflow_template_manager.add_workflow_template(template_name, workflow_template)
|
||||
|
||||
def delete_workflow_template(self, template_name: str):
|
||||
return self._global_workflow_template_manager.delete_workflow_template(template_name)
|
||||
|
||||
def generate_workflow_template(self, workflow_template):
|
||||
return self._global_workflow_template_manager.generate_workflow_template(workflow_template)
|
||||
|
||||
# Skill Manager Methods
|
||||
def add_skill(self, skill_name: str):
|
||||
return self._global_skill_manager.add_skill(skill_name)
|
||||
|
||||
def get_skill_list(self):
|
||||
return self._global_skill_manager.get_skill_list()
|
||||
|
||||
def remove_skill(self, skill_name: str):
|
||||
return self._global_skill_manager.remove_skill(skill_name)
|
||||
|
||||
# Individual Manager Methods
|
||||
def add_individual(self, agent_id: str, config):
|
||||
return self._global_individual_manager.add_individual(agent_id, config)
|
||||
|
||||
def get_individual(self, agent_id: str):
|
||||
return self._global_individual_manager.get_individual(agent_id)
|
||||
|
||||
def remove_individual(self, agent_id: str):
|
||||
return self._global_individual_manager.remove_individual(agent_id)
|
||||
|
||||
def list_individuals(self):
|
||||
return self._global_individual_manager.list_individuals()
|
||||
|
||||
###以下方法为event_dict方法
|
||||
def add_event(self, event: PretorEvent) -> None:
|
||||
event.pending_queue = asyncio.Queue()
|
||||
event.receive_queue = asyncio.Queue()
|
||||
self.event_dict[event.trace_id] = event
|
||||
|
||||
def delete_event(self, trace_id: str) -> None:
|
||||
del self.event_dict[trace_id]
|
||||
|
||||
def get_event(self, trace_id: str) -> PretorEvent:
|
||||
return self.event_dict.get(trace_id, None)
|
||||
|
||||
def update_attachment(self, trace_id: str, attachment: Dict[str, str]) -> None:
|
||||
self.event_dict[trace_id].attachment = attachment
|
||||
|
||||
def update_workflow(self, trace_id: str, workflow: PretorWorkflow) -> None:
|
||||
self.event_dict[trace_id].workflow = workflow
|
||||
|
||||
def get_workflow(self, trace_id: str) -> PretorWorkflow:
|
||||
return self.event_dict[trace_id].workflow
|
||||
|
||||
def list_events(self) -> list[dict]:
|
||||
result = []
|
||||
for trace_id, event in self.event_dict.items():
|
||||
workflow_title = event.workflow.title if event.workflow else None
|
||||
workflow_status = event.workflow.status.status if event.workflow and event.workflow.status else None
|
||||
result.append({
|
||||
"event_id": trace_id,
|
||||
"workflow_title": workflow_title,
|
||||
"status": workflow_status,
|
||||
"user_name": event.user_name,
|
||||
"message": event.message,
|
||||
})
|
||||
return result
|
||||
|
||||
async def put_pending(self, trace_id, item) -> None:
|
||||
await self.event_dict[trace_id].pending_queue.put(item)
|
||||
|
||||
async def get_pending(self, trace_id) -> str:
|
||||
return await self.event_dict[trace_id].pending_queue.get()
|
||||
|
||||
async def put_received(self, trace_id, item) -> None:
|
||||
await self.event_dict[trace_id].receive_queue.put(item)
|
||||
|
||||
async def get_received(self, trace_id) -> str:
|
||||
return await self.event_dict[trace_id].receive_queue.get()
|
||||
@@ -0,0 +1,62 @@
|
||||
# Copyright 2026 zhaoxi826
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import Dict, Any
|
||||
from pretor.utils.logger import get_logger
|
||||
logger = get_logger('individual_manager')
|
||||
|
||||
class GlobalIndividualManager:
|
||||
def __init__(self):
|
||||
self._individuals: Dict[str, Dict[str, Any]] = {}
|
||||
|
||||
async def init_individual_register(self, postgres) -> None:
|
||||
try:
|
||||
try:
|
||||
individuals = await postgres.get_all_worker_individual.remote()
|
||||
for ind in individuals:
|
||||
agent_id = getattr(ind, 'agent_id', None)
|
||||
if agent_id:
|
||||
self._individuals[agent_id] = ind.model_dump() if hasattr(ind, 'model_dump') else dict(ind)
|
||||
logger.info(f"成功从数据库拉取了 {len(self._individuals)} 个 Worker Individual 配置。")
|
||||
except AttributeError:
|
||||
logger.warning("数据库中 get_all_worker_individual 方法未实现,跳过全量加载。可以在将来完善该接口。")
|
||||
except Exception as e:
|
||||
# 捕获因 Ray 调用目标方法不存在引发的异常
|
||||
if "has no attribute 'get_all_worker_individual'" in str(e):
|
||||
logger.warning("数据库 individual_database 中缺少 get_all_worker_individual 方法,无法全量拉取。")
|
||||
else:
|
||||
raise e
|
||||
except Exception as e:
|
||||
logger.error(f"从数据库拉取 Worker Individual 配置失败: {e}")
|
||||
|
||||
def add_individual(self, agent_id: str, config: Dict[str, Any]) -> None:
|
||||
"""
|
||||
注册一个 worker individual
|
||||
config 可以包含 type, prompt, provider_title, model_id 等
|
||||
"""
|
||||
config["agent_id"] = agent_id
|
||||
self._individuals[agent_id] = config
|
||||
|
||||
def get_individual(self, agent_id: str) -> Dict[str, Any]:
|
||||
"""
|
||||
获取一个 worker individual 的配置
|
||||
"""
|
||||
return self._individuals.get(agent_id, None)
|
||||
|
||||
def remove_individual(self, agent_id: str) -> None:
|
||||
if agent_id in self._individuals:
|
||||
del self._individuals[agent_id]
|
||||
|
||||
def list_individuals(self) -> Dict[str, Dict[str, Any]]:
|
||||
return self._individuals
|
||||
@@ -0,0 +1,19 @@
|
||||
# Copyright 2026 zhaoxi826
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from pretor.core.global_state_machine.model_provider.base_provider import Provider, ProviderArgs
|
||||
from pretor.core.global_state_machine.model_provider.openai_provider import OpenAIProvider
|
||||
from pretor.core.global_state_machine.model_provider.claude_provider import ClaudeProvider
|
||||
from pretor.core.global_state_machine.model_provider.deepseek_provider import DeepseekProvider
|
||||
__all__ = ["Provider", "ProviderArgs", "OpenAIProvider", "ClaudeProvider", "DeepseekProvider"]
|
||||
@@ -0,0 +1,96 @@
|
||||
# Copyright 2026 zhaoxi826
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from pydantic import BaseModel
|
||||
from typing import List
|
||||
from enum import Enum
|
||||
|
||||
class ProviderStatus(str, Enum):
|
||||
UP = "up"
|
||||
DOWN = "down"
|
||||
|
||||
class Provider(BaseModel):
|
||||
provider_title: str
|
||||
provider_url: str
|
||||
provider_apikey: str
|
||||
provider_models: List[str]
|
||||
provider_type: str
|
||||
provider_owner: str | None = None
|
||||
provider_status: ProviderStatus = ProviderStatus.UP
|
||||
|
||||
class ProviderArgs(BaseModel):
|
||||
provider_title: str
|
||||
provider_url: str
|
||||
provider_apikey: str
|
||||
provider_owner: str
|
||||
|
||||
class BaseProvider(ABC):
|
||||
@staticmethod
|
||||
@abstractmethod
|
||||
async def create_provider(provider_args: ProviderArgs) -> Provider:
|
||||
"""
|
||||
创建一个供应商,传入provider_args参数,打包为一个Provider对象
|
||||
|
||||
Args:
|
||||
provider_args: 参数包,包含以下几个参数
|
||||
provider_title: 供应商的别名
|
||||
provider_url: 供应商的url
|
||||
provider_apikey:供应商的apikey
|
||||
|
||||
Returns:
|
||||
返回一个Provider对象,由provider_manager管理
|
||||
"""
|
||||
pass
|
||||
|
||||
@staticmethod
|
||||
@abstractmethod
|
||||
async def _load_models(provider_args: ProviderArgs) -> List[str]:
|
||||
"""
|
||||
加载模型列表
|
||||
base_provider的字类应当按照供应商的api标准,向供应商的接口发送http请求从而或者供应商所提供的模型列表
|
||||
|
||||
Args:
|
||||
provider_args: 参数包,包含以下几个参数
|
||||
provider_title: 供应商的别名
|
||||
provider_url: 供应商的url
|
||||
provider_apikey:供应商的apikey
|
||||
|
||||
Returns:
|
||||
返回一个列表,为http请求获取的模型列表
|
||||
|
||||
"""
|
||||
pass
|
||||
|
||||
@staticmethod
|
||||
@abstractmethod
|
||||
def _return_provider(provider_args: ProviderArgs, provider_models: List[str]) -> Provider:
|
||||
"""
|
||||
包装Provider对象并返回
|
||||
将provider_args和_load_models获取的方法包装为provider对象
|
||||
|
||||
Args:
|
||||
provider_args: 参数包,包含以下几个参数
|
||||
provider_title: 供应商的别名
|
||||
provider_url: 供应商的url
|
||||
provider_apikey:供应商的apikey
|
||||
|
||||
provider_models: 模型列表,为该供应商包含的模型列表
|
||||
|
||||
Returns:
|
||||
返回一个Provider对象
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
@@ -0,0 +1,60 @@
|
||||
from pretor.utils.retry import retry_on_retryable_error
|
||||
# Copyright 2026 zhaoxi826
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from pretor.core.global_state_machine.model_provider.base_provider import BaseProvider, Provider, ProviderArgs
|
||||
import httpx
|
||||
from typing import List
|
||||
|
||||
class ClaudeProvider(BaseProvider):
|
||||
@staticmethod
|
||||
async def create_provider(provider_args: ProviderArgs) -> Provider:
|
||||
provider_models: List[str] = await ClaudeProvider._load_models(provider_args)
|
||||
provider: Provider = ClaudeProvider._return_provider(provider_args, provider_models)
|
||||
return provider
|
||||
|
||||
@staticmethod
|
||||
@retry_on_retryable_error()
|
||||
async def _load_models(provider_args: ProviderArgs) -> List[str]:
|
||||
# Anthropic 官方需要 version 头
|
||||
headers = {
|
||||
"x-api-key": provider_args.provider_apikey,
|
||||
"anthropic-version": "2023-06-01",
|
||||
"Content-Type": "application/json"
|
||||
}
|
||||
# 如果是官方 API,通常使用 /v1/models (如果支持)
|
||||
# 注意:很多时候 Anthropic 并不返回完整列表,如果请求失败,建议返回硬编码的常用模型
|
||||
url = f"{provider_args.provider_url.rstrip('/')}/v1/models"
|
||||
|
||||
try:
|
||||
async with httpx.AsyncClient(timeout=10.0) as client:
|
||||
response = await client.get(url, headers=headers)
|
||||
if response.status_code == 200:
|
||||
data = response.json()
|
||||
model_ids = [m["id"] for m in data.get("data", [])]
|
||||
return sorted(model_ids)
|
||||
else:
|
||||
# 如果官方列表接口不可用,fallback 到已知常用模型
|
||||
return ["claude-3-5-sonnet-20240620", "claude-3-opus-20240229", "claude-3-haiku-20240307"]
|
||||
except Exception as e:
|
||||
print(f"[{provider_args.provider_title}] 获取 Claude 模型列表错误: {e}")
|
||||
return []
|
||||
|
||||
@staticmethod
|
||||
def _return_provider(provider_args: ProviderArgs, provider_models: List[str]) -> Provider:
|
||||
return Provider(provider_title=provider_args.provider_title,
|
||||
provider_apikey=provider_args.provider_apikey,
|
||||
provider_url=provider_args.provider_url,
|
||||
provider_models=provider_models,
|
||||
provider_type="claude")
|
||||
@@ -0,0 +1,59 @@
|
||||
# Copyright 2026 zhaoxi826
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from pretor.utils.retry import retry_on_retryable_error
|
||||
from pretor.core.global_state_machine.model_provider.base_provider import BaseProvider, Provider, ProviderArgs
|
||||
import httpx
|
||||
from typing import List
|
||||
|
||||
class DeepseekProvider(BaseProvider):
|
||||
@staticmethod
|
||||
async def create_provider(provider_args: ProviderArgs) -> Provider:
|
||||
provider_models: List[str] = await DeepseekProvider._load_models(provider_args)
|
||||
provider: Provider = DeepseekProvider._return_provider(provider_args, provider_models)
|
||||
return provider
|
||||
|
||||
@staticmethod
|
||||
@retry_on_retryable_error()
|
||||
async def _load_models(provider_args: ProviderArgs) -> List[str]:
|
||||
headers = {
|
||||
"Authorization": f"Bearer {provider_args.provider_apikey}",
|
||||
"Content-Type": "application/json"
|
||||
}
|
||||
url = f"{provider_args.provider_url}/models" if "/v1" in provider_args.provider_url else f"{provider_args.provider_url}/v1/models"
|
||||
try:
|
||||
async with httpx.AsyncClient(timeout=10.0) as client:
|
||||
response = await client.get(url, headers=headers)
|
||||
if response.status_code != 200:
|
||||
print(f"[{provider_args.provider_title}] 获取模型失败: {response.status_code}")
|
||||
return []
|
||||
data = response.json()
|
||||
raw_models = data.get("data", [])
|
||||
model_ids = [m["id"] for m in raw_models]
|
||||
return sorted(model_ids)
|
||||
except httpx.RequestError as e:
|
||||
from pretor.utils.error import RetryableError
|
||||
print(f"[{provider_args.provider_title}] 网络请求异常: {e}")
|
||||
raise RetryableError(f"[{provider_args.provider_title}] 网络请求异常: {e}") from e
|
||||
except Exception as e:
|
||||
print(f"[{provider_args.provider_title}] 解析模型列表时发生错误: {e}")
|
||||
return []
|
||||
|
||||
@staticmethod
|
||||
def _return_provider(provider_args: ProviderArgs, provider_models: List[str]) -> Provider:
|
||||
return Provider(provider_title=provider_args.provider_title,
|
||||
provider_apikey=provider_args.provider_apikey,
|
||||
provider_url=provider_args.provider_url,
|
||||
provider_models=provider_models,
|
||||
provider_type="deepseek")
|
||||
@@ -0,0 +1,59 @@
|
||||
# Copyright 2026 zhaoxi826
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from pretor.utils.retry import retry_on_retryable_error
|
||||
from pretor.core.global_state_machine.model_provider.base_provider import BaseProvider, Provider, ProviderArgs
|
||||
import httpx
|
||||
from typing import List
|
||||
|
||||
class OpenAIProvider(BaseProvider):
|
||||
@staticmethod
|
||||
async def create_provider(provider_args: ProviderArgs) -> Provider:
|
||||
provider_models: List[str] = await OpenAIProvider._load_models(provider_args)
|
||||
provider: Provider = OpenAIProvider._return_provider(provider_args, provider_models)
|
||||
return provider
|
||||
|
||||
@staticmethod
|
||||
@retry_on_retryable_error()
|
||||
async def _load_models(provider_args: ProviderArgs) -> List[str]:
|
||||
headers = {
|
||||
"Authorization": f"Bearer {provider_args.provider_apikey}",
|
||||
"Content-Type": "application/json"
|
||||
}
|
||||
url = f"{provider_args.provider_url}/models" if "/v1" in provider_args.provider_url else f"{provider_args.provider_url}/v1/models"
|
||||
try:
|
||||
async with httpx.AsyncClient(timeout=10.0) as client:
|
||||
response = await client.get(url, headers=headers)
|
||||
if response.status_code != 200:
|
||||
print(f"[{provider_args.provider_title}] 获取模型失败: {response.status_code}")
|
||||
return []
|
||||
data = response.json()
|
||||
raw_models = data.get("data", [])
|
||||
model_ids = [m["id"] for m in raw_models]
|
||||
return sorted(model_ids)
|
||||
except httpx.RequestError as e:
|
||||
from pretor.utils.error import RetryableError
|
||||
print(f"[{provider_args.provider_title}] 网络请求异常: {e}")
|
||||
raise RetryableError(f"[{provider_args.provider_title}] 网络请求异常: {e}") from e
|
||||
except Exception as e:
|
||||
print(f"[{provider_args.provider_title}] 解析模型列表时发生错误: {e}")
|
||||
return []
|
||||
|
||||
@staticmethod
|
||||
def _return_provider(provider_args: ProviderArgs, provider_models: List[str]) -> Provider:
|
||||
return Provider(provider_title=provider_args.provider_title,
|
||||
provider_apikey=provider_args.provider_apikey,
|
||||
provider_url=provider_args.provider_url,
|
||||
provider_models=provider_models,
|
||||
provider_type="openai")
|
||||
@@ -0,0 +1,86 @@
|
||||
# Copyright 2026 zhaoxi826
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from pretor.core.global_state_machine.model_provider import Provider, OpenAIProvider, ClaudeProvider, DeepseekProvider
|
||||
from typing import Dict, Type
|
||||
|
||||
class ProviderManager:
|
||||
"""
|
||||
模型供应商管理器 (ProviderManager)。
|
||||
负责维护不同的 LLM 协议适配器,提供从配置注册到 Agent 实例化的全生命周期管理。
|
||||
"""
|
||||
# --- 类属性显式标注 (IDE 友好) ---
|
||||
provider_mapper: Dict[str, Type[Provider]]
|
||||
"""协议映射表:键为协议名(如 'openai'),值为对应的 Provider 类。"""
|
||||
|
||||
provider_register: Dict[str, Provider]
|
||||
"""供应商注册表:键为用户自定义别名,值为已实例化的 Provider 对象。"""
|
||||
def __init__(self, postgres):
|
||||
self.provider_mapper = {"openai": OpenAIProvider,
|
||||
"claude": ClaudeProvider,
|
||||
"deepseek": DeepseekProvider}
|
||||
self.provider_register = {}
|
||||
|
||||
async def init_provider_register(self, postgres) -> None:
|
||||
providers = await postgres.get_provider.remote()
|
||||
for provider in providers:
|
||||
self.provider_register[provider.provider_title] = provider
|
||||
|
||||
async def add_provider(self, provider_type, provider_title, provider_url, provider_apikey, provider_owner, postgres_database) -> None:
|
||||
from pretor.core.global_state_machine.model_provider import ProviderArgs
|
||||
from pretor.utils.logger import get_logger
|
||||
logger = get_logger('provider_manager')
|
||||
import httpx
|
||||
|
||||
provider_args: ProviderArgs = ProviderArgs(provider_title=provider_title,
|
||||
provider_url=provider_url,
|
||||
provider_apikey=provider_apikey,
|
||||
provider_owner=provider_owner)
|
||||
try:
|
||||
import ulid
|
||||
provider_class = self.provider_mapper.get(provider_type, None)
|
||||
if provider_class is None:
|
||||
logger.warning(f"Provider type {provider_type} is not supported.")
|
||||
return None
|
||||
provider: Provider = await provider_class.create_provider(provider_args)
|
||||
provider.provider_owner = provider_owner
|
||||
self.provider_register[provider_title] = provider
|
||||
await postgres_database.add_provider_db.remote(
|
||||
provider_id=str(ulid.ULID()),
|
||||
provider_title=provider.provider_title,
|
||||
provider_url=provider.provider_url,
|
||||
provider_apikey=provider.provider_apikey,
|
||||
provider_models=provider.provider_models,
|
||||
provider_type=provider.provider_type,
|
||||
provider_owner=provider.provider_owner)
|
||||
|
||||
logger.info(f"已添加适配器{provider_title}")
|
||||
except httpx.RequestError as e:
|
||||
from pretor.utils.error import RetryableError
|
||||
logger.warning(f"[{provider_args.provider_title}] 网络请求异常: {e}")
|
||||
raise RetryableError(f"[{provider_args.provider_title}] 网络请求异常: {e}") from e
|
||||
except Exception as e:
|
||||
logger.warning(f"[{provider_args.provider_title}] 解析模型列表时发生错误: {e}")
|
||||
|
||||
def get_provider_list(self):
|
||||
return self.provider_register
|
||||
|
||||
def get_provider(self, provider_title):
|
||||
return self.provider_register.get(provider_title)
|
||||
|
||||
async def delete_provider(self, provider_title: str, postgres_database) -> None:
|
||||
if provider_title in self.provider_register:
|
||||
provider = self.provider_register[provider_title]
|
||||
await postgres_database.delete_provider_db.remote( provider_id=provider.provider_id)
|
||||
del self.provider_register[provider_title]
|
||||
@@ -0,0 +1,75 @@
|
||||
# Copyright 2026 zhaoxi826
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import Tuple, Dict
|
||||
from collections import defaultdict
|
||||
import pathlib
|
||||
import json
|
||||
|
||||
class GlobalSkillManager:
|
||||
skill_mapper = Dict[str,Tuple[str]]
|
||||
"""skill的存储表"""
|
||||
|
||||
def __init__(self):
|
||||
self.skill_mapper = defaultdict(tuple)
|
||||
|
||||
import os
|
||||
skill_plugin_dir = pathlib.Path(os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "..", "plugin", "skill")))
|
||||
if not skill_plugin_dir.exists() or not skill_plugin_dir.is_dir():
|
||||
return
|
||||
for item in skill_plugin_dir.iterdir():
|
||||
if item.is_dir() and not item.name.startswith((".", "__")):
|
||||
json_path = item / "skill.json" # 拼接文件路径
|
||||
if json_path.exists():
|
||||
try:
|
||||
with open(json_path, "r", encoding="utf-8") as f:
|
||||
skill = json.load(f)
|
||||
# 提取并映射
|
||||
name = skill.get("name")
|
||||
if name:
|
||||
self.skill_mapper[name] = (
|
||||
skill.get("description", ""),
|
||||
skill.get("instructions", "")
|
||||
)
|
||||
except (json.JSONDecodeError, OSError) as e:
|
||||
print(f"警告: 加载插件 {item.name} 失败: {e}")
|
||||
|
||||
def add_skill(self, skill_name: str) -> None:
|
||||
"""Add a skill to the manager by reading its skill.json from the path"""
|
||||
import os
|
||||
skill_plugin_dir = pathlib.Path(os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "..", "plugin", "skill")))
|
||||
item = skill_plugin_dir / skill_name
|
||||
if item.is_dir() and not item.name.startswith((".", "__")):
|
||||
json_path = item / "skill.json"
|
||||
if json_path.exists():
|
||||
try:
|
||||
with open(json_path, "r", encoding="utf-8") as f:
|
||||
skill = json.load(f)
|
||||
name = skill.get("name")
|
||||
if name:
|
||||
self.skill_mapper[name] = (
|
||||
skill.get("description", ""),
|
||||
skill.get("instructions", "")
|
||||
)
|
||||
except (json.JSONDecodeError, OSError) as e:
|
||||
print(f"警告: 加载插件 {item.name} 失败: {e}")
|
||||
|
||||
def get_skill_list(self) -> dict:
|
||||
"""Return all skills currently loaded."""
|
||||
return self.skill_mapper
|
||||
|
||||
def remove_skill(self, skill_name: str) -> None:
|
||||
"""Remove a skill from the manager mapping."""
|
||||
if skill_name in self.skill_mapper:
|
||||
del self.skill_mapper[skill_name]
|
||||
@@ -0,0 +1,52 @@
|
||||
# Copyright 2026 zhaoxi826
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import pathlib
|
||||
import importlib
|
||||
import inspect
|
||||
from collections import defaultdict
|
||||
from pretor.plugin.tool_plugin.base_tool import BaseToolData
|
||||
from typing import Dict, Type
|
||||
from pretor.utils.logger import get_logger
|
||||
logger = get_logger('tool_manager')
|
||||
|
||||
class GlobalToolManager:
|
||||
tool_mapper: Dict[str, Dict[str, Type[BaseToolData]]]
|
||||
|
||||
def __init__(self):
|
||||
self.tool_mapper = defaultdict(dict)
|
||||
|
||||
tool_plugin_dir = pathlib.Path(__file__).parent.parent.parent / "plugin" / "tool_plugin"
|
||||
if not tool_plugin_dir.exists() or not tool_plugin_dir.is_dir():
|
||||
return
|
||||
|
||||
for item in tool_plugin_dir.iterdir():
|
||||
if item.is_dir() and not item.name.startswith("__"):
|
||||
plugin_name = item.name
|
||||
module_name = f"pretor.plugin.tool_plugin.{plugin_name}"
|
||||
|
||||
try:
|
||||
module = importlib.import_module(module_name)
|
||||
for name, obj in inspect.getmembers(module, inspect.isclass):
|
||||
if issubclass(obj, BaseToolData) and obj is not BaseToolData:
|
||||
# It's a valid tool class
|
||||
action_scopes = obj.model_fields.get("action_scope").default
|
||||
|
||||
if not action_scopes:
|
||||
self.tool_mapper["default"][plugin_name] = obj
|
||||
else:
|
||||
for scope in action_scopes:
|
||||
self.tool_mapper[scope][plugin_name] = obj
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to load tool plugin {plugin_name}: {e}")
|
||||
@@ -0,0 +1,14 @@
|
||||
# Copyright 2026 zhaoxi826
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
@@ -0,0 +1,16 @@
|
||||
# Copyright 2026 zhaoxi826
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from .consciousness_node import ConsciousnessNode
|
||||
__all__ = ["ConsciousnessNode"]
|
||||
@@ -0,0 +1,179 @@
|
||||
# Copyright 2026 zhaoxi826
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
import ray
|
||||
from typing import Union, overload
|
||||
from pretor.core.individual.consciousness_node.template import (ConsciousnessNodeDeps, ForSupervisoryNode, ForWorkflow,\
|
||||
ForWorkflowEngine, ForWorkflowInput, ForSupervisoryInput, ForWorkflowEngineInput)
|
||||
from pydantic_ai import Agent, RunContext
|
||||
from pretor.core.global_state_machine.global_state_machine import GlobalStateMachine
|
||||
from pretor.core.global_state_machine.model_provider.base_provider import Provider
|
||||
from pretor.adapter.model_adapter.agent_factory import AgentFactory
|
||||
|
||||
|
||||
@ray.remote
|
||||
class ConsciousnessNode:
|
||||
def __init__(self) -> None:
|
||||
from pretor.utils.logger import get_logger
|
||||
self.logger = get_logger('consciousness_node')
|
||||
self.agent: None | Agent = None
|
||||
|
||||
|
||||
async def create_agent(self, global_state_machine: GlobalStateMachine, provider_title: str, model_id: str, tools_list: list[str] = None) -> None:
|
||||
"""
|
||||
create_agent方法,将agent对象装配到ConsciousnessNode的属性内
|
||||
该方法通过provider_title从global_state_machine中获取provider对象,然后从provider对象中取出供应商形象,装配为pydantic_ai的
|
||||
Agent实例,
|
||||
并挂载到self.agent属性
|
||||
Args:
|
||||
global_state_machine: 全局状态机
|
||||
provider_title: 供应商名
|
||||
model_id: 模型id
|
||||
|
||||
Returns:
|
||||
无返回
|
||||
"""
|
||||
system_prompt: str = (
|
||||
"你叫Pretor,是一个多智能体AI助手系统中的【意识节点 (Consciousness Node)】。\n"
|
||||
"你是系统的'高级规划师'和'架构师',负责处理监控节点分配过来的复杂任务。\n"
|
||||
"你的主要工作场景包括:\n"
|
||||
"1. 拆解任务 (Workflow Generation):结合用户的原始命令和提供的模板,生成严谨、可执行的工作流 (PretorWorkflow),并将其输出为 ForWorkflowEngine 格式。拆解时步骤应清晰连贯。\n"
|
||||
"2. 中途指导 (Workflow Execution):在工作流执行中,如果某一步骤指派给你,你需要对控制节点的结果进行分析或提供下一步的指导,输出 ForWorkflow 格式。\n"
|
||||
"3. 总结报告 (Supervisory Report):在整个工作流执行完毕后,你需要对整体流程、各个控制节点的执行情况进行审查,并生成一份技术性的总结报告,输出 ForSupervisoryNode 格式。\n"
|
||||
"请确保所有的思考和生成过程符合逻辑,严密且高质量。"
|
||||
)
|
||||
output_type = Union[ForSupervisoryNode, ForWorkflow, ForWorkflowEngine]
|
||||
from pretor.utils.get_tool import load_tools_from_list
|
||||
provider: Provider = await global_state_machine.get_provider.remote( provider_title)
|
||||
agent_factory = AgentFactory()
|
||||
|
||||
callables = load_tools_from_list(tools_list)
|
||||
self.agent = agent_factory.create_agent(provider=provider,
|
||||
model_id=model_id,
|
||||
output_type=output_type,
|
||||
system_prompt=system_prompt,
|
||||
deps_type=ConsciousnessNodeDeps,
|
||||
agent_name="consciousness_node",
|
||||
tools=callables)
|
||||
|
||||
@self.agent.system_prompt
|
||||
async def dynamic_prompt(ctx: RunContext[ConsciousnessNodeDeps]):
|
||||
prompt = system_prompt + "\n\n"
|
||||
prompt += (
|
||||
f"=== 当前任务上下文 ===\n"
|
||||
f"- 当前指令 (Command): {ctx.deps.command}\n"
|
||||
f"- 原始用户命令 (Original Command): {ctx.deps.original_command}\n"
|
||||
)
|
||||
if ctx.deps.workflow_template:
|
||||
prompt += f"- 选定工作流模板 (Workflow Template): {ctx.deps.workflow_template}\n"
|
||||
if ctx.deps.available_skills:
|
||||
prompt += "\n=== 当前可用 Skill Individual ===\n"
|
||||
prompt += "你可以直接将以下 Skill Individual 安排进工作流的步骤中(设置 node 为 skill_individual,并将 agent_id 设置为对应 Skill Individual 的真实 agent_id,不要用名称!),作为可调用的工具。\n"
|
||||
for skill in ctx.deps.available_skills:
|
||||
prompt += f"- 真实 agent_id: {skill.get('agent_id')}\n 名称: {skill['name']}\n 描述: {skill['description']}\n"
|
||||
|
||||
return prompt
|
||||
|
||||
async def working(self, payload: Union[ForWorkflowEngineInput, ForWorkflowInput, ForSupervisoryInput]) -> Union[ForWorkflowEngine, ForWorkflow, ForSupervisoryNode, None]:
|
||||
try:
|
||||
result = await self._run(payload)
|
||||
if isinstance(result, (ForWorkflowEngine, ForWorkflow, ForSupervisoryNode)):
|
||||
return result
|
||||
else:
|
||||
self.logger.error(f"ConsciousnessNode: 未知或不匹配的返回类型: {type(result)}")
|
||||
return None
|
||||
except Exception:
|
||||
self.logger.exception("ConsciousnessNode在执行working时发生严重错误")
|
||||
return None
|
||||
|
||||
|
||||
@overload
|
||||
async def _run(self, payload: ForWorkflowEngineInput) -> ForWorkflowEngine:
|
||||
"""
|
||||
_run方法
|
||||
该分支应当在supervisory_node简单处理用户命令后,工作流创建前调用!
|
||||
Args:
|
||||
payload: 应当包含workflow_template和event对象
|
||||
|
||||
Returns:
|
||||
ForWorkflowEngine对象,将被放到全局状态机后丢入WorkflowEngine的异步队列
|
||||
"""
|
||||
pass
|
||||
|
||||
@overload
|
||||
async def _run(self, payload: ForWorkflow) -> ForWorkflow:
|
||||
"""
|
||||
_run方法
|
||||
该分支应当在workflow运行时,由WorkflowEngine进行调用!
|
||||
Args:
|
||||
payload: 应当包含workflow中的WorkStep对象
|
||||
|
||||
Returns:
|
||||
ForWorkflow对象,作为ConsciousnessNode执行Workflow中的WorkStep的结果
|
||||
"""
|
||||
pass
|
||||
|
||||
@overload
|
||||
async def _run(self, payload: ForSupervisoryInput) -> ForSupervisoryNode:
|
||||
"""
|
||||
_run方法
|
||||
该分支应当在workflow运行完全结束后,由WorkflowEngine进行调用!
|
||||
Args:
|
||||
payload: 应当包含整个Workflow的情况
|
||||
|
||||
Returns:
|
||||
ForSupervisory对象,作为ConsciousnessNode对于全工作流的技术性总结,返回给SupervisoryNode
|
||||
"""
|
||||
pass
|
||||
|
||||
async def _run(self, payload: Union[ForSupervisoryInput, ForWorkflowInput, ForWorkflowEngineInput]) -> Union[ForSupervisoryNode, ForWorkflow, ForWorkflowEngine]:
|
||||
try:
|
||||
self.agent.retries = 3
|
||||
if isinstance(payload, ForWorkflowEngineInput):
|
||||
deps = ConsciousnessNodeDeps(
|
||||
original_command=payload.original_command,
|
||||
workflow_template=payload.workflow_template,
|
||||
command="拆解原始命令变成一个工作流",
|
||||
available_skills=payload.available_skills
|
||||
)
|
||||
self.logger.debug("ConsciousnessNode: 开始生成工作流 (原生重试开启)")
|
||||
prompt = "根据original_command制定严密的可执行workflow"
|
||||
if payload.workflow_template:
|
||||
prompt += ",可以学习并参考workflow_template的设计理念"
|
||||
result = await self.agent.run(prompt, deps=deps)
|
||||
return result.output
|
||||
|
||||
elif isinstance(payload, ForWorkflowInput):
|
||||
deps = ConsciousnessNodeDeps(
|
||||
original_command=payload.original_command,
|
||||
command="完成workflow step中分配给意识节点的特定任务或指导"
|
||||
)
|
||||
self.logger.debug("ConsciousnessNode: 开始处理工作流节点任务 (原生重试开启)")
|
||||
result = await self.agent.run(f"处理此工作流步骤信息:\n{payload.workflow_step.model_dump_json()}",
|
||||
deps=deps)
|
||||
return result.output
|
||||
|
||||
elif isinstance(payload, ForSupervisoryInput):
|
||||
deps = ConsciousnessNodeDeps(
|
||||
original_command=payload.original_command,
|
||||
command="对于工作流整体执行结果进行检查,并且生成一份专业的技术性总结报告"
|
||||
)
|
||||
self.logger.debug("ConsciousnessNode: 开始生成技术总结报告 (原生重试开启)")
|
||||
result = await self.agent.run(f"基于以下工作流的执行记录,生成技术报告:\n{payload.workflow.model_dump_json()}",
|
||||
deps=deps)
|
||||
return result.output
|
||||
except Exception as e:
|
||||
self.logger.exception(f"ConsciousnessNode 模型生成最终失败: {str(e)}")
|
||||
raise RuntimeError(f"ConsciousnessNode 无法完成任务: {str(e)}") from e
|
||||
@@ -0,0 +1,67 @@
|
||||
# Copyright 2026 zhaoxi826
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
from pretor.core.workflow.workflow import PretorWorkflow, WorkStep
|
||||
from pretor.utils.agent_model import ResponseModel, DepsModel, InputModel
|
||||
from pydantic import Field
|
||||
|
||||
|
||||
#意识节点回复类
|
||||
class ConsciousnessNodeResponse(ResponseModel):
|
||||
"""Consciousness response model,是意识节点所有回复类型的父类"""
|
||||
pass
|
||||
|
||||
|
||||
class ForWorkflowEngine(ConsciousnessNodeResponse):
|
||||
"""生成workflow并放入WorkflowEngine"""
|
||||
workflow: PretorWorkflow = Field(..., description="生成好的符合规范的完整工作流对象。")
|
||||
reasoning: str = Field(..., description="生成此工作流的原因和思路简述。")
|
||||
|
||||
|
||||
class ForWorkflow(ConsciousnessNodeResponse):
|
||||
"""处理workflow中需要ConsciousnessNode的工作"""
|
||||
output: str = Field(..., description="对当前工作流步骤的具体处理结果或指导意见。")
|
||||
|
||||
|
||||
class ForSupervisoryNode(ConsciousnessNodeResponse):
|
||||
"""工作流完成后进行校验并返回给SupervisoryNode"""
|
||||
output: str = Field(..., description="为监控节点提供的全工作流执行情况的技术性总结报告。")
|
||||
|
||||
|
||||
class ConsciousnessNodeDeps(DepsModel):
|
||||
original_command: str
|
||||
workflow_template: str | None = None
|
||||
command: str
|
||||
available_skills: list[dict] | None = None
|
||||
|
||||
|
||||
class ConsciousnessNodeInput(InputModel):
|
||||
pass
|
||||
|
||||
|
||||
class ForWorkflowEngineInput(ConsciousnessNodeInput):
|
||||
workflow_template: str | None = None
|
||||
original_command: str
|
||||
available_skills: list[dict] | None = None
|
||||
|
||||
|
||||
class ForWorkflowInput(ConsciousnessNodeInput):
|
||||
workflow_step: WorkStep
|
||||
original_command: str
|
||||
|
||||
|
||||
class ForSupervisoryInput(ConsciousnessNodeInput):
|
||||
workflow: PretorWorkflow
|
||||
original_command: str
|
||||
@@ -0,0 +1,16 @@
|
||||
# Copyright 2026 zhaoxi826
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from .control_node import ControlNode
|
||||
__all__ = ["ControlNode"]
|
||||
@@ -0,0 +1,102 @@
|
||||
# Copyright 2026 zhaoxi826
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import ray
|
||||
from pydantic_ai import Agent, RunContext
|
||||
from pretor.core.global_state_machine.global_state_machine import GlobalStateMachine
|
||||
from pretor.core.global_state_machine.model_provider.base_provider import Provider
|
||||
from pretor.adapter.model_adapter.agent_factory import AgentFactory
|
||||
from pretor.core.individual.control_node.template import ForWorkflow, ForWorkflowInput, ControlNodeDeps
|
||||
|
||||
|
||||
|
||||
@ray.remote
|
||||
class ControlNode:
|
||||
def __init__(self):
|
||||
from pretor.utils.logger import get_logger
|
||||
self.logger = get_logger('control_node')
|
||||
self.agent: Agent | None = None
|
||||
|
||||
|
||||
async def create_agent(self, global_state_machine: GlobalStateMachine, provider_title: str, model_id: str, tools_list: list[str] = None) -> None:
|
||||
"""
|
||||
create_agent方法,将agent对象装配到Control的属性内
|
||||
该方法通过provider_title从global_state_machine中获取provider对象,然后从provider对象中取出供应商形象,装配为pydantic_ai的
|
||||
Agent实例,
|
||||
并挂载到self.agent属性
|
||||
Args:
|
||||
global_state_machine: 全局状态机
|
||||
provider_title: 供应商名
|
||||
model_id: 模型id
|
||||
|
||||
Returns:
|
||||
无返回
|
||||
"""
|
||||
system_prompt: str = (
|
||||
"你叫Pretor,是一个多智能体AI助手系统中的【控制节点 (Control Node)】。\n"
|
||||
"你是系统的'执行者'和'车间主任',专门负责执行工作流中分配给你的具体子任务。\n"
|
||||
"你的工作职责是:\n"
|
||||
"1. 仔细分析分配给你的工作流步骤 (workflow_step) 的目标和要求。\n"
|
||||
"2. 运用你被分配的工具 (如有) 或者依靠自身的知识和推理能力,精准、高效地完成该任务。\n"
|
||||
"3. 将执行的结果、产生的数据或者具体的输出,严格按照 ForWorkflow 格式返回。\n"
|
||||
"请注意:你的输出应当具体、实用,直接提供任务所要求的结果,不要做过多无关的寒暄。"
|
||||
)
|
||||
output_type = ForWorkflow
|
||||
from pretor.utils.get_tool import load_tools_from_list
|
||||
provider: Provider = await global_state_machine.get_provider.remote( provider_title)
|
||||
agent_factory = AgentFactory()
|
||||
|
||||
callables = load_tools_from_list(tools_list)
|
||||
self.agent = agent_factory.create_agent(provider=provider,
|
||||
model_id=model_id,
|
||||
output_type=output_type,
|
||||
system_prompt=system_prompt,
|
||||
deps_type=ControlNodeDeps,
|
||||
agent_name="control_node",
|
||||
tools=callables)
|
||||
@self.agent.system_prompt
|
||||
async def dynamic_prompt(ctx: RunContext[ControlNodeDeps]):
|
||||
prompt = system_prompt + "\n\n"
|
||||
prompt += (
|
||||
f"=== 当前任务步骤上下文 ===\n"
|
||||
f"- 步骤名称 (Name): {ctx.deps.workflow_step.name}\n"
|
||||
f"- 步骤目标/描述 (Description): {ctx.deps.workflow_step.desc}\n"
|
||||
f"- 前置输入(input): {ctx.deps.workflow_step.inputs}\n"
|
||||
)
|
||||
return prompt
|
||||
|
||||
async def working(self, payload: ForWorkflowInput) -> str:
|
||||
try:
|
||||
result: ForWorkflow = await self._run(payload)
|
||||
return result
|
||||
except Exception:
|
||||
self.logger.exception("ControlNode在执行working时发生严重错误")
|
||||
return None
|
||||
|
||||
async def _run(self, payload: ForWorkflowInput) -> ForWorkflow:
|
||||
try:
|
||||
self.agent.retries = 3
|
||||
deps = ControlNodeDeps(
|
||||
workflow_step=payload.workflow_step
|
||||
)
|
||||
self.logger.debug(f"ControlNode: 开始执行工作流节点 [{payload.workflow_step.name}] (原生重试开启)")
|
||||
|
||||
result = await self.agent.run(
|
||||
f"请根据提供的 workflow_step 上下文,执行此步骤并输出结果。\n详细指令或附加数据:{payload.workflow_step.model_dump_json()}",
|
||||
deps=deps
|
||||
)
|
||||
return result.output
|
||||
except Exception as e:
|
||||
self.logger.exception(f"ControlNode 在执行步骤 [{payload.workflow_step.name}] 时最终失败: {str(e)}")
|
||||
raise RuntimeError(f"ControlNode 执行步骤失败: {str(e)}") from e
|
||||
@@ -0,0 +1,39 @@
|
||||
# Copyright 2026 zhaoxi826
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
from pydantic import Field
|
||||
from pretor.core.workflow.workflow import WorkStep
|
||||
from pretor.utils.agent_model import ResponseModel, InputModel, DepsModel
|
||||
|
||||
class ControlNodeResponse(ResponseModel):
|
||||
"""控制节点回复的基类"""
|
||||
pass
|
||||
|
||||
|
||||
class ControlNodeInput(InputModel):
|
||||
pass
|
||||
|
||||
|
||||
class ControlNodeDeps(DepsModel):
|
||||
workflow_step: WorkStep
|
||||
# In the future, this can be dynamically populated with tools specific to the current task execution
|
||||
|
||||
|
||||
class ForWorkflow(ControlNodeResponse):
|
||||
output: str = Field(..., description="控制节点执行特定工作流步骤的结果。包含执行细节和输出数据。")
|
||||
|
||||
|
||||
class ForWorkflowInput(ControlNodeInput):
|
||||
workflow_step: WorkStep
|
||||
@@ -0,0 +1,14 @@
|
||||
# Copyright 2026 zhaoxi826
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
@@ -0,0 +1,14 @@
|
||||
# Copyright 2026 zhaoxi826
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
@@ -0,0 +1,16 @@
|
||||
# Copyright 2026 zhaoxi826
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from .supervisory_node import SupervisoryNode
|
||||
__all__ = ["SupervisoryNode"]
|
||||
@@ -0,0 +1,192 @@
|
||||
# Copyright 2026 zhaoxi826
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import datetime
|
||||
import ray
|
||||
from typing import Union, overload
|
||||
from pretor.api.platform.event import PretorEvent
|
||||
from pretor.adapter.model_adapter.agent_factory import AgentFactory
|
||||
from pretor.core.global_state_machine.global_state_machine import GlobalStateMachine
|
||||
from pretor.core.global_state_machine.model_provider import Provider
|
||||
from pretor.core.individual.supervisory_node.template import ForConsciousnessNode, ForUser, SupervisoryNodeDeps, TerminationMessage
|
||||
from pydantic_ai import RunContext, Agent
|
||||
from pretor.utils.ray_hook import ray_actor_hook
|
||||
|
||||
|
||||
@ray.remote
|
||||
class SupervisoryNode:
|
||||
def __init__(self) -> None:
|
||||
from pretor.utils.logger import get_logger
|
||||
self.logger = get_logger('supervisory_node')
|
||||
self.agent: None | Agent = None
|
||||
|
||||
|
||||
async def create_agent(self, global_state_machine: GlobalStateMachine, provider_title: str, model_id: str, tools_list: list[str] = None) -> None:
|
||||
"""
|
||||
create_agent方法,将agent对象装配到SupervisoryNode的属性内
|
||||
该方法通过provider_title从global_state_machine中获取provider对象,然后从provider对象中取出供应商形象,装配为pydantic_ai的Agent实例,
|
||||
并挂载到self.agent属性
|
||||
Args:
|
||||
global_state_machine: 全局状态机
|
||||
provider_title: 供应商名
|
||||
model_id: 模型id
|
||||
|
||||
Returns:
|
||||
无返回
|
||||
"""
|
||||
system_prompt: str = (
|
||||
"你叫Pretor,是一个多智能体AI助手系统中的【监控节点 (Supervisory Node)】。\n"
|
||||
"你是系统的'前台接待'和'大脑皮层',负责接收用户的初始请求或工作流的最终报告。\n"
|
||||
"你的核心职责是进行【意图识别与路由】。请仔细阅读用户的请求:\n"
|
||||
"1. 如果用户只是进行简单的问候、闲聊或查询非常基础的信息,请直接生成友好的回复,使用 ForUser 格式。\n"
|
||||
"2. 如果用户提出的是复杂任务(如需要编写代码、多步骤规划、数据处理等),请务必将其判定为需要工作流处理的任务,"
|
||||
" 并使用 ForConsciousnessNode 格式。若提供的【可用模板列表】中有合适的模板请选用,若都不匹配则 workflow_template 设为 null。\n"
|
||||
"3. 如果你收到的是 TerminationMessage(代表工作流已完成并生成了报告),请将报告内容转化为友好的面向用户的回复,使用 ForUser 格式。\n"
|
||||
"请保持冷静、专业,并严格遵循上述路由规则。"
|
||||
)
|
||||
output_type = Union[ForConsciousnessNode, ForUser]
|
||||
from pretor.utils.get_tool import load_tools_from_list
|
||||
provider: Provider = await global_state_machine.get_provider.remote( provider_title)
|
||||
agent_factory = AgentFactory()
|
||||
|
||||
callables = load_tools_from_list(tools_list)
|
||||
self.agent = agent_factory.create_agent(provider=provider,
|
||||
model_id=model_id,
|
||||
output_type=output_type,
|
||||
system_prompt=system_prompt,
|
||||
deps_type=SupervisoryNodeDeps,
|
||||
agent_name="supervisory_node",
|
||||
tools=callables)
|
||||
|
||||
@self.agent.system_prompt
|
||||
async def dynamic_prompt(ctx: RunContext[SupervisoryNodeDeps]):
|
||||
prompt = system_prompt + "\n\n"
|
||||
prompt += (
|
||||
f"=== 当前上下文 ===\n"
|
||||
f"- 平台 (Platform): {ctx.deps.platform}\n"
|
||||
f"- 用户名 (User): {ctx.deps.user_name}\n"
|
||||
f"- 当前时间 (Time): {ctx.deps.time}\n"
|
||||
f"- 可用工作流模板 (Available Templates): {ctx.deps.available_templates}\n"
|
||||
)
|
||||
# 修改 system_prompt 变量
|
||||
prompt += (
|
||||
"\n\n注意:你必须调用且只能调用一个函数(工具)来输出结果。"
|
||||
"如果你想直接回复用户,请调用 ForUser;"
|
||||
"如果你想移交给工作流,请调用 ForConsciousnessNode(若没有合适的模板,workflow_template 填 null)。"
|
||||
"严禁返回纯文本,必须使用工具格式!"
|
||||
)
|
||||
if ctx.deps.error_history:
|
||||
prompt += (
|
||||
f"\n=== 错误重试指示 ===\n"
|
||||
f"警告:前一次尝试失败,错误信息如下:\n{ctx.deps.error_history}\n"
|
||||
f"请务必修正该错误并按照要求的 Pydantic 格式输出。"
|
||||
)
|
||||
return prompt
|
||||
|
||||
###工作函数
|
||||
async def working(self, payload: Union[PretorEvent, TerminationMessage]) -> str:
|
||||
"""
|
||||
working方法,是节点唯一的调用方法,对于_run函数的结果进行判断并实现最终回复
|
||||
Args:
|
||||
payload: 消息载荷,包含所有信息
|
||||
|
||||
Returns:
|
||||
str,监控节点对于用户的回复
|
||||
"""
|
||||
try:
|
||||
result = await self._run(payload)
|
||||
if isinstance(result, ForConsciousnessNode):
|
||||
self.logger.info(f"SupervisoryNode: 任务已分配给工作流引擎处理,选用模板 [{result.workflow_template}]")
|
||||
if isinstance(payload, PretorEvent):
|
||||
payload.context["workflow_template"] = result.workflow_template
|
||||
try:
|
||||
global_state_machine = ray_actor_hook("global_state_machine").global_state_machine
|
||||
await global_state_machine.add_event.remote(payload)
|
||||
workflow_running_engine = ray_actor_hook("workflow_running_engine").workflow_running_engine
|
||||
await workflow_running_engine.put_event.remote(payload)
|
||||
except Exception as e:
|
||||
self.logger.error(f"SupervisoryNode: 无法将事件放入 WorkflowRunningEngine: {e}")
|
||||
return "抱歉,任务提交失败,系统内部错误。"
|
||||
return f"任务已创建,准备创建工作流。原因:{result.reasoning}"
|
||||
elif isinstance(result, ForUser):
|
||||
self.logger.info("SupervisoryNode: 直接向用户返回简单回复。")
|
||||
return result.context
|
||||
else:
|
||||
self.logger.error(f"SupervisoryNode: 未知响应类型: {type(result)}")
|
||||
return "抱歉,系统内部遇到未知错误,无法正确处理您的请求。"
|
||||
except Exception:
|
||||
self.logger.exception("SupervisoryNode在处理请求时发生未捕获的严重错误")
|
||||
return "抱歉,监控节点处理请求时发生严重错误,请联系管理员。"
|
||||
|
||||
@overload
|
||||
async def _run(self, payload: PretorEvent) -> Union[ForConsciousnessNode, ForUser]:
|
||||
"""
|
||||
_run方法
|
||||
Args:
|
||||
payload: PretorEvent的实例,是用户输入时对于消息的封装
|
||||
|
||||
Returns:
|
||||
ForUser对象,监控节点对于用户进行的简单回答
|
||||
ForConsciousnessNode对象,监控节点将用户的请求判断为复杂任务,将PretorEvent传递给意识节点,并且给选择好的工作流模板
|
||||
"""
|
||||
...
|
||||
|
||||
@overload
|
||||
async def _run(self, payload: TerminationMessage) -> ForUser:
|
||||
"""
|
||||
_run方法
|
||||
Args:
|
||||
payload: Termination的实例,是工作流结束后到达监控节点的最后结果
|
||||
|
||||
Returns:
|
||||
ForUser对象,工作流结束后给用户的返回
|
||||
"""
|
||||
...
|
||||
|
||||
async def _run(self, payload: Union[PretorEvent, TerminationMessage]) -> Union[ForConsciousnessNode, ForUser]:
|
||||
"""
|
||||
_run方法,将payload转化为对llm发送的消息并发送
|
||||
Args:
|
||||
payload: 消息载荷
|
||||
|
||||
Returns:
|
||||
ForConsciousnessNode对象,对意识节点发送的消息
|
||||
ForUser对象,对用户发送到消息
|
||||
"""
|
||||
platform = payload.platform
|
||||
user_name = payload.user_name
|
||||
message = payload.message
|
||||
time_str = datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
||||
try:
|
||||
global_state_machine = ray_actor_hook("global_state_machine").global_state_machine
|
||||
workflow_template_dict = await global_state_machine.get_all_workflow_templates.remote()
|
||||
available_templates_str = "\n".join([f"- 名称: {k}, 描述/内容: {v}" for k, v in
|
||||
workflow_template_dict.items()]) if workflow_template_dict else "暂无注册的工作流模板"
|
||||
deps = SupervisoryNodeDeps(
|
||||
platform=platform,
|
||||
user_name=user_name,
|
||||
time=time_str,
|
||||
available_templates=available_templates_str
|
||||
)
|
||||
self.logger.debug("SupervisoryNode 开始生成 (启用原生 Pydantic-AI 重试)")
|
||||
prompt_message = message
|
||||
if isinstance(payload, TerminationMessage):
|
||||
prompt_message = f"【工作流执行结束报告】\n请将以下技术报告转化为对用户的友好回复:\n{message}"
|
||||
self.agent.retries = 3
|
||||
result = await self.agent.run(prompt_message,
|
||||
deps=deps)
|
||||
return result.output
|
||||
except Exception as e:
|
||||
self.logger.exception(f"SupervisoryNode 模型生成或解析最终失败: {str(e)}")
|
||||
return ForUser(context="系统当前负载过高或遇到复杂内部错误,请稍后再试。")
|
||||
@@ -0,0 +1,40 @@
|
||||
# Copyright 2026 zhaoxi826
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from pydantic import Field
|
||||
from pretor.utils.agent_model import ResponseModel, DepsModel
|
||||
from pydantic import BaseModel
|
||||
|
||||
class SupervisoryNodeResponse(ResponseModel):
|
||||
pass
|
||||
|
||||
class ForUser(SupervisoryNodeResponse):
|
||||
context: str = Field(..., description="对用户的回复,应当使用和蔼的语气进行回复。用于直接解答简单问题或返回最终报告。")
|
||||
|
||||
class ForConsciousnessNode(SupervisoryNodeResponse):
|
||||
workflow_template: str | None = Field(default=None, description="选择的工作流模板的名称,用于处理复杂任务。若无需模板则为 None。")
|
||||
reasoning: str = Field(..., description="选择将任务移交意识节点并选用该模板的简短原因。")
|
||||
|
||||
class TerminationMessage(BaseModel):
|
||||
platform: str
|
||||
user_name: str
|
||||
message: str
|
||||
|
||||
class SupervisoryNodeDeps(DepsModel):
|
||||
platform: str
|
||||
user_name: str
|
||||
time: str
|
||||
retry_count: int = 0
|
||||
error_history: str = ""
|
||||
available_templates: str = "默认工作流 (default_workflow)"
|
||||
@@ -0,0 +1,14 @@
|
||||
# Copyright 2026 zhaoxi826
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
@@ -0,0 +1,85 @@
|
||||
# Copyright 2026 zhaoxi826
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import List, Optional, Union, Literal, Dict, Any
|
||||
from pydantic import BaseModel, Field, model_validator
|
||||
|
||||
from pretor.utils.logger import get_logger
|
||||
logger = get_logger('workflow')
|
||||
NodeType = Literal[
|
||||
"consciousness_node", "control_node", "supervisory_node", "skill_individual"
|
||||
]
|
||||
|
||||
class EventInfo(BaseModel):
|
||||
platform: str
|
||||
user_name: str
|
||||
|
||||
class LogicGate(BaseModel):
|
||||
if_fail: str = Field(..., description="失败跳转目标,如 'jump_to_step_1'")
|
||||
if_pass: Literal["continue", "exit"] = Field(default="continue", description="成功后的动作")
|
||||
|
||||
class WorkStep(BaseModel):
|
||||
step: int = Field(..., gt=0, description="步骤序号,严格自增")
|
||||
name: str = Field(..., description="步骤名称")
|
||||
node: NodeType = Field(..., description="负责执行的节点类型")
|
||||
action: str = Field(..., description="执行的原子动作")
|
||||
desc: str = Field(..., description="动作细节的自然语言描述,包含人工规范指导")
|
||||
inputs: Optional[Union[str, List[str]]] = Field(default=None, description="前置依赖输出")
|
||||
outputs: Optional[str] = Field(default=None, description="当前步骤产出物变量名")
|
||||
agent_id: Optional[str] = Field(default=None, description="分配给 skill_individual 的 Skill Individual 真实 agent_id,不可用名称代替")
|
||||
logic_gate: Optional[LogicGate] = Field(default=None, description="逻辑跳转控制")
|
||||
status: Literal["waiting", "running", "completed", "failed"] = Field(
|
||||
default="waiting",
|
||||
description="执行状态 (LLM建议保留默认值)"
|
||||
)
|
||||
|
||||
|
||||
class WorkflowStatus(BaseModel):
|
||||
step: int = Field(default=1, gt=0, description="当前运行到的工作流步数")
|
||||
status: Literal["waiting_llm_working", "waiting_tool_working", "llm_working", "tool_working"] = Field(
|
||||
default="waiting_llm_working",
|
||||
description="当前系统调度状态"
|
||||
)
|
||||
|
||||
class PretorWorkflow(BaseModel):
|
||||
title: str = Field(..., description="工作流的标题")
|
||||
work_link: List[WorkStep] = Field(..., description="工作链逻辑定义")
|
||||
# ---------------- 以下为系统级管控字段,LLM 无需关心 ---------------- #
|
||||
trace_id: str | None = Field(description="系统自动生成的追溯ID")
|
||||
version: str = Field(default="v1.0", description="系统协议版本号")
|
||||
command: Optional[str] = Field(default=None, description="触发此工作流的原始命令")
|
||||
output: Dict[str, Any] = Field(default_factory=dict, description="工作流最终产出结果")
|
||||
status: WorkflowStatus = Field(default_factory=WorkflowStatus, description="运行时状态对象")
|
||||
event_info: EventInfo | None = Field(default=None)
|
||||
context_memory: Dict[str, Any] = Field(default_factory=dict)
|
||||
|
||||
@model_validator(mode='after')
|
||||
def validate_workflow_integrity(self) -> 'PretorWorkflow':
|
||||
steps = [s.step for s in self.work_link]
|
||||
expected = list(range(1, len(steps) + 1))
|
||||
if steps != expected:
|
||||
raise ValueError(f"工作链步数不连续!期望 {expected},实际 {steps}")
|
||||
|
||||
max_step = len(steps)
|
||||
for s in self.work_link:
|
||||
if s.logic_gate and "jump_to_step_" in s.logic_gate.if_fail:
|
||||
try:
|
||||
target = int(s.logic_gate.if_fail.split("_")[-1])
|
||||
if target > max_step or target < 1:
|
||||
raise ValueError(f"Step {s.step} 的跳转目标 Step {target} 越界了!")
|
||||
except ValueError as e:
|
||||
if "越界" in str(e):
|
||||
raise e
|
||||
raise ValueError(f"LogicGate 格式错误: {s.logic_gate.if_fail}")
|
||||
return self
|
||||
@@ -0,0 +1,23 @@
|
||||
# workflow文档
|
||||
---
|
||||
- workflow(工作流)是作为pretor中运行任务的基本单位,workflow_manager管理整个workflow模块,包括生成workflow_template(工作流模板),生成workflow对象,和保存整个workflow_template表。
|
||||
- workflow_template是一个工作流模板,旨在由专业人士教导LLM如何编写工作流并进行任务,每个workflow_template都应该保存在 **pretor/workflow_pugin/** 文件夹下,保存格式为~_workflow_template.json,json格式为:
|
||||
|
||||
```json
|
||||
{
|
||||
"name": "",
|
||||
"desc": "",
|
||||
"work_link": [
|
||||
{
|
||||
"step": "",
|
||||
"node": "",
|
||||
"action": "",
|
||||
"desc": "",
|
||||
"input": [],
|
||||
"output": [],
|
||||
"logic_gate": {}
|
||||
}
|
||||
]
|
||||
}
|
||||
```
|
||||
- workflow_template将由监管节点挑选交给意识节点,意识节点按照参考模板生成标准的workflow对象,转交给pipeline开始执行任务链。
|
||||
@@ -0,0 +1,364 @@
|
||||
# Copyright 2026 zhaoxi826
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from pretor.utils.ray_hook import ray_actor_hook
|
||||
import ray
|
||||
import asyncio
|
||||
from pretor.core.workflow.workflow import PretorWorkflow, WorkStep, EventInfo
|
||||
from typing import Optional, Dict, Union, Any, List
|
||||
from pretor.utils.error import WorkflowError, WorkflowExit
|
||||
from pretor.api.platform.event import PretorEvent
|
||||
from pretor.core.individual.control_node.template import ForWorkflowInput as ControlForWorkflowInput, \
|
||||
ForWorkflow as ControlForWorkflow
|
||||
from pretor.core.individual.consciousness_node.template import (
|
||||
ForWorkflowInput as ConsciousnessForWorkflowInput,
|
||||
ForSupervisoryInput,
|
||||
ForSupervisoryNode,
|
||||
ForWorkflow as ConsciousnessForWorkflow,
|
||||
ForWorkflowEngineInput,
|
||||
ForWorkflowEngine
|
||||
)
|
||||
from pretor.core.individual.supervisory_node.template import TerminationMessage
|
||||
import pathlib
|
||||
|
||||
|
||||
def get_workflow_template(workflow_name: str) -> str:
|
||||
workflow_template = pathlib.Path(__file__).parent.parent.parent / "workflow_template" / (workflow_name + "_workflow_template.json")
|
||||
with open(workflow_template, "r", encoding="utf-8") as workflow_template_file:
|
||||
workflow_template = workflow_template_file.read()
|
||||
return workflow_template
|
||||
|
||||
|
||||
class WorkflowEngine:
|
||||
def __init__(self,
|
||||
workflow: PretorWorkflow,
|
||||
consciousness_node=None,
|
||||
control_node=None,
|
||||
supervisory_node=None):
|
||||
from pretor.utils.logger import get_logger
|
||||
self.logger = get_logger('workflow_runner')
|
||||
self.workflow: PretorWorkflow = workflow
|
||||
"""工作流:当前WorkflowEngine待执行的workflow"""
|
||||
self._steps_by_id: Dict[int, WorkStep] = {step.step: step for step in self.workflow.work_link}
|
||||
"""步骤表:将当前workflow的步骤序号和步骤内容存放"""
|
||||
self.consciousness_node = consciousness_node
|
||||
"""意识节点"""
|
||||
self.control_node = control_node
|
||||
"""控制节点"""
|
||||
self.supervisory_node = supervisory_node
|
||||
"""监督节点"""
|
||||
self._gsm = ray_actor_hook("global_state_machine").global_state_machine
|
||||
|
||||
async def _push_sse(self, msg: str) -> None:
|
||||
try:
|
||||
await self._gsm.put_pending.remote(self.workflow.trace_id, msg)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
def _prepare_inputs(self, inputs: Optional[Union[str, List[str]]]) -> Any:
|
||||
"""
|
||||
准备输入的方法
|
||||
Args:
|
||||
inputs: 待输入的名称
|
||||
|
||||
Returns:
|
||||
|
||||
"""
|
||||
match inputs:
|
||||
case None:
|
||||
return None
|
||||
case str(name):
|
||||
return self.workflow.context_memory.get(name)
|
||||
case list(names):
|
||||
return {k: self.workflow.context_memory.get(k) for k in names}
|
||||
|
||||
async def run(self):
|
||||
"""
|
||||
run方法
|
||||
处理并执行workflow的方法
|
||||
|
||||
"""
|
||||
self.logger.info(f"🚀 工作流引擎启动: {self.workflow.title} [Trace ID: {self.workflow.trace_id}]")
|
||||
await self._push_sse(f"[工作流启动] {self.workflow.title}")
|
||||
max_step = len(self.workflow.work_link)
|
||||
while 1 <= self.workflow.status.step <= max_step:
|
||||
current_step_id = self.workflow.status.step
|
||||
current_step = self._steps_by_id.get(current_step_id)
|
||||
if not current_step:
|
||||
self.logger.error(f"严重错误:找不到步骤 {current_step_id},工作流强制终止。")
|
||||
self.workflow.status.status = "failed"
|
||||
await self._push_sse(f"[工作流失败] 找不到步骤 {current_step_id}")
|
||||
break
|
||||
self.logger.info(f"▶️ 开始执行 Step {current_step_id}: [{current_step.node}] -> {current_step.action}")
|
||||
current_step.status = "running"
|
||||
await self._push_sse(f"[Step {current_step_id}] {current_step.name}: {current_step.desc}")
|
||||
try:
|
||||
step_input_data = self._prepare_inputs(current_step.inputs)
|
||||
step_result, is_success = await self._dispatch_to_node(current_step, step_input_data)
|
||||
if is_success:
|
||||
if current_step.outputs:
|
||||
self.workflow.context_memory[current_step.outputs] = step_result
|
||||
self.logger.debug(f"Step {current_step_id} 产出已保存至变量: '{current_step.outputs}'")
|
||||
current_step.status = "completed"
|
||||
await self._push_sse(f"[Step {current_step_id} 完成] {current_step.name}")
|
||||
else:
|
||||
self.logger.warning(f"Step {current_step_id} 执行遇到业务失败/驳回。")
|
||||
current_step.status = "failed"
|
||||
await self._push_sse(f"[Step {current_step_id} 失败] {current_step.name}")
|
||||
self._handle_logic_gate(current_step, is_success)
|
||||
except WorkflowExit:
|
||||
self.logger.info("命中 if_pass='exit',工作流被主动要求结束。")
|
||||
await self._push_sse("[工作流结束] 主动退出")
|
||||
break
|
||||
except WorkflowError as e:
|
||||
self.logger.error(f"{e},终止工作流。")
|
||||
self.workflow.status.status = "failed"
|
||||
await self._push_sse(f"[工作流失败] {e}")
|
||||
break
|
||||
except Exception as e:
|
||||
self.logger.error(f"❌ Step {current_step_id} 发生系统级未捕获异常: {e}", exc_info=True)
|
||||
current_step.status = "failed"
|
||||
self.workflow.status.status = "failed"
|
||||
await self._push_sse(f"[工作流异常] {e}")
|
||||
break
|
||||
self.logger.info(f"✅ 工作流 {self.workflow.title} 执行步骤结束。")
|
||||
self.workflow.output = self.workflow.context_memory
|
||||
await self._push_sse(f"[工作流完成] {self.workflow.title}")
|
||||
await self._report_results()
|
||||
|
||||
async def _report_results(self):
|
||||
"""
|
||||
结果汇报函数
|
||||
在工作流结束后执行
|
||||
Returns:
|
||||
|
||||
"""
|
||||
if self.workflow.status.status == "failed":
|
||||
self.logger.warning("工作流执行失败,跳过正常汇报流程。")
|
||||
return
|
||||
try:
|
||||
self.logger.info("开始生成工作流结束技术报告...")
|
||||
report = ""
|
||||
if self.consciousness_node:
|
||||
supervisory_input = ForSupervisoryInput(
|
||||
workflow=self.workflow,
|
||||
original_command=self.workflow.command or "未知命令"
|
||||
)
|
||||
report_obj = await self.consciousness_node.working.remote(supervisory_input)
|
||||
if isinstance(report_obj, ForSupervisoryNode):
|
||||
report = report_obj.output
|
||||
elif isinstance(report_obj, str):
|
||||
report = report_obj
|
||||
self.logger.debug(f"生成的报告摘要: {report[:100]}...")
|
||||
else:
|
||||
self.logger.warning("未提供 consciousness_node 句柄,跳过报告生成。")
|
||||
|
||||
if self.supervisory_node:
|
||||
term_msg = TerminationMessage(
|
||||
platform=self.workflow.event_info.platform,
|
||||
user_name=self.workflow.event_info.user_name,
|
||||
message=f"工作流执行完毕。系统报告:{report}"
|
||||
)
|
||||
user_response = await self.supervisory_node.working.remote(term_msg)
|
||||
self.workflow.context_memory["_final_user_response"] = user_response
|
||||
self.logger.info(f"Supervisory 最终回复:{user_response}")
|
||||
else:
|
||||
self.logger.warning("未提供 supervisory_node 句柄,跳过用户反馈生成。")
|
||||
except Exception:
|
||||
self.logger.exception("生成工作流执行汇报时发生错误")
|
||||
|
||||
async def _dispatch_to_node(self, step: WorkStep, input_data: Any) -> tuple[Any, bool]:
|
||||
"""
|
||||
分流器
|
||||
调用当前step的执行对象
|
||||
Args:
|
||||
step: WorkStep对象,当前需要执行的step
|
||||
input_data: 输入数据
|
||||
|
||||
Returns:
|
||||
返回llm的输出和一个bool类型的判断
|
||||
"""
|
||||
self.logger.debug(f"正在向 {step.node} 节点发送动作 {step.action}...")
|
||||
try:
|
||||
if step.node == "control_node":
|
||||
if not self.control_node:
|
||||
raise WorkflowError("未提供 control_node 句柄!")
|
||||
payload = ControlForWorkflowInput(workflow_step=step)
|
||||
# 可选:如果 input_data 需要合并,可以扩展 ControlForWorkflowInput 或将其放在 context_memory
|
||||
result_obj = await self.control_node.working.remote(payload)
|
||||
if isinstance(result_obj, ControlForWorkflow):
|
||||
return result_obj.output, True
|
||||
return result_obj, True
|
||||
|
||||
elif step.node == "consciousness_node":
|
||||
if not self.consciousness_node:
|
||||
raise WorkflowError("未提供 consciousness_node 句柄!")
|
||||
original_cmd = self.workflow.command or ""
|
||||
payload = ConsciousnessForWorkflowInput(
|
||||
workflow_step=step,
|
||||
original_command=original_cmd
|
||||
)
|
||||
result_obj = await self.consciousness_node.working.remote(payload)
|
||||
if isinstance(result_obj, ConsciousnessForWorkflow):
|
||||
return result_obj.output, True
|
||||
return result_obj, True
|
||||
|
||||
elif step.node == "skill_individual":
|
||||
self.logger.info(f"正在通过 WorkerCluster 调度 skill_individual 执行 {step.action}。")
|
||||
try:
|
||||
from pretor.utils.ray_hook import ray_actor_hook
|
||||
worker_cluster = ray_actor_hook("worker_cluster").worker_cluster
|
||||
task_id = f"{self.workflow.trace_id}_step_{step.step}"
|
||||
agent_id = step.agent_id or f"default_{step.node}"
|
||||
task_event = {
|
||||
"action": step.action,
|
||||
"description": step.desc,
|
||||
"input_data": input_data,
|
||||
"context_memory": self.workflow.context_memory
|
||||
}
|
||||
result_response = await worker_cluster.submit_task.remote(task_id, agent_id, task_event)
|
||||
|
||||
if result_response.get("success"):
|
||||
return result_response.get("data"), True
|
||||
else:
|
||||
self.logger.error(f"WorkerCluster 执行 {step.node} 失败: {result_response.get('error')}")
|
||||
return result_response.get("error"), False
|
||||
|
||||
except Exception as e:
|
||||
self.logger.exception(f"调度 WorkerCluster 执行 {step.node} 时发生异常: {e}")
|
||||
raise WorkflowError(f"WorkerCluster 调度异常: {e}")
|
||||
else:
|
||||
raise WorkflowError(f"未知的节点类型:{step.node}")
|
||||
|
||||
except Exception:
|
||||
self.logger.exception(f"节点 {step.node} 执行动作 {step.action} 失败")
|
||||
return None, False
|
||||
|
||||
def _handle_logic_gate(self, step: WorkStep, is_success: bool):
|
||||
"""
|
||||
状态机,检测任务执行情况
|
||||
Args:
|
||||
step: WorkStep对象,当前执行的step
|
||||
is_success: bool类型,当前步骤是否成功
|
||||
|
||||
|
||||
"""
|
||||
gate = step.logic_gate
|
||||
if is_success:
|
||||
if gate and gate.if_pass == "exit":
|
||||
raise WorkflowExit()
|
||||
self.workflow.status.step += 1
|
||||
else:
|
||||
if not gate or not gate.if_fail:
|
||||
raise WorkflowError(f"步骤 {step.step} 失败且未配置 if_fail 兜底方案")
|
||||
match gate.if_fail.split("_"):
|
||||
case ["jump", "to", "step", target] if target.isdigit():
|
||||
target_step = int(target)
|
||||
self.logger.warning(f"触发逻辑门分支!从 Step {step.step} 跳转至 Step {target_step}")
|
||||
self.workflow.status.step = target_step
|
||||
case _:
|
||||
raise WorkflowError(f"未知的 if_fail 格式: {gate.if_fail}")
|
||||
|
||||
|
||||
@ray.remote
|
||||
class WorkflowRunningEngine:
|
||||
def __init__(self, consciousness_node=None, control_node=None, supervisory_node=None):
|
||||
from pretor.utils.logger import get_logger
|
||||
self.logger = get_logger('workflow_runner')
|
||||
self.runner_engine = {}
|
||||
self.workflow_queue: asyncio.Queue[PretorEvent] = None
|
||||
self.consciousness_node = consciousness_node
|
||||
self.control_node = control_node
|
||||
self.supervisory_node = supervisory_node
|
||||
self.global_state_machine = None
|
||||
|
||||
async def run(self):
|
||||
# Move actor hook to async start so we don't race during __init__ across cluster
|
||||
self.global_state_machine = ray_actor_hook("global_state_machine").global_state_machine
|
||||
self.workflow_queue = asyncio.Queue()
|
||||
self.runner_engine = {
|
||||
f"runner_{i}": asyncio.create_task(self.runner(i))
|
||||
for i in range(10)
|
||||
}
|
||||
|
||||
async def put_event(self, event: PretorEvent) -> None:
|
||||
await self.workflow_queue.put(event)
|
||||
|
||||
async def runner(self, i: int) -> None:
|
||||
"""
|
||||
runner方法,从self.workflow_queue中不断取出任务并执行
|
||||
Args:
|
||||
i: runner序列号
|
||||
|
||||
"""
|
||||
while True:
|
||||
try:
|
||||
event = await self.workflow_queue.get()
|
||||
self.logger.info(f"WorkflowRunningEngine: runner_{i} 接收到事件 {event.trace_id} 准备生成工作流。")
|
||||
|
||||
if not self.consciousness_node:
|
||||
raise WorkflowError("未配置 consciousness_node,无法生成工作流")
|
||||
|
||||
workflow_template_name = event.context.get("workflow_template", "")
|
||||
workflow_template = get_workflow_template(workflow_template_name) if workflow_template_name else None
|
||||
|
||||
available_skills = None
|
||||
if self.global_state_machine:
|
||||
try:
|
||||
all_individuals = await self.global_state_machine.list_individuals.remote()
|
||||
available_skills = []
|
||||
for agent_id, config in all_individuals.items():
|
||||
if config.get("agent_type") == "skill_individual" or config.get("type") == "skill_individual":
|
||||
available_skills.append({
|
||||
"agent_id": agent_id,
|
||||
"name": config.get("agent_name", "Unknown"),
|
||||
"description": config.get("description", "")
|
||||
})
|
||||
except Exception as e:
|
||||
self.logger.warning(f"获取Skill Individual列表失败: {e}")
|
||||
|
||||
payload = ForWorkflowEngineInput(
|
||||
original_command=event.message,
|
||||
workflow_template=workflow_template,
|
||||
available_skills=available_skills
|
||||
)
|
||||
|
||||
result_obj = await self.consciousness_node.working.remote(payload)
|
||||
|
||||
if isinstance(result_obj, ForWorkflowEngine):
|
||||
workflow = result_obj.workflow
|
||||
|
||||
workflow.trace_id = event.trace_id
|
||||
workflow.command = event.message
|
||||
workflow.event_info = EventInfo(platform=event.platform,
|
||||
user_name=event.user_name,)
|
||||
|
||||
self.logger.info(
|
||||
f"WorkflowRunningEngine: runner_{i} 成功生成工作流 {workflow.trace_id}:{workflow.title}")
|
||||
|
||||
await self.global_state_machine.update_workflow.remote(event.trace_id, workflow)
|
||||
|
||||
workflow_engine = WorkflowEngine(workflow,
|
||||
self.consciousness_node,
|
||||
self.control_node,
|
||||
self.supervisory_node)
|
||||
await workflow_engine.run()
|
||||
else:
|
||||
self.logger.error(f"WorkflowRunningEngine: runner_{i} 无法生成工作流,返回类型为 {type(result_obj)}")
|
||||
|
||||
except asyncio.CancelledError:
|
||||
self.logger.info(f"WorkflowRunningEngine: runner_{i} 被取消。")
|
||||
raise
|
||||
except Exception as e:
|
||||
self.logger.error(f"WorkflowRunningEngine: runner_{i} 遇到未捕获的异常: {e}", exc_info=True)
|
||||
@@ -0,0 +1,14 @@
|
||||
# Copyright 2026 zhaoxi826
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
@@ -0,0 +1,39 @@
|
||||
# Copyright 2026 zhaoxi826
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from pydantic import BaseModel, model_validator
|
||||
from typing import Dict,List
|
||||
|
||||
class WorkflowTemplateStep(BaseModel):
|
||||
step: int
|
||||
node: str
|
||||
action: str
|
||||
desc: str
|
||||
input: List[str]
|
||||
output: List[str]
|
||||
logic_gate: Dict[str, str]
|
||||
|
||||
class WorkflowTemplate(BaseModel):
|
||||
name: str
|
||||
desc: str
|
||||
work_link: list[WorkflowTemplateStep]
|
||||
|
||||
@model_validator(mode='after')
|
||||
def validate_steps(self) -> 'WorkflowTemplate':
|
||||
steps = [s.step for s in self.work_link]
|
||||
if len(steps) != len(set(steps)):
|
||||
raise ValueError("Step numbers in work_link must be unique")
|
||||
return self
|
||||
|
||||
|
||||
@@ -0,0 +1,26 @@
|
||||
# Copyright 2026 zhaoxi826
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from pathlib import Path
|
||||
from pretor.core.workflow.workflow_template_generator.workflow_template import WorkflowTemplate
|
||||
|
||||
class WorkflowTemplateGenerator:
|
||||
@staticmethod
|
||||
def generate_workflow_template(workflow_template: WorkflowTemplate) -> WorkflowTemplate:
|
||||
output_dir = Path("pretor") / "workflow_template"
|
||||
if not output_dir.exists():
|
||||
output_dir.mkdir(parents=True)
|
||||
output_file = output_dir / f"{workflow_template.name}_workflow_template.json"
|
||||
with output_file.open("w", encoding="utf-8") as f:
|
||||
f.write(workflow_template.model_dump_json(indent=4))
|
||||
@@ -0,0 +1,56 @@
|
||||
# Copyright 2026 zhaoxi826
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import json
|
||||
from pretor.core.workflow.workflow_template_generator.workflow_template_generator import WorkflowTemplateGenerator
|
||||
from pathlib import Path
|
||||
from pretor.core.workflow.workflow_template_generator.workflow_template import WorkflowTemplate
|
||||
|
||||
from pretor.utils.logger import get_logger
|
||||
logger = get_logger('workflow_template_manager')
|
||||
|
||||
class WorkflowManager:
|
||||
def __init__(self):
|
||||
self.workflow_template_generator = WorkflowTemplateGenerator()
|
||||
self.workflow_templates_registry = {}
|
||||
self.template_path = Path("pretor/workflow_template")
|
||||
self._load_workflow_template()
|
||||
|
||||
def _load_workflow_template(self) -> None:
|
||||
for workflow_template_file in self.template_path.glob("*_workflow_template.json"):
|
||||
with workflow_template_file.open("r",encoding="utf-8") as f:
|
||||
try:
|
||||
workflow_template = json.load(f)
|
||||
self.workflow_templates_registry[workflow_template.get("name")] = workflow_template.get("desc")
|
||||
except json.decoder.JSONDecodeError:
|
||||
logger.warning(f"{workflow_template_file}不是json文件或格式错误")
|
||||
except KeyError:
|
||||
logger.warning(f"{workflow_template_file}不符合workflow_template格式")
|
||||
|
||||
def generate_workflow_template(self, workflow_template: WorkflowTemplate) -> None:
|
||||
try:
|
||||
workflow_template = self.workflow_template_generator.generate_workflow_template(workflow_template=workflow_template)
|
||||
self.workflow_templates_registry[workflow_template.name] = workflow_template.desc
|
||||
except Exception:
|
||||
logger.exception("Failed to generate workflow template")
|
||||
|
||||
def add_workflow_template(self, template_name: str, workflow_template: WorkflowTemplate) -> None:
|
||||
self.generate_workflow_template(workflow_template)
|
||||
|
||||
def get_all_workflow_templates(self) -> dict:
|
||||
return self.workflow_templates_registry
|
||||
|
||||
def delete_workflow_template(self, template_name: str) -> None:
|
||||
if template_name in self.workflow_templates_registry:
|
||||
del self.workflow_templates_registry[template_name]
|
||||
@@ -0,0 +1,13 @@
|
||||
# Copyright 2026 zhaoxi826
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
@@ -0,0 +1,2 @@
|
||||
from .approval import ApprovalToolData, approval
|
||||
__all__ = ["ApprovalToolData", "approval"]
|
||||
@@ -0,0 +1,39 @@
|
||||
# Copyright 2026 zhaoxi826
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from pretor.plugin.tool_plugin.base_tool import BaseToolData
|
||||
from pretor.utils.ray_hook import ray_actor_hook
|
||||
from typing import List, Literal, Dict
|
||||
|
||||
class ApprovalToolData(BaseToolData):
|
||||
is_system: bool = True
|
||||
action_scope: List[Literal["control_node", "consciousness_node", "supervisory_node", "growth_node", "", ""]] = [
|
||||
"control_node", "consciousness_node"]
|
||||
config_args: Dict[str, str] = {}
|
||||
|
||||
|
||||
async def approval(message: str, trace_id: str) -> str:
|
||||
"""
|
||||
当任务存在某些高风险操作或者计划需要让用户审批,发送请求给用户等待用户审批
|
||||
Args:
|
||||
message: 发送给用户的请求
|
||||
trace_id:
|
||||
|
||||
Returns:
|
||||
用户的审批结果
|
||||
"""
|
||||
actor_list = ray_actor_hook("global_state_machine")
|
||||
await actor_list.global_state_machine.put_pending.remote(trace_id, message)
|
||||
reply = await actor_list.global_state_machine.get_received.remote(trace_id)
|
||||
return reply
|
||||
@@ -0,0 +1,2 @@
|
||||
{
|
||||
}
|
||||
@@ -0,0 +1,23 @@
|
||||
# Copyright 2026 zhaoxi826
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from pydantic import BaseModel
|
||||
from typing import List, Literal, Dict
|
||||
from pydantic import ConfigDict
|
||||
|
||||
class BaseToolData(BaseModel):
|
||||
model_config = ConfigDict(extra="allow")
|
||||
is_system: bool
|
||||
action_scope: List[Literal["control_node", "consciousness_node", "supervisory_node", "growth_node", "", ""]] = []
|
||||
config_args: Dict[str, str] = {}
|
||||
@@ -0,0 +1,3 @@
|
||||
from .file_reader import FileReaderData, file_reader
|
||||
|
||||
__all__ = ["FileReaderData", "file_reader"]
|
||||
@@ -0,0 +1,43 @@
|
||||
# Copyright 2026 zhaoxi826
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from pydantic_ai import RunContext
|
||||
from pretor.plugin.tool_plugin.base_tool import BaseToolData
|
||||
import os
|
||||
|
||||
class FileReaderData(BaseToolData):
|
||||
is_system: bool = True
|
||||
name: str = "file_reader"
|
||||
description: str = "读取本地文件的内容"
|
||||
|
||||
def file_reader(ctx: RunContext, filepath: str) -> str:
|
||||
"""读取本地文件内容的工具。
|
||||
|
||||
Args:
|
||||
filepath: 目标文件的绝对路径或相对路径。
|
||||
|
||||
Returns:
|
||||
如果文件存在并可读,返回文件内容;否则返回错误信息。
|
||||
"""
|
||||
if not os.path.exists(filepath):
|
||||
return f"Error: 文件 {filepath} 不存在。"
|
||||
if not os.path.isfile(filepath):
|
||||
return f"Error: {filepath} 不是一个文件。"
|
||||
|
||||
try:
|
||||
with open(filepath, 'r', encoding='utf-8') as f:
|
||||
content = f.read()
|
||||
return content
|
||||
except Exception as e:
|
||||
return f"Error: 读取文件失败,原因:{str(e)}"
|
||||
@@ -0,0 +1,14 @@
|
||||
# Copyright 2026 zhaoxi826
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
@@ -0,0 +1,104 @@
|
||||
# Copyright 2026 zhaoxi826
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import jwt
|
||||
import os
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from typing import Optional
|
||||
from fastapi import HTTPException, status, Request
|
||||
from pydantic import BaseModel, ValidationError
|
||||
from pretor.core.database.table.user import User
|
||||
from pwdlib import PasswordHash
|
||||
|
||||
|
||||
class TokenData(BaseModel):
|
||||
user_id: str
|
||||
username: Optional[str] = None
|
||||
exp: Optional[int] = None
|
||||
|
||||
SECRET_KEY = os.getenv("SECRET_KEY")
|
||||
ALGORITHM = "HS256"
|
||||
ACCESS_TOKEN_EXPIRE_MINUTES = 60 * 24
|
||||
|
||||
password_hasher = PasswordHash.recommended()
|
||||
|
||||
|
||||
class Accessor:
|
||||
@staticmethod
|
||||
def _decode_token(token: str) -> TokenData:
|
||||
try:
|
||||
payload = jwt.decode(
|
||||
token,
|
||||
SECRET_KEY,
|
||||
algorithms=[ALGORITHM]
|
||||
)
|
||||
return TokenData(**payload)
|
||||
except jwt.ExpiredSignatureError:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Token 已过期",
|
||||
)
|
||||
except (jwt.InvalidTokenError, ValidationError):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="无效的认证凭证",
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _create_access_token(data: dict) -> str:
|
||||
to_encode = data.copy()
|
||||
expire = datetime.now(timezone.utc) + timedelta(minutes=ACCESS_TOKEN_EXPIRE_MINUTES)
|
||||
to_encode.update({"exp": int(expire.timestamp())})
|
||||
return jwt.encode(to_encode, SECRET_KEY, algorithm=ALGORITHM)
|
||||
|
||||
@staticmethod
|
||||
def verify_password(plain_password: str, hashed_password: str) -> bool:
|
||||
return password_hasher.verify(plain_password, hashed_password)
|
||||
|
||||
@staticmethod
|
||||
def get_current_user(request: Request) -> TokenData:
|
||||
auth_header = request.headers.get("Authorization")
|
||||
if not auth_header or not auth_header.startswith("Bearer "):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="未提供认证头部",
|
||||
)
|
||||
token = auth_header.split(" ")[1]
|
||||
return Accessor._decode_token(token)
|
||||
|
||||
@staticmethod
|
||||
def login_hashed_password(user: User, password: str) -> str:
|
||||
if not user:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="用户不存在",
|
||||
)
|
||||
if not Accessor.verify_password(password, user.hashed_password):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="用户名或密码错误",
|
||||
)
|
||||
token_payload = {
|
||||
"user_id": str(user.user_id),
|
||||
"username": user.user_name
|
||||
}
|
||||
return Accessor._create_access_token(data=token_payload)
|
||||
|
||||
@staticmethod
|
||||
def hash_password(password: str) -> str:
|
||||
if not password:
|
||||
raise ValueError("密码不能为空")
|
||||
if len(password) < 6:
|
||||
raise ValueError("密码长度不能小于 6 位")
|
||||
return password_hasher.hash(password)
|
||||
@@ -0,0 +1,25 @@
|
||||
# Copyright 2026 zhaoxi826
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
class ResponseModel(BaseModel):
|
||||
pass
|
||||
|
||||
class DepsModel(BaseModel):
|
||||
pass
|
||||
|
||||
class InputModel(BaseModel):
|
||||
pass
|
||||
@@ -0,0 +1,25 @@
|
||||
from rich.console import Console
|
||||
from rich.text import Text
|
||||
import yaml
|
||||
def print_banner() -> None:
|
||||
with open("config/config.yml","r") as config:
|
||||
config = yaml.load(config, Loader=yaml.FullLoader)
|
||||
version = config.get("version", "unknown")
|
||||
pretor_banner = """
|
||||
██████╗ ██████╗ ███████╗████████╗ ██████╗ ██████╗
|
||||
██╔══██╗██╔══██╗██╔════╝╚══██╔══╝██╔═══██╗██╔══██╗
|
||||
██████╔╝██████╔╝█████╗ ██║ ██║ ██║██████╔╝
|
||||
██╔═══╝ ██╔══██╗██╔══╝ ██║ ██║ ██║██╔══██╗
|
||||
██║ ██║ ██║███████╗ ██║ ╚██████╔╝██║ ██║
|
||||
╚═╝ ╚═╝ ╚═╝╚══════╝ ╚═╝ ╚═════╝ ╚═╝ ╚═╝
|
||||
"""
|
||||
console = Console()
|
||||
banner_colored = Text(pretor_banner, style="gold3 bold")
|
||||
console.print(banner_colored)
|
||||
console.print("=" * 40, style="dim") # dim=灰色,低调
|
||||
console.print("🚀 Multi-Agent Orchestration Platform", style="blue")
|
||||
console.print(f"📦 Version: {version}", style="green")
|
||||
console.print("👤 Author: zhaoxi826", style="yellow")
|
||||
console.print("📜 License: Apache 2.0", style="magenta")
|
||||
console.print("🐙 github: https://github.com/zhaoxi826/pretor", style="yellow")
|
||||
console.print("=" * 40, style="dim")
|
||||
@@ -0,0 +1,53 @@
|
||||
# Copyright 2026 zhaoxi826
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from typing import Annotated
|
||||
from fastapi import Depends, HTTPException
|
||||
from pretor.utils.access import Accessor, TokenData
|
||||
from pretor.core.database.table.user import UserAuthority
|
||||
from pretor.utils.ray_hook import ray_actor_hook
|
||||
|
||||
async def get_authority(user_id: str) -> UserAuthority:
|
||||
from pretor.utils.error import UserNotExistError
|
||||
postgres_database = ray_actor_hook("postgres_database").postgres_database
|
||||
try:
|
||||
user_authority = await postgres_database.get_user_authority.remote(user_id=user_id)
|
||||
return user_authority
|
||||
except UserNotExistError:
|
||||
raise HTTPException(
|
||||
status_code=401,
|
||||
detail="用户不存在或已被删除,请重新登录"
|
||||
)
|
||||
except Exception as e:
|
||||
# Check if it's a RayTaskError wrapping UserNotExistError
|
||||
if "UserNotExistError" in str(e):
|
||||
raise HTTPException(
|
||||
status_code=401,
|
||||
detail="用户不存在或已被删除,请重新登录"
|
||||
)
|
||||
raise
|
||||
|
||||
class RoleChecker:
|
||||
def __init__(self, **kwargs):
|
||||
self.allowed_roles = kwargs.get("allowed_roles", )
|
||||
|
||||
async def __call__(self,
|
||||
token_data: Annotated[TokenData, Depends(Accessor.get_current_user)]):
|
||||
user_authority = await get_authority(token_data.user_id)
|
||||
if user_authority < self.allowed_roles:
|
||||
raise HTTPException(
|
||||
status_code=403,
|
||||
detail={"message": f"User {token_data.user_id} does not have allowed roles"},
|
||||
)
|
||||
return token_data
|
||||
|
||||
@@ -0,0 +1,50 @@
|
||||
# Copyright 2026 zhaoxi826
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
class RetryableError(Exception):
|
||||
"""基类:所有可重试错误(如网络断开、抖动等临时性故障)"""
|
||||
pass
|
||||
|
||||
class NonRetryableError(Exception):
|
||||
"""基类:所有不可重试错误(如数据验证失败、类型错误等业务逻辑故障)"""
|
||||
pass
|
||||
|
||||
class DemandError(NonRetryableError):
|
||||
pass
|
||||
|
||||
class ModelNotExistError(Exception):
|
||||
pass
|
||||
|
||||
class UserError(Exception):
|
||||
pass
|
||||
|
||||
class UserNotExistError(UserError):
|
||||
pass
|
||||
|
||||
class UserPasswordError(UserError):
|
||||
pass
|
||||
|
||||
class ProviderError(Exception):
|
||||
pass
|
||||
|
||||
class ProviderNotExistError(ProviderError):
|
||||
pass
|
||||
|
||||
class WorkflowError(Exception):
|
||||
|
||||
pass
|
||||
|
||||
class WorkflowExit(WorkflowError):
|
||||
|
||||
pass
|
||||
@@ -0,0 +1,80 @@
|
||||
# Copyright 2026 zhaoxi826
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import importlib.util
|
||||
import os
|
||||
import sys
|
||||
from typing import Callable, Dict, List
|
||||
import pathlib
|
||||
from pretor.utils.ray_hook import ray_actor_hook
|
||||
|
||||
from pretor.utils.logger import get_logger
|
||||
logger = get_logger('get_tool')
|
||||
_tool_cache: Dict[str, Callable] = {}
|
||||
|
||||
|
||||
def _get_tool_func(tool_name: str) -> Callable | None:
|
||||
func = _tool_cache.get(tool_name, None)
|
||||
if func:
|
||||
return func
|
||||
|
||||
app_root = "/app"
|
||||
tool_plugin_dir = os.path.join(app_root, "pretor", "plugin", "tool_plugin", tool_name)
|
||||
|
||||
if not os.path.exists(tool_plugin_dir) or not os.path.isdir(tool_plugin_dir):
|
||||
logger.error(f"Tool directory not found: {tool_plugin_dir}")
|
||||
return None
|
||||
|
||||
init_file = os.path.join(tool_plugin_dir, "__init__.py")
|
||||
if not os.path.exists(init_file):
|
||||
logger.error(f"Tool init file not found: {init_file}")
|
||||
return None
|
||||
|
||||
try:
|
||||
module_name = f"pretor.plugin.tool_plugin.{tool_name}"
|
||||
spec = importlib.util.spec_from_file_location(module_name, init_file)
|
||||
if spec is None or spec.loader is None:
|
||||
logger.error(f"Failed to create spec for {module_name}")
|
||||
return None
|
||||
|
||||
module = importlib.util.module_from_spec(spec)
|
||||
sys.modules[module_name] = module
|
||||
spec.loader.exec_module(module)
|
||||
|
||||
func = getattr(module, tool_name, None)
|
||||
|
||||
if not callable(func):
|
||||
logger.error(f"Tool function '{tool_name}' not found or not callable in {module_name}")
|
||||
return None
|
||||
_tool_cache[tool_name] = func
|
||||
return func
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to load module {module_name}: {e}")
|
||||
return None
|
||||
|
||||
def del_tool_cache(tool_name: str) -> None:
|
||||
if tool_name in _tool_cache:
|
||||
del _tool_cache[tool_name]
|
||||
|
||||
def load_tools_from_list(tool_names: List[str] | None) -> List[Callable]:
|
||||
if not tool_names:
|
||||
return []
|
||||
|
||||
tool_list = []
|
||||
for tool_name in tool_names:
|
||||
tool_func = _get_tool_func(tool_name)
|
||||
if tool_func:
|
||||
tool_list.append(tool_func)
|
||||
|
||||
return tool_list
|
||||
@@ -0,0 +1,44 @@
|
||||
# Copyright 2026 zhaoxi826
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from loguru import logger
|
||||
from rich.logging import RichHandler
|
||||
from loguru._logger import Logger
|
||||
|
||||
def setup_logger() -> Logger:
|
||||
logger.remove()
|
||||
|
||||
def format_record(record):
|
||||
# Format string for rich handler
|
||||
actor = record["extra"].get("actor_name", "System")
|
||||
trace_id = record["extra"].get("trace_id", "")
|
||||
|
||||
trace_str = f" | trace_id:({trace_id})" if trace_id else ""
|
||||
return f"actor:({actor}){trace_str} : {record['message']}"
|
||||
|
||||
logger.configure(extra={"actor_name": "System", "trace_id": ""})
|
||||
|
||||
logger.add(
|
||||
RichHandler(rich_tracebacks=True, markup=True, show_time=False, show_level=False, show_path=False),
|
||||
format=format_record,
|
||||
level="DEBUG",
|
||||
enqueue=True, # 异步记录
|
||||
)
|
||||
|
||||
return logger
|
||||
|
||||
global_logger = setup_logger()
|
||||
|
||||
def get_logger(actor_name: str, trace_id: str = "") -> Logger:
|
||||
return global_logger.bind(actor_name=actor_name, trace_id=trace_id)
|
||||
@@ -0,0 +1,37 @@
|
||||
# Copyright 2026 zhaoxi826
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import Type, TypeVar
|
||||
from pydantic import BaseModel
|
||||
|
||||
T = TypeVar("T", bound=Type[BaseModel])
|
||||
|
||||
def pickle(cls: T) -> T:
|
||||
"""
|
||||
类装饰器pickle
|
||||
通过装饰继承了BaseModel的类,用pydantic的高效序列化替代python原生__reduce__魔术方法,实现ray在通讯时的高效序列化
|
||||
Args:
|
||||
cls: 继承了BaseModel类的类,需要被装饰的对象
|
||||
|
||||
Returns:
|
||||
返回被重写了__reduce__魔术方法的cls类
|
||||
"""
|
||||
def __reduce__(self):
|
||||
# 1. 序列化:触发 Pydantic-core (Rust) 的极速序列化
|
||||
data = self.model_dump_json()
|
||||
# 2. 反序列化:告诉 Pickle 重建时调用 cls.model_validate_json
|
||||
return cls.model_validate_json, (data,)
|
||||
cls.__reduce__ = __reduce__
|
||||
return cls
|
||||
|
||||
@@ -0,0 +1,50 @@
|
||||
# Copyright 2026 zhaoxi826
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import ray
|
||||
from functools import lru_cache
|
||||
|
||||
class ActorList:
|
||||
def __init__(self):
|
||||
super().__setattr__('dict', {})
|
||||
|
||||
def __setattr__(self, key, value):
|
||||
self.dict[key] = value
|
||||
|
||||
def __getattr__(self, key):
|
||||
if key in self.dict:
|
||||
return self.dict[key]
|
||||
raise AttributeError(f"ActorList 对象没有属性 '{key}'")
|
||||
|
||||
def __delattr__(self, key):
|
||||
if key in self.dict:
|
||||
del self.dict[key]
|
||||
else:
|
||||
raise AttributeError(f"ActorList对象没有属性 '{key}'")
|
||||
|
||||
@lru_cache(maxsize=128)
|
||||
def _get_cached_actor_handle(actor_name: str):
|
||||
"""缓存接口"""
|
||||
return ray.get_actor(actor_name, namespace="pretor")
|
||||
|
||||
def clear_actor_cache():
|
||||
"""清理接口"""
|
||||
_get_cached_actor_handle.cache_clear()
|
||||
|
||||
def ray_actor_hook(*actor_names: str):
|
||||
actor_list = ActorList()
|
||||
for actor_name in actor_names:
|
||||
handle = _get_cached_actor_handle(actor_name)
|
||||
setattr(actor_list, actor_name, handle)
|
||||
return actor_list
|
||||
|
||||
@@ -0,0 +1,31 @@
|
||||
|
||||
import asyncio
|
||||
from functools import wraps
|
||||
from pretor.utils.error import RetryableError
|
||||
|
||||
def retry_on_retryable_error(max_retries=3, base_delay=1):
|
||||
def decorator(func):
|
||||
if asyncio.iscoroutinefunction(func):
|
||||
@wraps(func)
|
||||
async def async_wrapper(*args, **kwargs):
|
||||
for attempt in range(max_retries):
|
||||
try:
|
||||
return await func(*args, **kwargs)
|
||||
except RetryableError:
|
||||
if attempt == max_retries - 1:
|
||||
raise
|
||||
await asyncio.sleep(base_delay * (2 ** attempt))
|
||||
return async_wrapper
|
||||
else:
|
||||
@wraps(func)
|
||||
def sync_wrapper(*args, **kwargs):
|
||||
import time
|
||||
for attempt in range(max_retries):
|
||||
try:
|
||||
return func(*args, **kwargs)
|
||||
except RetryableError:
|
||||
if attempt == max_retries - 1:
|
||||
raise
|
||||
time.sleep(base_delay * (2 ** attempt))
|
||||
return sync_wrapper
|
||||
return decorator
|
||||
@@ -0,0 +1,11 @@
|
||||
from pretor.worker_individual.base_individual import BaseIndividual
|
||||
from pretor.worker_individual.skill_individual import SkillIndividual
|
||||
from pretor.worker_individual.ordinary_individual import OrdinaryIndividual
|
||||
from pretor.worker_individual.special_individual import SpecialIndividual
|
||||
|
||||
__all__ = [
|
||||
"BaseIndividual",
|
||||
"SkillIndividual",
|
||||
"OrdinaryIndividual",
|
||||
"SpecialIndividual",
|
||||
]
|
||||
@@ -0,0 +1,76 @@
|
||||
# Copyright 2026 zhaoxi826
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from pydantic_ai import Agent, RunContext
|
||||
from pydantic import Field
|
||||
from pretor.adapter.model_adapter.agent_factory import AgentFactory
|
||||
from pretor.core.global_state_machine.model_provider.base_provider import Provider
|
||||
from pretor.utils.agent_model import ResponseModel, InputModel, DepsModel
|
||||
from pretor.utils.ray_hook import ray_actor_hook
|
||||
|
||||
from pretor.utils.logger import get_logger
|
||||
logger = get_logger('worker_individual')
|
||||
|
||||
class WorkerIndividualResponse(ResponseModel):
|
||||
output: str = Field(..., description="Worker执行任务的输出结果")
|
||||
|
||||
class WorkerIndividualDeps(DepsModel):
|
||||
task_event: dict
|
||||
|
||||
class WorkerIndividualInput(InputModel):
|
||||
task_event: dict
|
||||
|
||||
class BaseIndividual:
|
||||
"""
|
||||
Worker Individual 的基类
|
||||
"""
|
||||
|
||||
def __init__(self, agent_config: dict):
|
||||
self.agent_config = agent_config
|
||||
self.agent_id = agent_config.get("agent_id")
|
||||
self.agent: Agent | None = None
|
||||
|
||||
async def _init_agent(self, agent_name: str, system_prompt: str):
|
||||
from pretor.utils.get_tool import load_tools_from_list
|
||||
global_state_machine = ray_actor_hook("global_state_machine").global_state_machine
|
||||
provider_title = self.agent_config.get("provider_title", "openai") # default fallback
|
||||
model_id = self.agent_config.get("model_id", "gpt-4o") # default fallback
|
||||
tools_list = self.agent_config.get("tools", None)
|
||||
|
||||
provider: Provider = await global_state_machine.get_provider.remote( provider_title)
|
||||
agent_factory = AgentFactory()
|
||||
|
||||
callables = load_tools_from_list(tools_list)
|
||||
|
||||
self.agent = agent_factory.create_agent(
|
||||
provider=provider,
|
||||
model_id=model_id,
|
||||
output_type=WorkerIndividualResponse,
|
||||
system_prompt=system_prompt,
|
||||
deps_type=WorkerIndividualDeps,
|
||||
agent_name=agent_name,
|
||||
tools=callables
|
||||
)
|
||||
|
||||
@self.agent.system_prompt
|
||||
async def dynamic_prompt(ctx: RunContext[WorkerIndividualDeps]):
|
||||
prompt = system_prompt + "\n\n"
|
||||
prompt += (
|
||||
f"=== 当前任务上下文 ===\n"
|
||||
f"{ctx.deps.task_event}\n"
|
||||
)
|
||||
return prompt
|
||||
|
||||
async def run(self, task_event: dict) -> dict:
|
||||
raise NotImplementedError("子类必须实现 run 方法")
|
||||
@@ -0,0 +1,14 @@
|
||||
worker_individual
|
||||
---
|
||||
**worker_individual**是pretor中的基础工作对象,主要分为三类:**skill_individual**,**ordinary_individual**和**special_individual**,庞大的**worker_individual**将负责具体的生产工作。
|
||||
|
||||
---
|
||||
## worker_individual分类
|
||||
### skill_individual(专家子个体)
|
||||
**skill_individual(专家子个体)** 是拥有专业**skill**的agent,通常使用MoE(混合专家模型)或者大参数的专家模型来作为agent的模型。通过装配专业化的知识从而实现完成复杂任务。
|
||||
|
||||
### ordinary_individual(普通子个体)
|
||||
**ordinary_individual(普通子个体)** 是普通的agent,通常使用小参数微调专家模型来作为agent的模型。通过专业化数据的微调,在一定程度上实现比大参数MoE模型在单一方面上的能力。
|
||||
|
||||
### special_individual(特殊子个体)
|
||||
**special_individual(特殊子个体)** 是特殊的agent,这类agent一般不承担普通的生成任务,更多是实现一些特殊的任务,比如生成语音生成视频等。
|
||||
@@ -0,0 +1,43 @@
|
||||
# Copyright 2026 zhaoxi826
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from pretor.worker_individual.base_individual import BaseIndividual, WorkerIndividualDeps
|
||||
from pretor.utils.logger import get_logger
|
||||
|
||||
logger = get_logger('ordinary_individual')
|
||||
|
||||
class OrdinaryIndividual(BaseIndividual):
|
||||
"""
|
||||
普通子个体:普通的 agent。
|
||||
"""
|
||||
|
||||
def __init__(self, agent_config: dict):
|
||||
super().__init__(agent_config)
|
||||
|
||||
async def run(self, task_event: dict) -> dict:
|
||||
if self.agent is None:
|
||||
system_prompt = self.agent_config.get("prompt", "你是一个普通的AI助手,请尽力完成给定的任务。")
|
||||
await self._init_agent("ordinary_individual", system_prompt)
|
||||
|
||||
deps = WorkerIndividualDeps(task_event=task_event)
|
||||
self.agent.retries = 3
|
||||
try:
|
||||
result = await self.agent.run(
|
||||
f"请执行以下任务:\n{task_event}",
|
||||
deps=deps
|
||||
)
|
||||
return {"output": result.data.output}
|
||||
except Exception as e:
|
||||
logger.exception(f"OrdinaryIndividual {self.agent_id} 执行失败: {e}")
|
||||
raise
|
||||
@@ -0,0 +1,110 @@
|
||||
# Copyright 2026 zhaoxi826
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from pretor.worker_individual.base_individual import BaseIndividual, WorkerIndividualDeps
|
||||
from pretor.utils.logger import get_logger
|
||||
import os
|
||||
import json
|
||||
from pydantic_ai import Tool
|
||||
import importlib.util
|
||||
|
||||
logger = get_logger('skill_individual')
|
||||
|
||||
class SkillIndividual(BaseIndividual):
|
||||
"""
|
||||
专家子个体:拥有专业 skill 的 agent。
|
||||
"""
|
||||
|
||||
def __init__(self, agent_config: dict):
|
||||
super().__init__(agent_config)
|
||||
|
||||
async def _load_skill_tools(self):
|
||||
"""动态加载已绑定的 skill 工具。"""
|
||||
tools = []
|
||||
bound_skill = self.agent_config.get("bound_skill", "")
|
||||
# bound_skill can be string or dict {"skill_name": ["file1", "file2"]}
|
||||
skill_mapper = {}
|
||||
if isinstance(bound_skill, str) and bound_skill:
|
||||
try:
|
||||
skill_mapper = json.loads(bound_skill)
|
||||
except json.JSONDecodeError:
|
||||
pass
|
||||
elif isinstance(bound_skill, dict):
|
||||
skill_mapper = bound_skill
|
||||
|
||||
skill_base_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "plugin", "skill"))
|
||||
|
||||
for skill_name, _ in skill_mapper.items():
|
||||
skill_path = os.path.join(skill_base_dir, skill_name)
|
||||
metadata_path = os.path.join(skill_path, "metadata.json")
|
||||
if not os.path.exists(metadata_path):
|
||||
continue
|
||||
|
||||
try:
|
||||
with open(metadata_path, 'r', encoding='utf-8') as f:
|
||||
metadata = json.load(f)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to load metadata for skill {skill_name}: {e}")
|
||||
continue
|
||||
|
||||
if "functions" in metadata:
|
||||
for func_info in metadata["functions"]:
|
||||
# Ensure path is absolute
|
||||
script_path = func_info.get("file_path", "")
|
||||
if not os.path.isabs(script_path):
|
||||
script_path = os.path.join(skill_path, script_path)
|
||||
|
||||
if not os.path.exists(script_path):
|
||||
logger.warning(f"Skill script not found: {script_path}")
|
||||
continue
|
||||
|
||||
func_name = func_info.get("name")
|
||||
try:
|
||||
# Dynamically load the python module
|
||||
spec = importlib.util.spec_from_file_location(func_name, script_path)
|
||||
module = importlib.util.module_from_spec(spec)
|
||||
spec.loader.exec_module(module)
|
||||
|
||||
func = getattr(module, func_name)
|
||||
if callable(func):
|
||||
# Convert to PydanticAI Tool
|
||||
tool = Tool(func, name=func_name, description=func_info.get("docstring", ""))
|
||||
tools.append(tool)
|
||||
logger.info(f"Loaded skill tool: {func_name} from {skill_name}")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to load function {func_name} from {script_path}: {e}")
|
||||
|
||||
return tools
|
||||
|
||||
async def run(self, task_event: dict) -> dict:
|
||||
if self.agent is None:
|
||||
system_prompt = self.agent_config.get("prompt",
|
||||
"你是一个拥有专业技能的专家级AI助手,请利用你的专业知识完成给定的任务。")
|
||||
await self._init_agent("skill_individual", system_prompt)
|
||||
|
||||
deps = WorkerIndividualDeps(task_event=task_event)
|
||||
self.agent.retries = 3
|
||||
|
||||
tools = await self._load_skill_tools()
|
||||
|
||||
try:
|
||||
result = await self.agent.run(
|
||||
f"请执行以下任务:\n{task_event}",
|
||||
deps=deps,
|
||||
tools=tools if tools else None
|
||||
)
|
||||
return {"output": result.data.output}
|
||||
except Exception as e:
|
||||
logger.exception(f"SkillIndividual {self.agent_id} 执行失败: {e}")
|
||||
raise
|
||||
@@ -0,0 +1,43 @@
|
||||
# Copyright 2026 zhaoxi826
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from pretor.worker_individual.base_individual import BaseIndividual, WorkerIndividualDeps
|
||||
from pretor.utils.logger import get_logger
|
||||
|
||||
logger = get_logger('special_individual')
|
||||
|
||||
class SpecialIndividual(BaseIndividual):
|
||||
"""
|
||||
特殊子个体:执行特殊任务的 agent,如生成语音、视频等。
|
||||
"""
|
||||
|
||||
def __init__(self, agent_config: dict):
|
||||
super().__init__(agent_config)
|
||||
|
||||
async def run(self, task_event: dict) -> dict:
|
||||
if self.agent is None:
|
||||
system_prompt = self.agent_config.get("prompt", "你是一个特殊的AI助手,负责处理特殊类型的任务。")
|
||||
await self._init_agent("special_individual", system_prompt)
|
||||
|
||||
deps = WorkerIndividualDeps(task_event=task_event)
|
||||
self.agent.retries = 3
|
||||
try:
|
||||
result = await self.agent.run(
|
||||
f"请执行以下任务:\n{task_event}",
|
||||
deps=deps
|
||||
)
|
||||
return {"output": result.data.output}
|
||||
except Exception as e:
|
||||
logger.exception(f"SpecialIndividual {self.agent_id} 执行失败: {e}")
|
||||
raise
|
||||
@@ -0,0 +1,148 @@
|
||||
# Copyright 2026 zhaoxi826
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import ray
|
||||
import time
|
||||
import asyncio
|
||||
from collections import OrderedDict
|
||||
from ray.util.queue import Queue
|
||||
from pretor.utils.ray_hook import ray_actor_hook
|
||||
from pretor.worker_individual.base_individual import BaseIndividual
|
||||
from pretor.worker_individual.skill_individual import SkillIndividual
|
||||
from pretor.worker_individual.ordinary_individual import OrdinaryIndividual
|
||||
from pretor.worker_individual.special_individual import SpecialIndividual
|
||||
|
||||
|
||||
from pretor.utils.logger import get_logger
|
||||
|
||||
|
||||
@ray.remote
|
||||
class WorkerCluster:
|
||||
"""
|
||||
工作集群 Actor:管理和调度所有的 worker_individual
|
||||
设计理念:按需加载,内存 LRU 淘汰,避免 Actor 爆炸
|
||||
"""
|
||||
|
||||
def __init__(self, max_capacity: int = 200, num_runners: int = 10):
|
||||
self.max_capacity = max_capacity
|
||||
self._active_workers: OrderedDict[str, BaseIndividual] = OrderedDict()
|
||||
self.status = "running"
|
||||
self.task_queue = None
|
||||
self.results_futures = {}
|
||||
self.runners = []
|
||||
self.num_runners = num_runners
|
||||
self.logger = get_logger('worker_cluster')
|
||||
|
||||
async def start(self):
|
||||
if self.task_queue is None:
|
||||
self.task_queue = Queue()
|
||||
self.runners = [asyncio.create_task(self._runner(i)) for i in range(self.num_runners)]
|
||||
self.logger.info(f"WorkerCluster 已启动 {self.num_runners} 个 runner 协程。")
|
||||
|
||||
async def _recruit_worker(self, agent_id: str) -> BaseIndividual:
|
||||
"""内部方法:招聘/唤醒一个具体的 Agent 对象"""
|
||||
if agent_id in self._active_workers:
|
||||
self._active_workers.move_to_end(agent_id)
|
||||
return self._active_workers[agent_id]
|
||||
|
||||
global_state_machine = ray_actor_hook("global_state_machine").global_state_machine
|
||||
agent_config = await global_state_machine.get_individual.remote( agent_id)
|
||||
|
||||
if not agent_config:
|
||||
raise ValueError(f"无法唤醒 Agent {agent_id}:数据库中不存在该档案")
|
||||
|
||||
worker_type = agent_config.get("type", "ordinary")
|
||||
if worker_type == "skill":
|
||||
worker = SkillIndividual(agent_config)
|
||||
elif worker_type == "special":
|
||||
worker = SpecialIndividual(agent_config)
|
||||
else:
|
||||
worker = OrdinaryIndividual(agent_config)
|
||||
|
||||
self._active_workers[agent_id] = worker
|
||||
if len(self._active_workers) > self.max_capacity:
|
||||
evicted_id, _ = self._active_workers.popitem(last=False)
|
||||
self.logger.info(f"[WorkerCluster] 内存池满,休眠老化 Agent: {evicted_id}")
|
||||
|
||||
return worker
|
||||
|
||||
async def _runner(self, runner_id: int):
|
||||
while True:
|
||||
try:
|
||||
if self.task_queue is None:
|
||||
await asyncio.sleep(0.1)
|
||||
continue
|
||||
task = await self.task_queue.get_async()
|
||||
task_id = task.get("task_id")
|
||||
agent_id = task.get("agent_id")
|
||||
task_event = task.get("task_event")
|
||||
|
||||
self.logger.debug(f"[WorkerCluster Runner {runner_id}] 开始处理任务 {task_id} 给 Agent {agent_id}")
|
||||
start_time = time.time()
|
||||
|
||||
try:
|
||||
worker = await self._recruit_worker(agent_id)
|
||||
result = await worker.run(task_event)
|
||||
cost_time = time.time() - start_time
|
||||
|
||||
response = {
|
||||
"success": True,
|
||||
"agent_id": agent_id,
|
||||
"data": result,
|
||||
"metrics": {"cost_time_sec": round(cost_time, 2)}
|
||||
}
|
||||
except Exception as e:
|
||||
self.logger.exception(f"[WorkerCluster Runner {runner_id}] 执行任务 {task_id} 时发生错误: {e}")
|
||||
response = {
|
||||
"success": False,
|
||||
"agent_id": agent_id,
|
||||
"error": str(e)
|
||||
}
|
||||
if task_id in self.results_futures:
|
||||
future = self.results_futures[task_id]
|
||||
if not future.done():
|
||||
future.set_result(response)
|
||||
|
||||
except Exception as e:
|
||||
self.logger.error(f"[WorkerCluster Runner {runner_id}] 循环发生异常: {e}")
|
||||
await asyncio.sleep(1)
|
||||
|
||||
async def submit_task(self, task_id: str, agent_id: str, task_event: dict):
|
||||
if not self.runners:
|
||||
await self.start()
|
||||
|
||||
future = asyncio.Future()
|
||||
self.results_futures[task_id] = future
|
||||
|
||||
task = {
|
||||
"task_id": task_id,
|
||||
"agent_id": agent_id,
|
||||
"task_event": task_event
|
||||
}
|
||||
await self.task_queue.put_async(task)
|
||||
self.logger.debug(f"[WorkerCluster] 任务 {task_id} 已加入队列。")
|
||||
|
||||
try:
|
||||
result = await future
|
||||
return result
|
||||
finally:
|
||||
self.results_futures.pop(task_id, None)
|
||||
|
||||
def get_cluster_metrics(self):
|
||||
return {
|
||||
"active_worker_count": len(self._active_workers),
|
||||
"max_capacity": self.max_capacity,
|
||||
"cached_agent_ids": list(self._active_workers.keys()),
|
||||
"queue_size": self.task_queue.size()
|
||||
}
|
||||
@@ -0,0 +1,82 @@
|
||||
{
|
||||
"name": "programme",
|
||||
"desc": "一个示范型的编程工作流",
|
||||
"work_link": [
|
||||
{
|
||||
"step": 1,
|
||||
"node": "consciousness_node",
|
||||
"action": "architect",
|
||||
"desc": "【人类规范】分析用户需求,构建程序整体架构,定义需要拉起的子个体名称与数量。"
|
||||
},
|
||||
{
|
||||
"step": 2,
|
||||
"node": "control_node",
|
||||
"action": "spawn_actors",
|
||||
"desc": "【人类规范】根据架构要求,拉起对应的开发与测试工作组,并挂载 /workspace 目录。"
|
||||
},
|
||||
{
|
||||
"step": 3,
|
||||
"node": "composite_individual",
|
||||
"action": "decompose",
|
||||
"desc": "【人类规范】将整体架构拆解为可独立执行的原子任务包 (Task Packets)。",
|
||||
"output": "task_packets"
|
||||
},
|
||||
{
|
||||
"step": 4,
|
||||
"node": "primary_individual",
|
||||
"action": "execute_code",
|
||||
"desc": "【人类规范】执行编码任务,必须确保所有代码写入指定的挂载目录。",
|
||||
"input": "task_packets",
|
||||
"output": "source_code"
|
||||
},
|
||||
{
|
||||
"step": 5,
|
||||
"node": "composite_individual",
|
||||
"action": "audit",
|
||||
"desc": "【人类规范】对产出的源码进行静态逻辑检查与 PEP8 代码规范审计。",
|
||||
"input": "source_code",
|
||||
"output": "audit_report"
|
||||
},
|
||||
{
|
||||
"step": 6,
|
||||
"node": "control_node",
|
||||
"action": "resource_recycle",
|
||||
"desc": "【安全规范】暂存当前编码子个体的状态,释放非必要显存,为测试环境腾出算力。",
|
||||
"input": "audit_report"
|
||||
},
|
||||
{
|
||||
"step": 7,
|
||||
"node": "consciousness_node",
|
||||
"action": "design_test",
|
||||
"desc": "【人类规范】基于源码设计测试用例架构,覆盖边缘场景。",
|
||||
"input": "source_code",
|
||||
"output": "test_spec"
|
||||
},
|
||||
{
|
||||
"step": 8,
|
||||
"node": "primary_individual",
|
||||
"action": "run_test",
|
||||
"desc": "【人类规范】在独立的 Docker 沙箱中运行 test,并生成结构化的实验报告。",
|
||||
"input": "test_spec",
|
||||
"output": "test_report"
|
||||
},
|
||||
{
|
||||
"step": 9,
|
||||
"node": "consciousness_node",
|
||||
"action": "analyze_report",
|
||||
"desc": "【逻辑网关】研究测试报告。如果存在 Error 或 Fail,必须触发逻辑跳转,重写代码。",
|
||||
"input": "test_report",
|
||||
"logic_gate": {
|
||||
"if_fail": "jump_to_step_4",
|
||||
"if_pass": "continue"
|
||||
}
|
||||
},
|
||||
{
|
||||
"step": 10,
|
||||
"node": "supervisory_node",
|
||||
"action": "terminate_workflow",
|
||||
"desc": "【系统规范】核对所有产出物,关闭工作流管道,向宿主机发送 .done 信号。",
|
||||
"input": ["source_code", "test_report"]
|
||||
}
|
||||
]
|
||||
}
|
||||
Reference in New Issue
Block a user