diff --git a/frontend/src/components/Plugin/PluginLayout.tsx b/frontend/src/components/Plugin/PluginLayout.tsx index afcb8fe..223b7da 100644 --- a/frontend/src/components/Plugin/PluginLayout.tsx +++ b/frontend/src/components/Plugin/PluginLayout.tsx @@ -1,6 +1,5 @@ import { SkillSettings } from './SkillSettings'; import { ToolSettings } from './ToolSettings'; -import { WorkflowTemplateSettings } from './WorkflowTemplateSettings'; interface PluginLayoutProps { resourceTab: string; @@ -20,14 +19,6 @@ export function PluginLayout({ resourceTab, setResourceTab }: PluginLayoutProps) > Skills - setResourceTab('workflow_template')} - className={`py-4 text-sm font-medium border-b-2 transition-colors ${ - resourceTab === 'workflow_template' ? 'border-blue-600 text-blue-600' : 'border-transparent text-slate-500 hover:text-slate-800' - }`} - > - Workflow Templates - setResourceTab('tool')} className={`py-4 text-sm font-medium border-b-2 transition-colors ${ @@ -41,7 +32,6 @@ export function PluginLayout({ resourceTab, setResourceTab }: PluginLayoutProps) {/* Main Content */} {resourceTab === 'skill' && } - {resourceTab === 'workflow_template' && } {resourceTab === 'tool' && } diff --git a/frontend/src/components/Plugin/WorkflowTemplateSettings.tsx b/frontend/src/components/Plugin/WorkflowTemplateSettings.tsx deleted file mode 100644 index eb00a0a..0000000 --- a/frontend/src/components/Plugin/WorkflowTemplateSettings.tsx +++ /dev/null @@ -1,163 +0,0 @@ -import { useState, useEffect } from 'react'; -import apiClient from '../../api/client'; -import { FileCode, Trash2, Plus, LayoutTemplate } from 'lucide-react'; - -import type { WorkflowTemplate as ParsedWorkflowTemplate } from '../../types'; - -export function WorkflowTemplateSettings() { - const [templates, setTemplates] = useState>({}); - const [loading, setLoading] = useState(true); - const [templateJson, setTemplateJson] = useState('{\n "name": "my_template",\n "steps": [\n {\n "name": "step1",\n "actor": "actor_name"\n }\n ]\n}'); - const [creating, setCreating] = useState(false); - const [message, setMessage] = useState(''); - const [error, setError] = useState(''); - - const validateTemplate = (data: any): data is ParsedWorkflowTemplate => { - if (!data || typeof data !== 'object') return false; - if (typeof data.name !== 'string') return false; - if (!Array.isArray(data.steps)) return false; - for (const step of data.steps) { - if (typeof step.name !== 'string') return false; - if (typeof step.actor !== 'string') return false; - } - return true; - }; - - const fetchTemplates = async () => { - setLoading(true); - try { - const response = await apiClient.get('/api/v1/resource/workflow_template'); - setTemplates(response.data.templates || {}); - } catch (err) { - console.error('Failed to fetch templates:', err); - } finally { - setLoading(false); - } - }; - - useEffect(() => { - fetchTemplates(); - }, []); - - const handleCreate = async (e: React.FormEvent) => { - e.preventDefault(); - setCreating(true); - setMessage(''); - setError(''); - - try { - const parsedJson = JSON.parse(templateJson); - - if (!validateTemplate(parsedJson)) { - throw new Error('JSON structure does not match WorkflowTemplate schema (requires name and steps array with name and actor).'); - } - - await apiClient.post('/api/v1/resource/workflow_template', parsedJson); - setMessage('Workflow template created successfully'); - setTemplateJson('{\n "name": "my_template",\n "steps": []\n}'); - fetchTemplates(); - } catch (err: any) { - console.error(err); - if (err instanceof SyntaxError) { - setError('Invalid JSON format'); - } else { - setError(err.message || err.response?.data?.message || 'Failed to create workflow template'); - } - } finally { - setCreating(false); - } - }; - - const handleDelete = async (templateName: string) => { - if (!confirm(`Are you sure you want to delete ${templateName}?`)) return; - try { - await apiClient.delete(`/api/v1/resource/workflow_template/${templateName}`); - fetchTemplates(); - } catch (err: any) { - console.error('Failed to delete template:', err); - alert('Failed to delete template'); - } - }; - - return ( - - - Workflow Templates - Manage and create reusable workflow templates. - - - - - - - - - Create Template - Provide the JSON definition for a new workflow template. - - - - - - Template JSON Definition - setTemplateJson(e.target.value)} - className="w-full px-4 py-2 border border-slate-200 rounded-lg focus:outline-none focus:ring-2 focus:ring-blue-500 font-mono text-sm" - /> - - - {message && {message}} - {error && {error}} - - - - - {creating ? 'Creating...' : 'Create Template'} - - - - - - - - - Available Templates - - - {loading ? ( - Loading templates... - ) : Object.keys(templates).length === 0 ? ( - No workflow templates created yet. - ) : ( - - {Object.keys(templates).map((name) => ( - - - - - - {name} - - handleDelete(name)} - className="p-2 text-slate-400 hover:text-red-500 hover:bg-red-50 rounded-lg transition-colors" - title="Delete Template" - > - - - - ))} - - )} - - - - ); -} diff --git a/main.py b/main.py index 7f9f894..0b6a24d 100644 --- a/main.py +++ b/main.py @@ -1,14 +1,14 @@ import asyncio import ray -from pretor.worker_individual.worker_cluster import WorkerCluster +from pretor.worker_cluster import WorkerCluster from pretor.utils.banner import print_banner -from pretor.core.database.postgres import PostgresDatabase -from pretor.core.global_state_machine.global_state_machine import GlobalStateMachine -from pretor.core.global_state_machine.global_workflow_manager import GlobalWorkflowManager -from pretor.core.individual.supervisory_node.supervisory_node import SupervisoryNode -from pretor.core.individual.consciousness_node.consciousness_node import ConsciousnessNode -from pretor.core.individual.control_node.control_node import ControlNode -from pretor.core.workflow.workflow_runner import WorkflowRunningEngine +from pretor.core.postgres_database import PostgresDatabase +from pretor.core.global_state_machine import GlobalStateMachine +from pretor.core.global_workflow_manager import GlobalWorkflowManager +from pretor.core.individual.supervisory_node import SupervisoryNode +from pretor.core.individual.consciousness_node import ConsciousnessNode +from pretor.core.individual.control_node import ControlNode +from pretor.core.workflow_running_engine import WorkflowRunningEngine from pretor.api import PretorGateway from ray import serve import os @@ -18,7 +18,10 @@ _secret_key = os.getenv("SECRET_KEY") if not _secret_key or _secret_key in {"secret", "114514"}: _secret_key = secrets.token_urlsafe(32) os.environ["SECRET_KEY"] = _secret_key - print("⚠️ 警告: 未提供有效的 SECRET_KEY 或使用了不安全的默认值,已生成并设置随机密钥。") + print( + "⚠️ 警告: 未提供有效的 SECRET_KEY 或使用了不安全的默认值,已生成并设置随机密钥。" + ) + async def start_system(): env_vars = { @@ -30,21 +33,20 @@ async def start_system(): "SECRET_KEY": os.getenv("SECRET_KEY"), } - ray.init(ignore_reinit_error=True, - namespace="pretor", - dashboard_host="0.0.0.0", - dashboard_port=8265, - runtime_env={"env_vars": env_vars}) - + ray.init( + ignore_reinit_error=True, + namespace="pretor", + dashboard_host="0.0.0.0", + dashboard_port=8265, + runtime_env={"env_vars": env_vars}, + ) # 2. 启动数据库组件 - postgres_database = PostgresDatabase.options(name='postgres_database').remote() + postgres_database = PostgresDatabase.options(name="postgres_database").remote() await postgres_database.init_db.remote() global_state_machine = GlobalStateMachine.options( - name='global_state_machine', - namespace='pretor', - lifetime='detached' + name="global_state_machine", namespace="pretor", lifetime="detached" ).remote(postgres_database) print("正在等待 GlobalStateMachine 初始化并加载注册表...") @@ -58,29 +60,29 @@ async def start_system(): return global_workflow_manager = GlobalWorkflowManager.options( - name='global_workflow_manager', - namespace='pretor', - lifetime='detached' + name="global_workflow_manager", namespace="pretor", lifetime="detached" ).remote() # 4. 启动核心节点 - supervisory_node = SupervisoryNode.options(name='supervisory_node').remote() - consciousness_node = ConsciousnessNode.options(name='consciousness_node').remote() - control_node = ControlNode.options(name='control_node').remote() + supervisory_node = SupervisoryNode.options(name="supervisory_node").remote() + consciousness_node = ConsciousnessNode.options(name="consciousness_node").remote() + control_node = ControlNode.options(name="control_node").remote() try: WorkerCluster.options( name="worker_cluster", - lifetime="detached" # 保证它在后台一直运行 + lifetime="detached", # 保证它在后台一直运行 ).remote() print("✅ WorkerCluster 已成功启动并注册!") except ValueError: print("WorkerCluster 已经存在。") # 5. 启动工作流运行引擎 - workflow_engine = WorkflowRunningEngine.options(name='workflow_running_engine').remote( + workflow_engine = WorkflowRunningEngine.options( + name="workflow_running_engine" + ).remote( consciousness_node=consciousness_node, control_node=control_node, - supervisory_node=supervisory_node + supervisory_node=supervisory_node, ) # 异步拉起 runner 协程群 workflow_engine.run.remote() @@ -110,5 +112,5 @@ def main(): print("系统已退出。") -if __name__ == '__main__': - main() \ No newline at end of file +if __name__ == "__main__": + main() diff --git a/pretor/adapter/model_adapter/agent_factory.py b/pretor/adapter/model_adapter/agent_factory.py index ca3c9d4..8b5004e 100644 --- a/pretor/adapter/model_adapter/agent_factory.py +++ b/pretor/adapter/model_adapter/agent_factory.py @@ -22,22 +22,28 @@ from pretor.core.global_state_machine.model_provider import Provider from pretor.utils.agent_model import ResponseModel, DepsModel from pretor.utils.error import ModelNotExistError + class AgentFactory: """AgentFactory 核心组件类。 - 这是一个领域数据模型或功能封装类,承载了 AgentFactory 相关的内聚属性定义与状态维护。它的存在隔离了局部的业务复杂性,并对外提供了类型安全的访问接口。 """ - def __init__(self): - self._models_mapping = {"openai": (OpenAIChatModel, OpenAIProvider), - "claude": (AnthropicModel, AnthropicProvider), - "deepseek": (OpenAIChatModel, OpenAIProvider),} + 这是一个领域数据模型或功能封装类,承载了 AgentFactory 相关的内聚属性定义与状态维护。它的存在隔离了局部的业务复杂性,并对外提供了类型安全的访问接口。""" - def create_agent(self, - provider: Provider, - model_id: str, - output_type: ResponseModel, - system_prompt: str, - deps_type: DepsModel, - agent_name: str, - tools: list = None) -> Agent: + def __init__(self): + self._models_mapping = { + "openai": (OpenAIChatModel, OpenAIProvider), + "claude": (AnthropicModel, AnthropicProvider), + "deepseek": (OpenAIChatModel, OpenAIProvider), + } + + def create_agent( + self, + provider: Provider, + model_id: str, + output_type: ResponseModel, + system_prompt: str, + deps_type: DepsModel, + agent_name: str, + tools: list = None, + ) -> Agent: """ create_agent方法,将输入的provider对象实例化为一个pydantic-ai的agent对象 @@ -58,22 +64,30 @@ class AgentFactory: if provider.provider_type not in self._models_mapping: raise ValueError(f"不支持的协议类型: {provider.provider_type}") model_class, provider_class = self._models_mapping[provider.provider_type] - model = model_class(model_id, provider=provider_class(api_key=provider.provider_apikey, base_url=provider.provider_url)) + model = model_class( + model_id, + provider=provider_class( + api_key=provider.provider_apikey, base_url=provider.provider_url + ), + ) match provider.provider_type: case "deepseek": - agent = DeepSeekReasonerAgent(model=model, - name=agent_name, - output_type=output_type, - deps_type=deps_type, - system_prompt=system_prompt, - tools=tools, - retries=3, - ) + agent = DeepSeekReasonerAgent( + model=model, + name=agent_name, + output_type=output_type, + deps_type=deps_type, + system_prompt=system_prompt, + tools=tools, + retries=3, + ) case _: - agent = Agent(model=model, - name=agent_name, - system_prompt=system_prompt, - output_type=output_type, - deps_type=deps_type, - tools=tools) - return agent \ No newline at end of file + agent = Agent( + model=model, + name=agent_name, + system_prompt=system_prompt, + output_type=output_type, + deps_type=deps_type, + tools=tools, + ) + return agent diff --git a/pretor/adapter/model_adapter/deepseek_reasoner.py b/pretor/adapter/model_adapter/deepseek_reasoner.py index 520f566..6039969 100644 --- a/pretor/adapter/model_adapter/deepseek_reasoner.py +++ b/pretor/adapter/model_adapter/deepseek_reasoner.py @@ -18,25 +18,29 @@ from typing import Type, TypeVar, Any, Generic from pydantic import BaseModel, ValidationError from pydantic_ai import Agent -T = TypeVar('T', bound=BaseModel) +T = TypeVar("T", bound=BaseModel) + class AgentRunResultProxy: """AgentRunResultProxy 核心组件类。 - 这是一个领域数据模型或功能封装类,承载了 AgentRunResultProxy 相关的内聚属性定义与状态维护。它的存在隔离了局部的业务复杂性,并对外提供了类型安全的访问接口。 """ + 这是一个领域数据模型或功能封装类,承载了 AgentRunResultProxy 相关的内聚属性定义与状态维护。它的存在隔离了局部的业务复杂性,并对外提供了类型安全的访问接口。""" + def __init__(self, original, parsed): self._original = original self._parsed = parsed + def __getattr__(self, name): """检索并获取特定的 getattr 数据集合或实例对象。 根据提供的查询条件或上下文凭证,从数据库、缓存或第三方服务中读取对应的资源状态。 Args: name: 赋予该实体的人类可读名称或标题字符串,主要用于前端 UI 展示、日志记录或模糊检索。 - Returns: : 经由当前业务模型加工处理后所输出的具体数据实例或领域模型对象。 """ - if name == 'data': + Returns: : 经由当前业务模型加工处理后所输出的具体数据实例或领域模型对象。""" + if name == "data": return self._parsed - if name == 'output': + if name == "output": return self._parsed return getattr(self._original, name) + class DeepSeekReasonerAgent(Generic[T]): """ 专为 DeepSeek-V4/R1 设计的适配器。 @@ -44,15 +48,15 @@ class DeepSeekReasonerAgent(Generic[T]): """ def __init__( - self, - model, - name, - output_type: Any = str, - system_prompt: str = "", - deps_type: Type[Any] = None, - tools: list = None, - retries: int = 3, - **kwargs + self, + model, + name, + output_type: Any = str, + system_prompt: str = "", + deps_type: Type[Any] = None, + tools: list = None, + retries: int = 3, + **kwargs, ): self.output_schema = output_type self.has_custom_output = output_type is not str and output_type is not None @@ -63,6 +67,7 @@ class DeepSeekReasonerAgent(Generic[T]): if self.has_custom_output: try: from pydantic import TypeAdapter + schema_dict = TypeAdapter(self.output_schema).json_schema() schema_str = json.dumps(schema_dict, ensure_ascii=False) format_instruction = ( @@ -77,14 +82,14 @@ class DeepSeekReasonerAgent(Generic[T]): if self.tools: tool_descs = [] for t in self.tools: - desc = getattr(t, '__name__', str(t)) - if hasattr(t, '__doc__') and t.__doc__: + desc = getattr(t, "__name__", str(t)) + if hasattr(t, "__doc__") and t.__doc__: desc += f": {t.__doc__.strip()}" tool_descs.append(f"- {desc}") tool_instruction = ( "\n\n系统为您提供了以下工具。由于当前处于结构化降级模式,无法原生调用。" - "但如果您在思考过程中判断必须使用这些工具,请在返回的结构中(或如果是自由文本)注明意图,由外层逻辑进行调度:\n" + - "\n".join(tool_descs) + "但如果您在思考过程中判断必须使用这些工具,请在返回的结构中(或如果是自由文本)注明意图,由外层逻辑进行调度:\n" + + "\n".join(tool_descs) ) self.agent = Agent( @@ -93,40 +98,41 @@ class DeepSeekReasonerAgent(Generic[T]): output_type=str, # Force native agent to return str to disable function calling system_prompt=system_prompt + format_instruction + tool_instruction, deps_type=deps_type, - **kwargs + **kwargs, ) def _parse_output(self, text: str) -> Any: """执行与 parse output 相关的核心业务流转操作。 该方法封装了具体的算法策略或状态控制逻辑,确保操作能够在事务上下文中被原子且一致地执行。 Args: text (str): 控制逻辑流向的具体字符串参数,指定了期望的 text 内容。 - Returns: (Any): 经由当前业务模型加工处理后所输出的具体数据实例或领域模型对象。 """ + Returns: (Any): 经由当前业务模型加工处理后所输出的具体数据实例或领域模型对象。""" if not self.has_custom_output: return text - match = re.search(r'```json\s*(.*?)\s*```', text, re.DOTALL) + match = re.search(r"```json\s*(.*?)\s*```", text, re.DOTALL) json_str = match.group(1).strip() if match else text - if not json_str.startswith('{') and not json_str.startswith('['): - start_obj = json_str.find('{') - start_arr = json_str.find('[') + if not json_str.startswith("{") and not json_str.startswith("["): + start_obj = json_str.find("{") + start_arr = json_str.find("[") start = -1 end = -1 if start_obj != -1 and (start_arr == -1 or start_obj < start_arr): start = start_obj - end = json_str.rfind('}') + end = json_str.rfind("}") elif start_arr != -1: start = start_arr - end = json_str.rfind(']') + end = json_str.rfind("]") if start != -1 and end != -1 and end > start: - json_str = json_str[start:end+1] + json_str = json_str[start : end + 1] if not json_str: raise ValueError("未找到有效的 JSON 块。请将结果包装在 ```json 中。") try: from pydantic import TypeAdapter + adapter = TypeAdapter(self.output_schema) return adapter.validate_json(json_str) except ValidationError as e: @@ -134,33 +140,35 @@ class DeepSeekReasonerAgent(Generic[T]): except json.JSONDecodeError as e: raise ValueError(f"返回的不是合法的 JSON:{e}") - def __getattr__(self, item): # Delegate any unknown attributes (like .system_prompt, .tool) to the underlying pydantic_ai Agent """检索并获取特定的 getattr 数据集合或实例对象。 根据提供的查询条件或上下文凭证,从数据库、缓存或第三方服务中读取对应的资源状态。 Args: item: 参与 getattr 逻辑运算或数据构建的上下文依赖对象。 - Returns: : 经由当前业务模型加工处理后所输出的具体数据实例或领域模型对象。 """ + Returns: : 经由当前业务模型加工处理后所输出的具体数据实例或领域模型对象。""" return getattr(self.agent, item) - async def run(self, user_prompt: str, deps: Any = None, message_history: list = None, **kwargs) -> Any: + async def run( + self, user_prompt: str, deps: Any = None, message_history: list = None, **kwargs + ) -> Any: # Custom retry loop """执行与 run 相关的核心业务流转操作。 该方法封装了具体的算法策略或状态控制逻辑,确保操作能够在事务上下文中被原子且一致地执行。 Args: user_prompt (str): 控制逻辑流向的具体字符串参数,指定了期望的 user prompt 内容。 deps (Any): 参与 run 逻辑运算或数据构建的上下文依赖对象。 message_history (list): 批量操作所需的列表集合,囊括了需要统一处理的多个 message history 元素。 - Returns: (Any): 经由当前业务模型加工处理后所输出的具体数据实例或领域模型对象。 """ + Returns: (Any): 经由当前业务模型加工处理后所输出的具体数据实例或领域模型对象。""" current_history = message_history or [] last_exception = None for attempt in range(self.retries + 1): result = await self.agent.run( - user_prompt, - deps=deps, - message_history=current_history, - **kwargs + user_prompt, deps=deps, message_history=current_history, **kwargs ) - raw_text = result.data if hasattr(result, 'data') else getattr(result, 'output', str(result)) + raw_text = ( + result.data + if hasattr(result, "data") + else getattr(result, "output", str(result)) + ) try: parsed_data = self._parse_output(raw_text) @@ -171,11 +179,15 @@ class DeepSeekReasonerAgent(Generic[T]): except ValueError as e: last_exception = e # Prepare retry prompt - user_prompt = f"你的上一次输出解析失败,错误原因是: {e}\n请修正格式后重新输出。" + user_prompt = ( + f"你的上一次输出解析失败,错误原因是: {e}\n请修正格式后重新输出。" + ) # We need to maintain history manually so the model sees what it did wrong # Actually, pydantic-ai manages history inside the result. Let's use the all_messages from result - if hasattr(result, 'all_messages'): + if hasattr(result, "all_messages"): current_history = result.all_messages() - raise ValueError(f"Exceeded maximum retries ({self.retries}) for output validation. Last error: {last_exception}") + raise ValueError( + f"Exceeded maximum retries ({self.retries}) for output validation. Last error: {last_exception}" + ) diff --git a/pretor/api/__init__.py b/pretor/api/__init__.py index 1b47ece..d8c7629 100644 --- a/pretor/api/__init__.py +++ b/pretor/api/__init__.py @@ -28,9 +28,15 @@ from .provider import provider_router from .resource import resource_router from .workflow import workflow_router from pretor.utils.error import ( - DemandError, ModelNotExistError, UserError, - UserNotExistError, UserPasswordError, ProviderError, - ProviderNotExistError, WorkflowError, WorkflowExit + DemandError, + ModelNotExistError, + UserError, + UserNotExistError, + UserPasswordError, + ProviderError, + ProviderNotExistError, + WorkflowError, + WorkflowExit, ) app = FastAPI() @@ -43,6 +49,7 @@ app.include_router(cluster_router) # 集群信息路径 app.include_router(agent_router) # agent路径 app.include_router(workflow_router) # workflow路径 + @app.exception_handler(UserNotExistError) async def user_not_exist_handler(request: Request, exc: UserNotExistError): return JSONResponse(status_code=404, content={"message": "用户不存在"}) @@ -87,37 +94,47 @@ async def workflow_exit_handler(request: Request, exc: WorkflowExit): async def workflow_error_handler(request: Request, exc: WorkflowError): return JSONResponse(status_code=500, content={"message": "工作流执行错误"}) -base_dir = os.path.dirname(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))) + +base_dir = os.path.dirname( + os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) +) frontend_dir = os.path.join(base_dir, "frontend", "dist") if os.path.exists(frontend_dir): - app.mount("/assets", StaticFiles(directory=os.path.join(frontend_dir, "assets")), name="assets") - + app.mount( + "/assets", + StaticFiles(directory=os.path.join(frontend_dir, "assets")), + name="assets", + ) @app.get("/favicon.svg", include_in_schema=False) async def serve_favicon(): return FileResponse(os.path.join(frontend_dir, "favicon.svg")) - @app.get("/icons.svg", include_in_schema=False) async def serve_icons(): return FileResponse(os.path.join(frontend_dir, "icons.svg")) - @app.get("/{full_path:path}", include_in_schema=False) async def serve_frontend(full_path: str): # 【重要安全修复】避免拦截不存在的 API 路由。如果是调用了不存在的 /api/ 接口,直接返回 404,不返回前端页面 if full_path.startswith("api/"): - return JSONResponse(status_code=404, content={"detail": "API endpoint not found"}) + return JSONResponse( + status_code=404, content={"detail": "API endpoint not found"} + ) index_path = os.path.join(frontend_dir, "index.html") if os.path.exists(index_path): return FileResponse(index_path) - return JSONResponse(status_code=404, content={"detail": "Frontend build not found"}) + return JSONResponse( + status_code=404, content={"detail": "Frontend build not found"} + ) else: import logging - logging.getLogger("pretor").warning(f"Frontend dist folder not found at {frontend_dir}. Skipping frontend mount.") + logging.getLogger("pretor").warning( + f"Frontend dist folder not found at {frontend_dir}. Skipping frontend mount." + ) @serve.deployment @@ -126,4 +143,4 @@ class PretorGateway: gateway: Dict[str, WebSocket] def __init__(self): - self.gateway = {} \ No newline at end of file + self.gateway = {} diff --git a/pretor/api/agent.py b/pretor/api/agent.py index 72fd164..d8f7122 100644 --- a/pretor/api/agent.py +++ b/pretor/api/agent.py @@ -26,38 +26,48 @@ from pretor.core.database.table.user import UserAuthority agent_router = APIRouter(prefix="/api/v1/agent", tags=["agent"]) + class AgentRegister(BaseModel): """AgentRegister 核心组件类。 - 这是一个领域数据模型或功能封装类,承载了 AgentRegister 相关的内聚属性定义与状态维护。它的存在隔离了局部的业务复杂性,并对外提供了类型安全的访问接口。 """ + 这是一个领域数据模型或功能封装类,承载了 AgentRegister 相关的内聚属性定义与状态维护。它的存在隔离了局部的业务复杂性,并对外提供了类型安全的访问接口。""" + provider_title: str model_id: str individual_name: str tools: Optional[List[str]] = None + class AgentLocalRegister(BaseModel): """AgentLocalRegister 核心组件类。 - 这是一个领域数据模型或功能封装类,承载了 AgentLocalRegister 相关的内聚属性定义与状态维护。它的存在隔离了局部的业务复杂性,并对外提供了类型安全的访问接口。 """ + 这是一个领域数据模型或功能封装类,承载了 AgentLocalRegister 相关的内聚属性定义与状态维护。它的存在隔离了局部的业务复杂性,并对外提供了类型安全的访问接口。""" + path: str individual_name: str tools: Optional[List[str]] = None + @agent_router.get("") -async def get_system_nodes(_: TokenData = Depends(RoleChecker(allowed_roles=UserAuthority.USER))): +async def get_system_nodes( + _: TokenData = Depends(RoleChecker(allowed_roles=UserAuthority.USER)), +): """处理针对 get system nodes 相关的 HTTP API 请求。 该接口负责解析前端传入的载荷数据,调用底层核心业务逻辑进行处理,并组装标准化的 JSON 响应。 Args: _ (TokenData): 参与 get system nodes 逻辑运算或数据构建的上下文依赖对象。 - Returns: : 序列化后的标准网络响应模型(如包含业务状态码、成功标志及对应的数据载荷 Data)。 """ + Returns: : 序列化后的标准网络响应模型(如包含业务状态码、成功标志及对应的数据载荷 Data)。""" postgres_database = ray_actor_hook("postgres_database").postgres_database configs = await postgres_database.get_all_system_node_configs.remote() return {"system_nodes": configs} + @agent_router.post("") -async def load_agent(agent_register: Union[AgentRegister, AgentLocalRegister], - _: TokenData = Depends(RoleChecker(allowed_roles=UserAuthority.USER))): +async def load_agent( + agent_register: Union[AgentRegister, AgentLocalRegister], + _: TokenData = Depends(RoleChecker(allowed_roles=UserAuthority.USER)), +): """处理针对 load agent 相关的 HTTP API 请求。 该接口负责解析前端传入的载荷数据,调用底层核心业务逻辑进行处理,并组装标准化的 JSON 响应。 Args: agent_register (Union[AgentRegister, AgentLocalRegister]): 参与 load agent 逻辑运算或数据构建的上下文依赖对象。 _ (TokenData): 参与 load agent 逻辑运算或数据构建的上下文依赖对象。 - Returns: : 序列化后的标准网络响应模型(如包含业务状态码、成功标志及对应的数据载荷 Data)。 """ + Returns: : 序列化后的标准网络响应模型(如包含业务状态码、成功标志及对应的数据载荷 Data)。""" global_state_machine = ray_actor_hook("global_state_machine").global_state_machine postgres_database = ray_actor_hook("postgres_database").postgres_database @@ -71,20 +81,35 @@ async def load_agent(agent_register: Union[AgentRegister, AgentLocalRegister], agent_register.individual_name, agent_register.provider_title, agent_register.model_id, - agent_register.tools + agent_register.tools, ) # Load agent into state machine match agent_register.individual_name: case "supervisory_node": node = ray_actor_hook("supervisory_node").supervisory_node - await node.create_agent.remote(global_state_machine,agent_register.provider_title,agent_register.model_id, agent_register.tools) + await node.create_agent.remote( + global_state_machine, + agent_register.provider_title, + agent_register.model_id, + agent_register.tools, + ) case "consciousness_node": node = ray_actor_hook("consciousness_node").consciousness_node - await node.create_agent.remote(global_state_machine,agent_register.provider_title,agent_register.model_id, agent_register.tools) + await node.create_agent.remote( + global_state_machine, + agent_register.provider_title, + agent_register.model_id, + agent_register.tools, + ) case "control_node": node = ray_actor_hook("control_node").control_node - await node.create_agent.remote(global_state_machine,agent_register.provider_title,agent_register.model_id, agent_register.tools) + await node.create_agent.remote( + global_state_machine, + agent_register.provider_title, + agent_register.model_id, + agent_register.tools, + ) case _: pass except Exception as e: @@ -94,7 +119,8 @@ async def load_agent(agent_register: Union[AgentRegister, AgentLocalRegister], class WorkerIndividualCreate(BaseModel): """WorkerIndividualCreate 核心组件类。 - 这是一个具体的 Worker 智能体实体类,代表着具备特定人设、领域技能或长文本处理能力的数字员工。它可以被控制器动态拉起,并在安全沙箱内执行复杂的工作流指令与多步骤推理任务。 """ + 这是一个具体的 Worker 智能体实体类,代表着具备特定人设、领域技能或长文本处理能力的数字员工。它可以被控制器动态拉起,并在安全沙箱内执行复杂的工作流指令与多步骤推理任务。""" + agent_name: str agent_type: AgentType description: str @@ -109,7 +135,8 @@ class WorkerIndividualCreate(BaseModel): class WorkerIndividualUpdate(BaseModel): """WorkerIndividualUpdate 核心组件类。 - 这是一个具体的 Worker 智能体实体类,代表着具备特定人设、领域技能或长文本处理能力的数字员工。它可以被控制器动态拉起,并在安全沙箱内执行复杂的工作流指令与多步骤推理任务。 """ + 这是一个具体的 Worker 智能体实体类,代表着具备特定人设、领域技能或长文本处理能力的数字员工。它可以被控制器动态拉起,并在安全沙箱内执行复杂的工作流指令与多步骤推理任务。""" + agent_name: Optional[str] = None agent_type: Optional[AgentType] = None description: Optional[str] = None @@ -123,63 +150,78 @@ class WorkerIndividualUpdate(BaseModel): @agent_router.post("/worker") -async def create_worker_individual(worker_data: WorkerIndividualCreate, - token_data: TokenData = Depends(RoleChecker(allowed_roles=UserAuthority.USER))): +async def create_worker_individual( + worker_data: WorkerIndividualCreate, + token_data: TokenData = Depends(RoleChecker(allowed_roles=UserAuthority.USER)), +): """处理针对 create worker individual 相关的 HTTP API 请求。 该接口负责解析前端传入的载荷数据,调用底层核心业务逻辑进行处理,并组装标准化的 JSON 响应。 Args: worker_data (WorkerIndividualCreate): 从客户端传递过来或由上游组件生成的核心业务数据体,通常需要进一步的清洗和结构化解析。 token_data (TokenData): 从客户端传递过来或由上游组件生成的核心业务数据体,通常需要进一步的清洗和结构化解析。 - Returns: : 序列化后的标准网络响应模型(如包含业务状态码、成功标志及对应的数据载荷 Data)。 """ + Returns: : 序列化后的标准网络响应模型(如包含业务状态码、成功标志及对应的数据载荷 Data)。""" postgres_database = ray_actor_hook("postgres_database").postgres_database data_dict = worker_data.model_dump() data_dict["owner_id"] = token_data.user_id - worker = await postgres_database.add_worker_individual.remote( **data_dict) + worker = await postgres_database.add_worker_individual.remote(**data_dict) return {"message": "success", "agent_id": worker.agent_id} @agent_router.get("/worker") -async def get_worker_individual_list(token_data: TokenData = Depends(Accessor.get_current_user)): +async def get_worker_individual_list( + token_data: TokenData = Depends(Accessor.get_current_user), +): """处理针对 get worker individual list 相关的 HTTP API 请求。 该接口负责解析前端传入的载荷数据,调用底层核心业务逻辑进行处理,并组装标准化的 JSON 响应。 Args: token_data (TokenData): 从客户端传递过来或由上游组件生成的核心业务数据体,通常需要进一步的清洗和结构化解析。 - Returns: : 序列化后的标准网络响应模型(如包含业务状态码、成功标志及对应的数据载荷 Data)。 """ + Returns: : 序列化后的标准网络响应模型(如包含业务状态码、成功标志及对应的数据载荷 Data)。""" postgres_database = ray_actor_hook("postgres_database").postgres_database - workers = await postgres_database.get_worker_individual_list.remote( owner_id=token_data.user_id) + workers = await postgres_database.get_worker_individual_list.remote( + owner_id=token_data.user_id + ) return {"workers": workers} @agent_router.get("/worker/{agent_id}") -async def get_worker_individual(agent_id: str, - token_data: TokenData = Depends(Accessor.get_current_user)): +async def get_worker_individual( + agent_id: str, token_data: TokenData = Depends(Accessor.get_current_user) +): """处理针对 get worker individual 相关的 HTTP API 请求。 该接口负责解析前端传入的载荷数据,调用底层核心业务逻辑进行处理,并组装标准化的 JSON 响应。 Args: agent_id (str): 目标对象的唯一全局标识符 (UUID/ULID),用于在数据库表或缓存结构中精准匹配该 agent 实例。 token_data (TokenData): 从客户端传递过来或由上游组件生成的核心业务数据体,通常需要进一步的清洗和结构化解析。 - Returns: : 序列化后的标准网络响应模型(如包含业务状态码、成功标志及对应的数据载荷 Data)。 """ + Returns: : 序列化后的标准网络响应模型(如包含业务状态码、成功标志及对应的数据载荷 Data)。""" postgres_database = ray_actor_hook("postgres_database").postgres_database - worker = await postgres_database.get_worker_individual.remote( agent_id=agent_id) + worker = await postgres_database.get_worker_individual.remote(agent_id=agent_id) if not worker: raise HTTPException(status_code=404, detail="Agent not found") if worker.owner_id != token_data.user_id: - raise HTTPException(status_code=403, detail="Forbidden: You do not own this agent") + raise HTTPException( + status_code=403, detail="Forbidden: You do not own this agent" + ) return worker @agent_router.put("/worker/{agent_id}") -async def update_worker_individual(agent_id: str, - worker_data: WorkerIndividualUpdate, - token_data: TokenData = Depends(Accessor.get_current_user)): +async def update_worker_individual( + agent_id: str, + worker_data: WorkerIndividualUpdate, + token_data: TokenData = Depends(Accessor.get_current_user), +): """处理针对 update worker individual 相关的 HTTP API 请求。 该接口负责解析前端传入的载荷数据,调用底层核心业务逻辑进行处理,并组装标准化的 JSON 响应。 Args: agent_id (str): 目标对象的唯一全局标识符 (UUID/ULID),用于在数据库表或缓存结构中精准匹配该 agent 实例。 worker_data (WorkerIndividualUpdate): 从客户端传递过来或由上游组件生成的核心业务数据体,通常需要进一步的清洗和结构化解析。 token_data (TokenData): 从客户端传递过来或由上游组件生成的核心业务数据体,通常需要进一步的清洗和结构化解析。 - Returns: : 序列化后的标准网络响应模型(如包含业务状态码、成功标志及对应的数据载荷 Data)。 """ + Returns: : 序列化后的标准网络响应模型(如包含业务状态码、成功标志及对应的数据载荷 Data)。""" postgres_database = ray_actor_hook("postgres_database").postgres_database - worker = await postgres_database.get_worker_individual.remote( agent_id=agent_id) + worker = await postgres_database.get_worker_individual.remote(agent_id=agent_id) if not worker: raise HTTPException(status_code=404, detail="Agent not found") if worker.owner_id != token_data.user_id: - raise HTTPException(status_code=403, detail="Forbidden: You do not own this agent") + raise HTTPException( + status_code=403, detail="Forbidden: You do not own this agent" + ) update_data = worker_data.model_dump(exclude_unset=True) - updated_worker = await postgres_database.update_worker_individual.remote( agent_id=agent_id, **update_data) + updated_worker = await postgres_database.update_worker_individual.remote( + agent_id=agent_id, **update_data + ) global_state_machine = ray_actor_hook("global_state_machine").global_state_machine try: @@ -189,18 +231,23 @@ async def update_worker_individual(agent_id: str, return {"message": "success", "worker": updated_worker} + @agent_router.post("/worker/{agent_id}/reload") -async def reload_worker_individual(agent_id: str, token_data: TokenData = Depends(Accessor.get_current_user)): +async def reload_worker_individual( + agent_id: str, token_data: TokenData = Depends(Accessor.get_current_user) +): """处理针对 reload worker individual 相关的 HTTP API 请求。 该接口负责解析前端传入的载荷数据,调用底层核心业务逻辑进行处理,并组装标准化的 JSON 响应。 Args: agent_id (str): 目标对象的唯一全局标识符 (UUID/ULID),用于在数据库表或缓存结构中精准匹配该 agent 实例。 token_data (TokenData): 从客户端传递过来或由上游组件生成的核心业务数据体,通常需要进一步的清洗和结构化解析。 - Returns: : 序列化后的标准网络响应模型(如包含业务状态码、成功标志及对应的数据载荷 Data)。 """ + Returns: : 序列化后的标准网络响应模型(如包含业务状态码、成功标志及对应的数据载荷 Data)。""" postgres_database = ray_actor_hook("postgres_database").postgres_database worker = await postgres_database.get_worker_individual.remote(agent_id=agent_id) if not worker: raise HTTPException(status_code=404, detail="Agent not found") if worker.owner_id != token_data.user_id: - raise HTTPException(status_code=403, detail="Forbidden: You do not own this agent") + raise HTTPException( + status_code=403, detail="Forbidden: You do not own this agent" + ) global_state_machine = ray_actor_hook("global_state_machine").global_state_machine await global_state_machine.remove_individual.remote(agent_id) @@ -209,17 +256,20 @@ async def reload_worker_individual(agent_id: str, token_data: TokenData = Depend @agent_router.delete("/worker/{agent_id}") -async def delete_worker_individual(agent_id: str, - token_data: TokenData = Depends(Accessor.get_current_user)): +async def delete_worker_individual( + agent_id: str, token_data: TokenData = Depends(Accessor.get_current_user) +): """处理针对 delete worker individual 相关的 HTTP API 请求。 该接口负责解析前端传入的载荷数据,调用底层核心业务逻辑进行处理,并组装标准化的 JSON 响应。 Args: agent_id (str): 目标对象的唯一全局标识符 (UUID/ULID),用于在数据库表或缓存结构中精准匹配该 agent 实例。 token_data (TokenData): 从客户端传递过来或由上游组件生成的核心业务数据体,通常需要进一步的清洗和结构化解析。 - Returns: : 序列化后的标准网络响应模型(如包含业务状态码、成功标志及对应的数据载荷 Data)。 """ + Returns: : 序列化后的标准网络响应模型(如包含业务状态码、成功标志及对应的数据载荷 Data)。""" postgres_database = ray_actor_hook("postgres_database").postgres_database - worker = await postgres_database.get_worker_individual.remote( agent_id=agent_id) + worker = await postgres_database.get_worker_individual.remote(agent_id=agent_id) if not worker: raise HTTPException(status_code=404, detail="Agent not found") if worker.owner_id != token_data.user_id: - raise HTTPException(status_code=403, detail="Forbidden: You do not own this agent") - await postgres_database.delete_worker_individual.remote( agent_id=agent_id) - return {"message": "success"} \ No newline at end of file + raise HTTPException( + status_code=403, detail="Forbidden: You do not own this agent" + ) + await postgres_database.delete_worker_individual.remote(agent_id=agent_id) + return {"message": "success"} diff --git a/pretor/api/auth.py b/pretor/api/auth.py index d084233..061538a 100644 --- a/pretor/api/auth.py +++ b/pretor/api/auth.py @@ -24,79 +24,113 @@ from pretor.utils.error import UserNotExistError auth_router = APIRouter(prefix="/api/v1/auth", tags=["auth"]) + class UserRegister(BaseModel): """UserRegister 核心组件类。 - 这是一个领域数据模型或功能封装类,承载了 UserRegister 相关的内聚属性定义与状态维护。它的存在隔离了局部的业务复杂性,并对外提供了类型安全的访问接口。 """ + 这是一个领域数据模型或功能封装类,承载了 UserRegister 相关的内聚属性定义与状态维护。它的存在隔离了局部的业务复杂性,并对外提供了类型安全的访问接口。""" + user_name: str password: str + @auth_router.post("/register") async def create_user(user_register: UserRegister): """处理针对 create user 相关的 HTTP API 请求。 该接口负责解析前端传入的载荷数据,调用底层核心业务逻辑进行处理,并组装标准化的 JSON 响应。 Args: user_register (UserRegister): 参与 create user 逻辑运算或数据构建的上下文依赖对象。 - Returns: : 序列化后的标准网络响应模型(如包含业务状态码、成功标志及对应的数据载荷 Data)。 """ + Returns: : 序列化后的标准网络响应模型(如包含业务状态码、成功标志及对应的数据载荷 Data)。""" postgres_database = ray_actor_hook("postgres_database").postgres_database - hashed_password = await run_in_threadpool(Accessor.hash_password, user_register.password) - user = await postgres_database.add_user.remote( user_register.user_name, hashed_password) + hashed_password = await run_in_threadpool( + Accessor.hash_password, user_register.password + ) + user = await postgres_database.add_user.remote( + user_register.user_name, hashed_password + ) return {"message": "success", "user_id": user.user_id} + class UserLogin(BaseModel): """UserLogin 核心组件类。 - 这是一个领域数据模型或功能封装类,承载了 UserLogin 相关的内聚属性定义与状态维护。它的存在隔离了局部的业务复杂性,并对外提供了类型安全的访问接口。 """ + 这是一个领域数据模型或功能封装类,承载了 UserLogin 相关的内聚属性定义与状态维护。它的存在隔离了局部的业务复杂性,并对外提供了类型安全的访问接口。""" + user_name: str password: str + @auth_router.post("/login") async def login_user(user_login: UserLogin): """处理针对 login user 相关的 HTTP API 请求。 该接口负责解析前端传入的载荷数据,调用底层核心业务逻辑进行处理,并组装标准化的 JSON 响应。 Args: user_login (UserLogin): 参与 login user 逻辑运算或数据构建的上下文依赖对象。 - Returns: : 序列化后的标准网络响应模型(如包含业务状态码、成功标志及对应的数据载荷 Data)。 """ + Returns: : 序列化后的标准网络响应模型(如包含业务状态码、成功标志及对应的数据载荷 Data)。""" postgres_database = ray_actor_hook("postgres_database").postgres_database - user = await postgres_database.login_user.remote( user_login.user_name) + user = await postgres_database.login_user.remote(user_login.user_name) if not user: raise UserNotExistError() - token = await run_in_threadpool(Accessor.login_hashed_password, user, user_login.password) - return {"message":"success", "token":token} + token = await run_in_threadpool( + Accessor.login_hashed_password, user, user_login.password + ) + return {"message": "success", "token": token} + class ChangeAuthorityRequest(BaseModel): """ChangeAuthorityRequest 核心组件类。 - 这是一个领域数据模型或功能封装类,承载了 ChangeAuthorityRequest 相关的内聚属性定义与状态维护。它的存在隔离了局部的业务复杂性,并对外提供了类型安全的访问接口。 """ + 这是一个领域数据模型或功能封装类,承载了 ChangeAuthorityRequest 相关的内聚属性定义与状态维护。它的存在隔离了局部的业务复杂性,并对外提供了类型安全的访问接口。""" + user_id: str new_authority: UserAuthority + @auth_router.put("/authority") async def change_authority( request: ChangeAuthorityRequest, - _: TokenData = Depends(RoleChecker(allowed_roles=UserAuthority.SUPER_ADMINISTRATOR)) + _: TokenData = Depends( + RoleChecker(allowed_roles=UserAuthority.SUPER_ADMINISTRATOR) + ), ): """ Update a user's authority level. Only accessible by SUPER_ADMINISTRATOR. """ postgres_database = ray_actor_hook("postgres_database").postgres_database - user = await postgres_database.change_user_authority.remote( user_id=request.user_id, new_authority=request.new_authority) - return {"message": "success", "user_id": user.user_id, "new_authority": user.user_authority} + user = await postgres_database.change_user_authority.remote( + user_id=request.user_id, new_authority=request.new_authority + ) + return { + "message": "success", + "user_id": user.user_id, + "new_authority": user.user_authority, + } + @auth_router.get("/list") async def get_user_list( - _: TokenData = Depends(RoleChecker(allowed_roles=UserAuthority.SUPER_ADMINISTRATOR)) + _: TokenData = Depends( + RoleChecker(allowed_roles=UserAuthority.SUPER_ADMINISTRATOR) + ), ): """ Get a list of all users. Only accessible by SUPER_ADMINISTRATOR. """ postgres_database = ray_actor_hook("postgres_database").postgres_database users = await postgres_database.get_all_users.remote() - return {"users": [{"user_id": u.user_id, "user_name": u.user_name, "role": u.user_authority} for u in users]} + return { + "users": [ + {"user_id": u.user_id, "user_name": u.user_name, "role": u.user_authority} + for u in users + ] + } + @auth_router.delete("/{user_id}") async def delete_user( user_id: str, - _: TokenData = Depends(RoleChecker(allowed_roles=UserAuthority.SUPER_ADMINISTRATOR)) + _: TokenData = Depends( + RoleChecker(allowed_roles=UserAuthority.SUPER_ADMINISTRATOR) + ), ): """ Delete a user. Only accessible by SUPER_ADMINISTRATOR. """ postgres_database = ray_actor_hook("postgres_database").postgres_database - await postgres_database.delete_user_by_id.remote( user_id=user_id) - return {"message": "success"} \ No newline at end of file + await postgres_database.delete_user_by_id.remote(user_id=user_id) + return {"message": "success"} diff --git a/pretor/api/cluster.py b/pretor/api/cluster.py index 9165c47..c9cb7dc 100644 --- a/pretor/api/cluster.py +++ b/pretor/api/cluster.py @@ -16,4 +16,4 @@ from fastapi import APIRouter cluster_router = APIRouter(prefix="/api/v1/cluster", tags=["cluster"]) -# Monitor websocket API temporarily removed \ No newline at end of file +# Monitor websocket API temporarily removed diff --git a/pretor/api/platform/__init__.py b/pretor/api/platform/__init__.py index 9f99733..afdace8 100644 --- a/pretor/api/platform/__init__.py +++ b/pretor/api/platform/__init__.py @@ -13,4 +13,5 @@ # limitations under the License. from .frontend import client_router + __all__ = ["client_router"] diff --git a/pretor/api/platform/event.py b/pretor/api/platform/event.py index 8b0a68e..5af3fab 100644 --- a/pretor/api/platform/event.py +++ b/pretor/api/platform/event.py @@ -19,20 +19,34 @@ from typing import Any, Dict from pretor.core.workflow.workflow import PretorWorkflow import asyncio + class PretorEvent(BaseModel): """PretorEvent 核心组件类。 - 这是一个领域数据模型或功能封装类,承载了 PretorEvent 相关的内聚属性定义与状态维护。它的存在隔离了局部的业务复杂性,并对外提供了类型安全的访问接口。 """ + 这是一个领域数据模型或功能封装类,承载了 PretorEvent 相关的内聚属性定义与状态维护。它的存在隔离了局部的业务复杂性,并对外提供了类型安全的访问接口。""" + model_config = ConfigDict(arbitrary_types_allowed=True) - trace_id: str = Field(default_factory=lambda: str(ULID()), description="事件的唯一标识符") + trace_id: str = Field( + default_factory=lambda: str(ULID()), description="事件的唯一标识符" + ) platform: str = Field(description="消息来源的平台") user_id: str = Field(description="用户id") user_name: str = Field(description="用户名") - create_time: str = Field(default_factory=lambda: str(datetime.datetime.now(datetime.timezone.utc).isoformat()), - description="事件创建时间") + create_time: str = Field( + default_factory=lambda: str( + datetime.datetime.now(datetime.timezone.utc).isoformat() + ), + description="事件创建时间", + ) message: str = Field(description="用户发来的消息") - attachment: Dict[str, str] | None = Field(default=None,description="附件") - #-------------------------------------------------------------------------------------------------------------- - context: Dict[str, Any] = Field(default_factory=dict, description="事件上下文内容,可包含工作流模板等信息") - workflow: PretorWorkflow | None = Field(default=None,description="工作流") - pending_queue: asyncio.Queue[str] | None= Field(default=None,description="待处理队列") - receive_queue: asyncio.Queue[str] | None = Field(default=None,description="待接收队列") + attachment: Dict[str, str] | None = Field(default=None, description="附件") + # -------------------------------------------------------------------------------------------------------------- + context: Dict[str, Any] = Field( + default_factory=dict, description="事件上下文内容,可包含工作流模板等信息" + ) + workflow: PretorWorkflow | None = Field(default=None, description="工作流") + pending_queue: asyncio.Queue[str] | None = Field( + default=None, description="待处理队列" + ) + receive_queue: asyncio.Queue[str] | None = Field( + default=None, description="待接收队列" + ) diff --git a/pretor/api/platform/frontend.py b/pretor/api/platform/frontend.py index 329f82b..2280828 100644 --- a/pretor/api/platform/frontend.py +++ b/pretor/api/platform/frontend.py @@ -22,45 +22,54 @@ import anyio from pretor.utils.logger import get_logger -logger = get_logger('frontend') +logger = get_logger("frontend") client_router = APIRouter(prefix="/api/v1/adapter/client", tags=["client"]) + class Message(BaseModel): """Message 核心组件类。 - 这是一个领域数据模型或功能封装类,承载了 Message 相关的内聚属性定义与状态维护。它的存在隔离了局部的业务复杂性,并对外提供了类型安全的访问接口。 """ + 这是一个领域数据模型或功能封装类,承载了 Message 相关的内聚属性定义与状态维护。它的存在隔离了局部的业务复杂性,并对外提供了类型安全的访问接口。""" + message: str + @client_router.post("") -async def create_message(message: Message, - token_data: TokenData = Depends(Accessor.get_current_user)): +async def create_message( + message: Message, token_data: TokenData = Depends(Accessor.get_current_user) +): """处理针对 create message 相关的 HTTP API 请求。 该接口负责解析前端传入的载荷数据,调用底层核心业务逻辑进行处理,并组装标准化的 JSON 响应。 Args: message (Message): 参与 create message 逻辑运算或数据构建的上下文依赖对象。 token_data (TokenData): 从客户端传递过来或由上游组件生成的核心业务数据体,通常需要进一步的清洗和结构化解析。 - Returns: : 序列化后的标准网络响应模型(如包含业务状态码、成功标志及对应的数据载荷 Data)。 """ + Returns: : 序列化后的标准网络响应模型(如包含业务状态码、成功标志及对应的数据载荷 Data)。""" logger.info("收到消息,来源:客户端") logger.debug(f"消息内容:{message.message}") - event = PretorEvent(platform="client", - user_id=str(token_data.user_id), - user_name=token_data.username, - message=message.message) + event = PretorEvent( + platform="client", + user_id=str(token_data.user_id), + user_name=token_data.username, + message=message.message, + ) supervisory_node = ray_actor_hook("supervisory_node").supervisory_node message = await supervisory_node.working.remote(event) if message.startswith("任务已创建"): return {"message": f"{event.trace_id}\n\n{message}"} elif message == "未知相应类型": raise HTTPException( - status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, - detail="模型回复错误") + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="模型回复错误" + ) else: return {"message": message} + @client_router.post("/upload") -async def upload_file(file: UploadFile = File(...), - token_data: TokenData = Depends(Accessor.get_current_user)): +async def upload_file( + file: UploadFile = File(...), + token_data: TokenData = Depends(Accessor.get_current_user), +): """处理针对 upload file 相关的 HTTP API 请求。 该接口负责解析前端传入的载荷数据,调用底层核心业务逻辑进行处理,并组装标准化的 JSON 响应。 Args: file (UploadFile): 参与 upload file 逻辑运算或数据构建的上下文依赖对象。 token_data (TokenData): 从客户端传递过来或由上游组件生成的核心业务数据体,通常需要进一步的清洗和结构化解析。 - Returns: : 序列化后的标准网络响应模型(如包含业务状态码、成功标志及对应的数据载荷 Data)。 """ + Returns: : 序列化后的标准网络响应模型(如包含业务状态码、成功标志及对应的数据载荷 Data)。""" try: upload_dir = "uploads" os.makedirs(upload_dir, exist_ok=True) @@ -69,7 +78,10 @@ async def upload_file(file: UploadFile = File(...), while chunk := await file.read(64 * 1024): # 64KB chunks await buffer.write(chunk) logger.info(f"用户 {token_data.username} 上传了文件: {file.filename}") - return {"filename": file.filename, "message": f"File {file.filename} uploaded successfully"} + return { + "filename": file.filename, + "message": f"File {file.filename} uploaded successfully", + } except Exception as e: logger.error(f"文件上传失败: {e}") raise HTTPException(status_code=500, detail="文件上传失败") diff --git a/pretor/api/provider.py b/pretor/api/provider.py index 6ec4087..3c1e4cf 100644 --- a/pretor/api/provider.py +++ b/pretor/api/provider.py @@ -24,45 +24,62 @@ from pretor.utils.ray_hook import ray_actor_hook provider_router = APIRouter(prefix="/api/v1/provider", tags=["provider"]) + class ProviderRegister(BaseModel): """ProviderRegister 核心组件类。 - 这是一个模型/服务提供商适配器类,屏蔽了外部不同供应商(如 OpenAI、Anthropic 等)的底层 API 差异。它负责标准化参数组装、网络请求发送、鉴权处理以及响应结构的反序列化。 """ + 这是一个模型/服务提供商适配器类,屏蔽了外部不同供应商(如 OpenAI、Anthropic 等)的底层 API 差异。它负责标准化参数组装、网络请求发送、鉴权处理以及响应结构的反序列化。""" + provider_type: Literal["openai", "claude", "deepseek"] provider_title: str provider_url: str provider_apikey: str + @provider_router.post("") -async def create_provider(provider_register: ProviderRegister, - token_data: TokenData = Depends(RoleChecker(allowed_roles=UserAuthority.USER))) -> None: +async def create_provider( + provider_register: ProviderRegister, + token_data: TokenData = Depends(RoleChecker(allowed_roles=UserAuthority.USER)), +) -> None: """处理针对 create provider 相关的 HTTP API 请求。 该接口负责解析前端传入的载荷数据,调用底层核心业务逻辑进行处理,并组装标准化的 JSON 响应。 Args: provider_register (ProviderRegister): 目标对象的唯一全局标识符 (UUID/ULID),用于在数据库表或缓存结构中精准匹配该 provider_register 实例。 token_data (TokenData): 从客户端传递过来或由上游组件生成的核心业务数据体,通常需要进一步的清洗和结构化解析。 - Returns: (None): 序列化后的标准网络响应模型(如包含业务状态码、成功标志及对应的数据载荷 Data)。 """ + Returns: (None): 序列化后的标准网络响应模型(如包含业务状态码、成功标志及对应的数据载荷 Data)。""" global_state_machine = ray_actor_hook("global_state_machine").global_state_machine - await global_state_machine.add_provider_wrap.remote(provider_type=provider_register.provider_type, - provider_title=provider_register.provider_title, - provider_url=provider_register.provider_url, - provider_apikey=provider_register.provider_apikey, - provider_owner=token_data.user_id) + await global_state_machine.add_provider_wrap.remote( + provider_type=provider_register.provider_type, + provider_title=provider_register.provider_title, + provider_url=provider_register.provider_url, + provider_apikey=provider_register.provider_apikey, + provider_owner=token_data.user_id, + ) @provider_router.get("/list") -async def get_provider_list(_: TokenData = Depends(Accessor.get_current_user)) -> Dict[str, Dict[str, Provider]]: +async def get_provider_list( + _: TokenData = Depends(Accessor.get_current_user), +) -> Dict[str, Dict[str, Provider]]: """处理针对 get provider list 相关的 HTTP API 请求。 该接口负责解析前端传入的载荷数据,调用底层核心业务逻辑进行处理,并组装标准化的 JSON 响应。 Args: _ (TokenData): 参与 get provider list 逻辑运算或数据构建的上下文依赖对象。 - Returns: (Dict[str, Dict[str, Provider]]): 序列化后的标准网络响应模型(如包含业务状态码、成功标志及对应的数据载荷 Data)。 """ + Returns: (Dict[str, Dict[str, Provider]]): 序列化后的标准网络响应模型(如包含业务状态码、成功标志及对应的数据载荷 Data)。""" global_state_machine = ray_actor_hook("global_state_machine").global_state_machine - provider_list: Dict[str, Provider] = await global_state_machine.get_provider_list.remote() + provider_list: Dict[ + str, Provider + ] = await global_state_machine.get_provider_list.remote() return {"provider_list": provider_list} + @provider_router.delete("/{provider_title}") -async def delete_provider(provider_title: str, _: TokenData = Depends(RoleChecker(allowed_roles=UserAuthority.SUPER_ADMINISTRATOR))) -> dict: +async def delete_provider( + provider_title: str, + _: TokenData = Depends( + RoleChecker(allowed_roles=UserAuthority.SUPER_ADMINISTRATOR) + ), +) -> dict: """处理针对 delete provider 相关的 HTTP API 请求。 该接口负责解析前端传入的载荷数据,调用底层核心业务逻辑进行处理,并组装标准化的 JSON 响应。 Args: provider_title (str): 目标对象的唯一全局标识符 (UUID/ULID),用于在数据库表或缓存结构中精准匹配该 provider_title 实例。 _ (TokenData): 参与 delete provider 逻辑运算或数据构建的上下文依赖对象。 - Returns: (dict): 序列化后的标准网络响应模型(如包含业务状态码、成功标志及对应的数据载荷 Data)。 """ + Returns: (dict): 序列化后的标准网络响应模型(如包含业务状态码、成功标志及对应的数据载荷 Data)。""" global_state_machine = ray_actor_hook("global_state_machine").global_state_machine await global_state_machine.delete_provider.remote(provider_title=provider_title) - return {"message": "success"} \ No newline at end of file + return {"message": "success"} diff --git a/pretor/api/resource.py b/pretor/api/resource.py index 08bd491..02eaf52 100644 --- a/pretor/api/resource.py +++ b/pretor/api/resource.py @@ -14,7 +14,6 @@ from pydantic import BaseModel import viceroy -from pretor.core.workflow.workflow_template_generator.workflow_template import WorkflowTemplate from pretor.utils.ray_hook import ray_actor_hook from fastapi import APIRouter, Depends from pretor.utils.access import TokenData @@ -23,97 +22,83 @@ from pretor.core.database.table.user import UserAuthority resource_router = APIRouter(prefix="/api/v1/resource") -@resource_router.post("/workflow_template") -async def create_workflow_template(workflow_template: WorkflowTemplate, - _: TokenData = Depends(RoleChecker(allowed_roles=UserAuthority.USER))): - """处理针对 create workflow template 相关的 HTTP API 请求。 - 该接口负责解析前端传入的载荷数据,调用底层核心业务逻辑进行处理,并组装标准化的 JSON 响应。 - Args: workflow_template (WorkflowTemplate): 参与 create workflow template 逻辑运算或数据构建的上下文依赖对象。 _ (TokenData): 参与 create workflow template 逻辑运算或数据构建的上下文依赖对象。 - Returns: : 序列化后的标准网络响应模型(如包含业务状态码、成功标志及对应的数据载荷 Data)。 """ - global_state_machine = ray_actor_hook("global_state_machine").global_state_machine - await global_state_machine.add_workflow_template.remote( workflow_template.name, workflow_template) - return {"message": "创建成功"} - -@resource_router.get("/workflow_template") -async def get_workflow_templates(_: TokenData = Depends(RoleChecker(allowed_roles=UserAuthority.USER))): - """处理针对 get workflow templates 相关的 HTTP API 请求。 - 该接口负责解析前端传入的载荷数据,调用底层核心业务逻辑进行处理,并组装标准化的 JSON 响应。 - Args: _ (TokenData): 参与 get workflow templates 逻辑运算或数据构建的上下文依赖对象。 - Returns: : 序列化后的标准网络响应模型(如包含业务状态码、成功标志及对应的数据载荷 Data)。 """ - global_state_machine = ray_actor_hook("global_state_machine").global_state_machine - templates = await global_state_machine.get_all_workflow_templates.remote() - return {"templates": templates} - -@resource_router.delete("/workflow_template/{template_name}") -async def delete_workflow_template(template_name: str, _: TokenData = Depends(RoleChecker(allowed_roles=UserAuthority.SUPER_ADMINISTRATOR))): - """处理针对 delete workflow template 相关的 HTTP API 请求。 - 该接口负责解析前端传入的载荷数据,调用底层核心业务逻辑进行处理,并组装标准化的 JSON 响应。 - Args: template_name (str): 赋予该实体的人类可读名称或标题字符串,主要用于前端 UI 展示、日志记录或模糊检索。 _ (TokenData): 参与 delete workflow template 逻辑运算或数据构建的上下文依赖对象。 - Returns: : 序列化后的标准网络响应模型(如包含业务状态码、成功标志及对应的数据载荷 Data)。 """ - global_state_machine = ray_actor_hook("global_state_machine").global_state_machine - await global_state_machine.delete_workflow_template.remote( template_name) - return {"message": "success"} - - class Skill(BaseModel): """Skill 核心组件类。 - 这是一个领域数据模型或功能封装类,承载了 Skill 相关的内聚属性定义与状态维护。它的存在隔离了局部的业务复杂性,并对外提供了类型安全的访问接口。 """ + 这是一个领域数据模型或功能封装类,承载了 Skill 相关的内聚属性定义与状态维护。它的存在隔离了局部的业务复杂性,并对外提供了类型安全的访问接口。""" + repo_url: str path: str | None + @resource_router.post("/skill") -async def install_skill(skill: Skill, - _: TokenData = Depends(RoleChecker(allowed_roles=UserAuthority.USER))): +async def install_skill( + skill: Skill, _: TokenData = Depends(RoleChecker(allowed_roles=UserAuthority.USER)) +): """处理针对 install skill 相关的 HTTP API 请求。 该接口负责解析前端传入的载荷数据,调用底层核心业务逻辑进行处理,并组装标准化的 JSON 响应。 Args: skill (Skill): 参与 install skill 逻辑运算或数据构建的上下文依赖对象。 _ (TokenData): 参与 install skill 逻辑运算或数据构建的上下文依赖对象。 - Returns: : 序列化后的标准网络响应模型(如包含业务状态码、成功标志及对应的数据载荷 Data)。 """ + Returns: : 序列化后的标准网络响应模型(如包含业务状态码、成功标志及对应的数据载荷 Data)。""" global_state_machine = ray_actor_hook("global_state_machine").global_state_machine # noinspection PyUnresolvedReferences import os - skill_output_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "plugin", "skill")) + + skill_output_dir = os.path.abspath( + os.path.join(os.path.dirname(__file__), "..", "plugin", "skill") + ) os.makedirs(skill_output_dir, exist_ok=True) - await viceroy.install_skill_async(url = skill.repo_url, - path = skill.path, - output = skill_output_dir) + await viceroy.install_skill_async( + url=skill.repo_url, path=skill.path, output=skill_output_dir + ) if skill.path: skill_name = skill.path.split("/")[-1] else: skill_name = skill.repo_url.split("/")[-1] - await global_state_machine.add_skill.remote( skill_name) + await global_state_machine.add_skill.remote(skill_name) return {"message": "创建成功"} + @resource_router.get("/skill") -async def get_skills(_: TokenData = Depends(RoleChecker(allowed_roles=UserAuthority.USER))): +async def get_skills( + _: TokenData = Depends(RoleChecker(allowed_roles=UserAuthority.USER)), +): """处理针对 get skills 相关的 HTTP API 请求。 该接口负责解析前端传入的载荷数据,调用底层核心业务逻辑进行处理,并组装标准化的 JSON 响应。 Args: _ (TokenData): 参与 get skills 逻辑运算或数据构建的上下文依赖对象。 - Returns: : 序列化后的标准网络响应模型(如包含业务状态码、成功标志及对应的数据载荷 Data)。 """ + Returns: : 序列化后的标准网络响应模型(如包含业务状态码、成功标志及对应的数据载荷 Data)。""" global_state_machine = ray_actor_hook("global_state_machine").global_state_machine skills = await global_state_machine.get_skill_list.remote() return {"skills": skills} + @resource_router.delete("/skill/{skill_name}") -async def delete_skill(skill_name: str, _: TokenData = Depends(RoleChecker(allowed_roles=UserAuthority.SUPER_ADMINISTRATOR))): +async def delete_skill( + skill_name: str, + _: TokenData = Depends( + RoleChecker(allowed_roles=UserAuthority.SUPER_ADMINISTRATOR) + ), +): """处理针对 delete skill 相关的 HTTP API 请求。 该接口负责解析前端传入的载荷数据,调用底层核心业务逻辑进行处理,并组装标准化的 JSON 响应。 Args: skill_name (str): 赋予该实体的人类可读名称或标题字符串,主要用于前端 UI 展示、日志记录或模糊检索。 _ (TokenData): 参与 delete skill 逻辑运算或数据构建的上下文依赖对象。 - Returns: : 序列化后的标准网络响应模型(如包含业务状态码、成功标志及对应的数据载荷 Data)。 """ + Returns: : 序列化后的标准网络响应模型(如包含业务状态码、成功标志及对应的数据载荷 Data)。""" global_state_machine = ray_actor_hook("global_state_machine").global_state_machine # Note: this only removes it from the state machine manager. - await global_state_machine.remove_skill.remote( skill_name) + await global_state_machine.remove_skill.remote(skill_name) return {"message": "success"} + @resource_router.get("/tool") -async def get_tools(_: TokenData = Depends(RoleChecker(allowed_roles=UserAuthority.USER))): +async def get_tools( + _: TokenData = Depends(RoleChecker(allowed_roles=UserAuthority.USER)), +): """处理针对 get tools 相关的 HTTP API 请求。 该接口负责解析前端传入的载荷数据,调用底层核心业务逻辑进行处理,并组装标准化的 JSON 响应。 Args: _ (TokenData): 参与 get tools 逻辑运算或数据构建的上下文依赖对象。 - Returns: : 序列化后的标准网络响应模型(如包含业务状态码、成功标志及对应的数据载荷 Data)。 """ + Returns: : 序列化后的标准网络响应模型(如包含业务状态码、成功标志及对应的数据载荷 Data)。""" global_state_machine = ray_actor_hook("global_state_machine").global_state_machine tool_mapper = await global_state_machine.get_tool_mapper.remote() all_tool_names = set() for scope_tools in tool_mapper.values(): all_tool_names.update(scope_tools.keys()) - return {"tools": list(all_tool_names)} \ No newline at end of file + return {"tools": list(all_tool_names)} diff --git a/pretor/api/workflow.py b/pretor/api/workflow.py index 8e2012a..5a1eab9 100644 --- a/pretor/api/workflow.py +++ b/pretor/api/workflow.py @@ -20,12 +20,15 @@ import asyncio workflow_router = APIRouter(prefix="/api/v1/workflow", tags=["workflow"]) + @workflow_router.get("/list") async def get_workflow_list(): """处理针对 get workflow list 相关的 HTTP API 请求。 该接口负责解析前端传入的载荷数据,调用底层核心业务逻辑进行处理,并组装标准化的 JSON 响应。 - Returns: : 序列化后的标准网络响应模型(如包含业务状态码、成功标志及对应的数据载荷 Data)。 """ - global_workflow_manager = ray_actor_hook("global_workflow_manager").global_workflow_manager + Returns: : 序列化后的标准网络响应模型(如包含业务状态码、成功标志及对应的数据载荷 Data)。""" + global_workflow_manager = ray_actor_hook( + "global_workflow_manager" + ).global_workflow_manager events = await global_workflow_manager.list_events.remote() return events @@ -35,8 +38,10 @@ async def get_workflow_detail(trace_id: str): """处理针对 get workflow detail 相关的 HTTP API 请求。 该接口负责解析前端传入的载荷数据,调用底层核心业务逻辑进行处理,并组装标准化的 JSON 响应。 Args: trace_id (str): 目标对象的唯一全局标识符 (UUID/ULID),用于在数据库表或缓存结构中精准匹配该 trace 实例。 - Returns: : 序列化后的标准网络响应模型(如包含业务状态码、成功标志及对应的数据载荷 Data)。 """ - global_workflow_manager = ray_actor_hook("global_workflow_manager").global_workflow_manager + Returns: : 序列化后的标准网络响应模型(如包含业务状态码、成功标志及对应的数据载荷 Data)。""" + global_workflow_manager = ray_actor_hook( + "global_workflow_manager" + ).global_workflow_manager event = await global_workflow_manager.get_event.remote(trace_id) if not event: raise HTTPException(status_code=404, detail="Workflow not found") @@ -55,15 +60,17 @@ async def get_workflow_detail(trace_id: str): steps = [] for step in workflow.work_link: - steps.append({ - "step": step.step, - "name": step.name, - "node": step.node, - "action": step.action, - "desc": step.desc, - "status": step.status, - "agent_id": step.agent_id, - }) + steps.append( + { + "step": step.step, + "name": step.name, + "node": step.node, + "action": step.action, + "desc": step.desc, + "status": step.status, + "agent_id": step.agent_id, + } + ) return { "event_id": trace_id, "workflow_title": workflow.title, @@ -76,17 +83,20 @@ async def get_workflow_detail(trace_id: str): "steps": steps, } + @workflow_router.get("/sse/{trace_id}") async def get_workflow_sse(trace_id: str, request: Request): """处理针对 get workflow sse 相关的 HTTP API 请求。 该接口负责解析前端传入的载荷数据,调用底层核心业务逻辑进行处理,并组装标准化的 JSON 响应。 Args: trace_id (str): 目标对象的唯一全局标识符 (UUID/ULID),用于在数据库表或缓存结构中精准匹配该 trace 实例。 request (Request): FastAPI 框架注入的原生 HTTP 请求对象,包含了完整的 Header 标头、查询参数和正文流。 - Returns: : 序列化后的标准网络响应模型(如包含业务状态码、成功标志及对应的数据载荷 Data)。 """ - global_workflow_manager = ray_actor_hook("global_workflow_manager").global_workflow_manager + Returns: : 序列化后的标准网络响应模型(如包含业务状态码、成功标志及对应的数据载荷 Data)。""" + global_workflow_manager = ray_actor_hook( + "global_workflow_manager" + ).global_workflow_manager async def event_generator(): """执行与 event generator 相关的核心业务流转操作。 - 该方法封装了具体的算法策略或状态控制逻辑,确保操作能够在事务上下文中被原子且一致地执行。 """ + 该方法封装了具体的算法策略或状态控制逻辑,确保操作能够在事务上下文中被原子且一致地执行。""" try: while True: if await request.is_disconnected(): @@ -102,15 +112,17 @@ async def get_workflow_sse(trace_id: str, request: Request): return StreamingResponse(event_generator(), media_type="text/event-stream") + @workflow_router.post("/reply/{trace_id}") async def post_workflow_reply(trace_id: str, request: Request): """处理针对 post workflow reply 相关的 HTTP API 请求。 该接口负责解析前端传入的载荷数据,调用底层核心业务逻辑进行处理,并组装标准化的 JSON 响应。 Args: trace_id (str): 目标对象的唯一全局标识符 (UUID/ULID),用于在数据库表或缓存结构中精准匹配该 trace 实例。 request (Request): FastAPI 框架注入的原生 HTTP 请求对象,包含了完整的 Header 标头、查询参数和正文流。 - Returns: : 序列化后的标准网络响应模型(如包含业务状态码、成功标志及对应的数据载荷 Data)。 """ + Returns: : 序列化后的标准网络响应模型(如包含业务状态码、成功标志及对应的数据载荷 Data)。""" data = await request.json() reply_msg = data.get("message", "") - global_workflow_manager = ray_actor_hook("global_workflow_manager").global_workflow_manager + global_workflow_manager = ray_actor_hook( + "global_workflow_manager" + ).global_workflow_manager await global_workflow_manager.put_received.remote(trace_id, reply_msg) return {"status": "ok"} - diff --git a/pretor/core/database/database_exception.py b/pretor/core/database/database_exception.py index 9fd8322..7d7a2c6 100644 --- a/pretor/core/database/database_exception.py +++ b/pretor/core/database/database_exception.py @@ -17,16 +17,20 @@ from pydantic import ValidationError from pretor.utils.error import UserNotExistError from pretor.utils.logger import get_logger -logger = get_logger('database_exception') + +logger = get_logger("database_exception") + + def database_exception(func): """执行与 database exception 相关的核心业务流转操作。 该方法封装了具体的算法策略或状态控制逻辑,确保操作能够在事务上下文中被原子且一致地执行。 Args: func: 参与 database exception 逻辑运算或数据构建的上下文依赖对象。 - Returns: : 经由当前业务模型加工处理后所输出的具体数据实例或领域模型对象。 """ + Returns: : 经由当前业务模型加工处理后所输出的具体数据实例或领域模型对象。""" + async def wrapper(*args, **kwargs): """执行与 wrapper 相关的核心业务流转操作。 该方法封装了具体的算法策略或状态控制逻辑,确保操作能够在事务上下文中被原子且一致地执行。 - Returns: : 经由当前业务模型加工处理后所输出的具体数据实例或领域模型对象。 """ + Returns: : 经由当前业务模型加工处理后所输出的具体数据实例或领域模型对象。""" try: return await func(*args, **kwargs) except ValidationError as e: @@ -43,4 +47,5 @@ def database_exception(func): except Exception as e: logger.exception(f"未预期的数据库错误: {e}") raise e - return wrapper \ No newline at end of file + + return wrapper diff --git a/pretor/core/database/module/event.py b/pretor/core/database/module/event.py index ced4f8d..a258288 100644 --- a/pretor/core/database/module/event.py +++ b/pretor/core/database/module/event.py @@ -3,6 +3,7 @@ from typing import List, Optional from pretor.core.database.table.event import EventRecord from sqlalchemy.ext.asyncio import async_sessionmaker, AsyncSession + class EventDatabase: def __init__(self, async_session_maker: async_sessionmaker[AsyncSession]): self.async_session_maker = async_session_maker diff --git a/pretor/core/database/module/individual.py b/pretor/core/database/module/individual.py index 1fcb03a..57288d8 100644 --- a/pretor/core/database/module/individual.py +++ b/pretor/core/database/module/individual.py @@ -19,9 +19,11 @@ from pretor.core.database.database_exception import database_exception from ulid import ULID + class IndividualDatabase: """IndividualDatabase 核心组件类。 - 这是一个数据库操作层 (DAO/Repository) 封装类,专注于处理实体模型与关系型数据库表之间的映射。它将复杂的 SQL 查询、跨表 Join 和事务回滚逻辑进行了高级抽象,向上层服务暴露简洁的数据读写接口。 """ + 这是一个数据库操作层 (DAO/Repository) 封装类,专注于处理实体模型与关系型数据库表之间的映射。它将复杂的 SQL 查询、跨表 Join 和事务回滚逻辑进行了高级抽象,向上层服务暴露简洁的数据读写接口。""" + def __init__(self, async_session_maker): self.async_session_maker = async_session_maker @@ -29,7 +31,7 @@ class IndividualDatabase: async def add_worker_individual(self, **kwargs) -> WorkerIndividual: """创建并持久化新的 worker individual 实体。 接收构建参数,执行必要的数据校验与默认值填充后,将新记录安全地写入底层存储或系统注册表中。 - Returns: (WorkerIndividual): 经由当前业务模型加工处理后所输出的具体数据实例或领域模型对象。 """ + Returns: (WorkerIndividual): 经由当前业务模型加工处理后所输出的具体数据实例或领域模型对象。""" async with self.async_session_maker() as session: agent_id = str(ULID()) individual = WorkerIndividual(agent_id=agent_id, **kwargs) @@ -43,9 +45,11 @@ class IndividualDatabase: """检索并获取特定的 worker individual 数据集合或实例对象。 根据提供的查询条件或上下文凭证,从数据库、缓存或第三方服务中读取对应的资源状态。 Args: agent_id (str): 目标对象的唯一全局标识符 (UUID/ULID),用于在数据库表或缓存结构中精准匹配该 agent 实例。 - Returns: (Optional[WorkerIndividual]): 经由当前业务模型加工处理后所输出的具体数据实例或领域模型对象。 """ + Returns: (Optional[WorkerIndividual]): 经由当前业务模型加工处理后所输出的具体数据实例或领域模型对象。""" async with self.async_session_maker() as session: - statement = select(WorkerIndividual).where(WorkerIndividual.agent_id == agent_id) + statement = select(WorkerIndividual).where( + WorkerIndividual.agent_id == agent_id + ) results = await session.execute(statement) return results.scalar_one_or_none() @@ -54,20 +58,26 @@ class IndividualDatabase: """检索并获取特定的 worker individual list 数据集合或实例对象。 根据提供的查询条件或上下文凭证,从数据库、缓存或第三方服务中读取对应的资源状态。 Args: owner_id (str): 目标对象的唯一全局标识符 (UUID/ULID),用于在数据库表或缓存结构中精准匹配该 owner 实例。 - Returns: (List[WorkerIndividual]): 经过筛选、排序或分页处理后的实体对象列表集合。 """ + Returns: (List[WorkerIndividual]): 经过筛选、排序或分页处理后的实体对象列表集合。""" async with self.async_session_maker() as session: - statement = select(WorkerIndividual).where(WorkerIndividual.owner_id == owner_id) + statement = select(WorkerIndividual).where( + WorkerIndividual.owner_id == owner_id + ) results = await session.execute(statement) return list(results.scalars().all()) @database_exception - async def update_worker_individual(self, agent_id: str, **kwargs) -> Optional[WorkerIndividual]: + async def update_worker_individual( + self, agent_id: str, **kwargs + ) -> Optional[WorkerIndividual]: """对现有的 worker individual 进行状态更新或属性覆盖。 基于增量变更原则,合并最新的配置或数据,并触发相关依赖组件的缓存刷新或事件通知。 Args: agent_id (str): 目标对象的唯一全局标识符 (UUID/ULID),用于在数据库表或缓存结构中精准匹配该 agent 实例。 - Returns: (Optional[WorkerIndividual]): 经由当前业务模型加工处理后所输出的具体数据实例或领域模型对象。 """ + Returns: (Optional[WorkerIndividual]): 经由当前业务模型加工处理后所输出的具体数据实例或领域模型对象。""" async with self.async_session_maker() as session: - statement = select(WorkerIndividual).where(WorkerIndividual.agent_id == agent_id) + statement = select(WorkerIndividual).where( + WorkerIndividual.agent_id == agent_id + ) results = await session.execute(statement) individual = results.scalar_one_or_none() if not individual: @@ -85,9 +95,11 @@ class IndividualDatabase: """安全地移除或注销 worker individual。 执行物理删除或逻辑删除操作,并妥善清理相关的关联数据及占用资源。 Args: agent_id (str): 目标对象的唯一全局标识符 (UUID/ULID),用于在数据库表或缓存结构中精准匹配该 agent 实例。 - Returns: (bool): 一个布尔型结果标志,明确返回 True 表示该操作成功应用或条件达成,False 则表示失败或被拒绝。 """ + Returns: (bool): 一个布尔型结果标志,明确返回 True 表示该操作成功应用或条件达成,False 则表示失败或被拒绝。""" async with self.async_session_maker() as session: - statement = select(WorkerIndividual).where(WorkerIndividual.agent_id == agent_id) + statement = select(WorkerIndividual).where( + WorkerIndividual.agent_id == agent_id + ) results = await session.execute(statement) individual = results.scalar_one_or_none() if not individual: @@ -100,8 +112,8 @@ class IndividualDatabase: async def get_all_worker_individual(self) -> List[WorkerIndividual]: """检索并获取特定的 all worker individual 数据集合或实例对象。 根据提供的查询条件或上下文凭证,从数据库、缓存或第三方服务中读取对应的资源状态。 - Returns: (List[WorkerIndividual]): 经过筛选、排序或分页处理后的实体对象列表集合。 """ + Returns: (List[WorkerIndividual]): 经过筛选、排序或分页处理后的实体对象列表集合。""" async with self.async_session_maker() as session: statement = select(WorkerIndividual) results = await session.execute(statement) - return list(results.scalars().all()) \ No newline at end of file + return list(results.scalars().all()) diff --git a/pretor/core/database/module/provider.py b/pretor/core/database/module/provider.py index 6bb3297..01331af 100644 --- a/pretor/core/database/module/provider.py +++ b/pretor/core/database/module/provider.py @@ -18,9 +18,11 @@ from pretor.core.database.table.provider import Provider from sqlmodel import select from pretor.core.database.database_exception import database_exception + class ProviderDatabase: """ProviderDatabase 核心组件类。 - 这是一个模型/服务提供商适配器类,屏蔽了外部不同供应商(如 OpenAI、Anthropic 等)的底层 API 差异。它负责标准化参数组装、网络请求发送、鉴权处理以及响应结构的反序列化。 """ + 这是一个模型/服务提供商适配器类,屏蔽了外部不同供应商(如 OpenAI、Anthropic 等)的底层 API 差异。它负责标准化参数组装、网络请求发送、鉴权处理以及响应结构的反序列化。""" + def __init__(self, async_session_maker): self.async_session_maker = async_session_maker @@ -28,23 +30,28 @@ class ProviderDatabase: async def get_provider(self) -> List[Provider]: """检索并获取特定的 provider 数据集合或实例对象。 根据提供的查询条件或上下文凭证,从数据库、缓存或第三方服务中读取对应的资源状态。 - Returns: (List[Provider]): 经过筛选、排序或分页处理后的实体对象列表集合。 """ + Returns: (List[Provider]): 经过筛选、排序或分页处理后的实体对象列表集合。""" async with self.async_session_maker() as session: statement = select(Provider) results = await session.execute(statement) results = results.scalars().all() - providers = [Provider(provider_title=provider.provider_title, - provider_url=provider.provider_url, - provider_apikey=provider.provider_apikey, - provider_models=provider.provider_models, - provider_type=provider.provider_type) for provider in results] + providers = [ + Provider( + provider_title=provider.provider_title, + provider_url=provider.provider_url, + provider_apikey=provider.provider_apikey, + provider_models=provider.provider_models, + provider_type=provider.provider_type, + ) + for provider in results + ] return providers @database_exception async def add_provider(self, **kwargs) -> None: """创建并持久化新的 provider 实体。 接收构建参数,执行必要的数据校验与默认值填充后,将新记录安全地写入底层存储或系统注册表中。 - Returns: (None): 经由当前业务模型加工处理后所输出的具体数据实例或领域模型对象。 """ + Returns: (None): 经由当前业务模型加工处理后所输出的具体数据实例或领域模型对象。""" async with self.async_session_maker() as session: provider = Provider(**kwargs) session.add(provider) @@ -55,7 +62,7 @@ class ProviderDatabase: """安全地移除或注销 provider。 执行物理删除或逻辑删除操作,并妥善清理相关的关联数据及占用资源。 Args: provider_id (str): 目标对象的唯一全局标识符 (UUID/ULID),用于在数据库表或缓存结构中精准匹配该 provider 实例。 - Returns: (None): 经由当前业务模型加工处理后所输出的具体数据实例或领域模型对象。 """ + Returns: (None): 经由当前业务模型加工处理后所输出的具体数据实例或领域模型对象。""" async with self.async_session_maker() as session: provider = await session.get(Provider, provider_id) if provider is not None: @@ -67,7 +74,7 @@ class ProviderDatabase: """对现有的 provider 进行状态更新或属性覆盖。 基于增量变更原则,合并最新的配置或数据,并触发相关依赖组件的缓存刷新或事件通知。 Args: provider_id (str): 目标对象的唯一全局标识符 (UUID/ULID),用于在数据库表或缓存结构中精准匹配该 provider 实例。 - Returns: (Provider): 经由当前业务模型加工处理后所输出的具体数据实例或领域模型对象。 """ + Returns: (Provider): 经由当前业务模型加工处理后所输出的具体数据实例或领域模型对象。""" async with self.async_session_maker() as session: provider = await session.get(Provider, provider_id) if provider is not None: @@ -77,4 +84,4 @@ class ProviderDatabase: await session.commit() await session.refresh(provider) return provider - return None \ No newline at end of file + return None diff --git a/pretor/core/database/module/system_node.py b/pretor/core/database/module/system_node.py index 4a7a872..50d43a0 100644 --- a/pretor/core/database/module/system_node.py +++ b/pretor/core/database/module/system_node.py @@ -17,20 +17,30 @@ from sqlmodel import select from typing import List, Optional from pretor.core.database.database_exception import database_exception + class SystemNodeDatabase: """SystemNodeDatabase 核心组件类。 - 这是一个系统执行节点类,作为多智能体架构中的独立处理单元。它能够接收工作流上下文,根据内置的大模型策略进行意图理解和自主决策,从而驱动特定阶段的任务闭环。 """ + 这是一个系统执行节点类,作为多智能体架构中的独立处理单元。它能够接收工作流上下文,根据内置的大模型策略进行意图理解和自主决策,从而驱动特定阶段的任务闭环。""" + def __init__(self, async_session_maker): self.async_session_maker = async_session_maker @database_exception - async def upsert_system_node_config(self, node_name: str, provider_title: str, model_id: str, tools: Optional[List[str]] = None) -> SystemNodeConfig: + async def upsert_system_node_config( + self, + node_name: str, + provider_title: str, + model_id: str, + tools: Optional[List[str]] = None, + ) -> SystemNodeConfig: """执行与 upsert system node config 相关的核心业务流转操作。 该方法封装了具体的算法策略或状态控制逻辑,确保操作能够在事务上下文中被原子且一致地执行。 Args: node_name (str): 赋予该实体的人类可读名称或标题字符串,主要用于前端 UI 展示、日志记录或模糊检索。 provider_title (str): 目标对象的唯一全局标识符 (UUID/ULID),用于在数据库表或缓存结构中精准匹配该 provider_title 实例。 model_id (str): 目标对象的唯一全局标识符 (UUID/ULID),用于在数据库表或缓存结构中精准匹配该 model 实例。 tools (Optional[List[str]]): 控制逻辑流向的具体字符串参数,指定了期望的 tools 内容。 - Returns: (SystemNodeConfig): 经由当前业务模型加工处理后所输出的具体数据实例或领域模型对象。 """ + Returns: (SystemNodeConfig): 经由当前业务模型加工处理后所输出的具体数据实例或领域模型对象。""" async with self.async_session_maker() as session: - statement = select(SystemNodeConfig).where(SystemNodeConfig.node_name == node_name) + statement = select(SystemNodeConfig).where( + SystemNodeConfig.node_name == node_name + ) results = await session.execute(statement) config = results.scalar_one_or_none() if config: @@ -39,7 +49,12 @@ class SystemNodeDatabase: if tools is not None: config.tools = tools else: - config = SystemNodeConfig(node_name=node_name, provider_title=provider_title, model_id=model_id, tools=tools) + config = SystemNodeConfig( + node_name=node_name, + provider_title=provider_title, + model_id=model_id, + tools=tools, + ) session.add(config) await session.commit() await session.refresh(config) @@ -49,19 +64,23 @@ class SystemNodeDatabase: async def get_all_system_node_configs(self) -> List[SystemNodeConfig]: """检索并获取特定的 all system node configs 数据集合或实例对象。 根据提供的查询条件或上下文凭证,从数据库、缓存或第三方服务中读取对应的资源状态。 - Returns: (List[SystemNodeConfig]): 经过筛选、排序或分页处理后的实体对象列表集合。 """ + Returns: (List[SystemNodeConfig]): 经过筛选、排序或分页处理后的实体对象列表集合。""" async with self.async_session_maker() as session: statement = select(SystemNodeConfig) results = await session.execute(statement) return list(results.scalars().all()) @database_exception - async def get_system_node_config(self, node_name: str) -> Optional[SystemNodeConfig]: + async def get_system_node_config( + self, node_name: str + ) -> Optional[SystemNodeConfig]: """检索并获取特定的 system node config 数据集合或实例对象。 根据提供的查询条件或上下文凭证,从数据库、缓存或第三方服务中读取对应的资源状态。 Args: node_name (str): 赋予该实体的人类可读名称或标题字符串,主要用于前端 UI 展示、日志记录或模糊检索。 - Returns: (Optional[SystemNodeConfig]): 经由当前业务模型加工处理后所输出的具体数据实例或领域模型对象。 """ + Returns: (Optional[SystemNodeConfig]): 经由当前业务模型加工处理后所输出的具体数据实例或领域模型对象。""" async with self.async_session_maker() as session: - statement = select(SystemNodeConfig).where(SystemNodeConfig.node_name == node_name) + statement = select(SystemNodeConfig).where( + SystemNodeConfig.node_name == node_name + ) results = await session.execute(statement) return results.scalar_one_or_none() diff --git a/pretor/core/database/module/user.py b/pretor/core/database/module/user.py index 06463ed..3f98c13 100644 --- a/pretor/core/database/module/user.py +++ b/pretor/core/database/module/user.py @@ -19,9 +19,11 @@ from pretor.core.database.database_exception import database_exception from pretor.core.database.table.user import UserAuthority from pretor.utils.access import Accessor + class AuthDatabase: """AuthDatabase 核心组件类。 - 这是一个数据库操作层 (DAO/Repository) 封装类,专注于处理实体模型与关系型数据库表之间的映射。它将复杂的 SQL 查询、跨表 Join 和事务回滚逻辑进行了高级抽象,向上层服务暴露简洁的数据读写接口。 """ + 这是一个数据库操作层 (DAO/Repository) 封装类,专注于处理实体模型与关系型数据库表之间的映射。它将复杂的 SQL 查询、跨表 Join 和事务回滚逻辑进行了高级抽象,向上层服务暴露简洁的数据读写接口。""" + def __init__(self, async_session_maker): self.async_session_maker = async_session_maker @@ -30,8 +32,9 @@ class AuthDatabase: """创建并持久化新的 user 实体。 接收构建参数,执行必要的数据校验与默认值填充后,将新记录安全地写入底层存储或系统注册表中。 Args: user_name (str): 赋予该实体的人类可读名称或标题字符串,主要用于前端 UI 展示、日志记录或模糊检索。 hashed_password (str): 控制逻辑流向的具体字符串参数,指定了期望的 hashed password 内容。 - Returns: (User): 经由当前业务模型加工处理后所输出的具体数据实例或领域模型对象。 """ + Returns: (User): 经由当前业务模型加工处理后所输出的具体数据实例或领域模型对象。""" from ulid import ULID + async with self.async_session_maker() as session: # Check if any users exist statement = select(User).limit(1) @@ -46,7 +49,7 @@ class AuthDatabase: user_id=str(ULID()), user_name=user_name, hashed_password=hashed_password, - user_authority=authority + user_authority=authority, ) session.add(user) await session.commit() @@ -58,7 +61,7 @@ class AuthDatabase: """执行与 change password 相关的核心业务流转操作。 该方法封装了具体的算法策略或状态控制逻辑,确保操作能够在事务上下文中被原子且一致地执行。 Args: user_name: 赋予该实体的人类可读名称或标题字符串,主要用于前端 UI 展示、日志记录或模糊检索。 old_password: 参与 change password 逻辑运算或数据构建的上下文依赖对象。 new_password: 参与 change password 逻辑运算或数据构建的上下文依赖对象。 - Returns: (User): 经由当前业务模型加工处理后所输出的具体数据实例或领域模型对象。 """ + Returns: (User): 经由当前业务模型加工处理后所输出的具体数据实例或领域模型对象。""" async with self.async_session_maker() as session: statement = select(User).where(User.user_name == user_name) results = await session.execute(statement) @@ -78,7 +81,7 @@ class AuthDatabase: """安全地移除或注销 user。 执行物理删除或逻辑删除操作,并妥善清理相关的关联数据及占用资源。 Args: user_name (str): 赋予该实体的人类可读名称或标题字符串,主要用于前端 UI 展示、日志记录或模糊检索。 - Returns: (None): 经由当前业务模型加工处理后所输出的具体数据实例或领域模型对象。 """ + Returns: (None): 经由当前业务模型加工处理后所输出的具体数据实例或领域模型对象。""" async with self.async_session_maker() as session: statement = select(User).where(User.user_name == user_name) results = await session.execute(statement) @@ -93,7 +96,7 @@ class AuthDatabase: """安全地移除或注销 user by id。 执行物理删除或逻辑删除操作,并妥善清理相关的关联数据及占用资源。 Args: user_id (str): 目标对象的唯一全局标识符 (UUID/ULID),用于在数据库表或缓存结构中精准匹配该 user 实例。 - Returns: (None): 经由当前业务模型加工处理后所输出的具体数据实例或领域模型对象。 """ + Returns: (None): 经由当前业务模型加工处理后所输出的具体数据实例或领域模型对象。""" async with self.async_session_maker() as session: user = await session.get(User, user_id) if user is None: @@ -106,7 +109,7 @@ class AuthDatabase: """执行与 login user 相关的核心业务流转操作。 该方法封装了具体的算法策略或状态控制逻辑,确保操作能够在事务上下文中被原子且一致地执行。 Args: user_name (str): 赋予该实体的人类可读名称或标题字符串,主要用于前端 UI 展示、日志记录或模糊检索。 - Returns: (str): 处理流程所输出的具体字符串产物,可能是新生成的 ID 序列、格式化好的文本片段或 LLM 推理的回答内容。 """ + Returns: (str): 处理流程所输出的具体字符串产物,可能是新生成的 ID 序列、格式化好的文本片段或 LLM 推理的回答内容。""" async with self.async_session_maker() as session: statement = select(User).where(User.user_name == user_name) results = await session.execute(statement) @@ -119,7 +122,7 @@ class AuthDatabase: async def get_all_users(self) -> list[User]: """检索并获取特定的 all users 数据集合或实例对象。 根据提供的查询条件或上下文凭证,从数据库、缓存或第三方服务中读取对应的资源状态。 - Returns: (list[User]): 经过筛选、排序或分页处理后的实体对象列表集合。 """ + Returns: (list[User]): 经过筛选、排序或分页处理后的实体对象列表集合。""" async with self.async_session_maker() as session: statement = select(User) results = await session.execute(statement) @@ -131,7 +134,7 @@ class AuthDatabase: """检索并获取特定的 user authority 数据集合或实例对象。 根据提供的查询条件或上下文凭证,从数据库、缓存或第三方服务中读取对应的资源状态。 Args: user_id (str): 目标对象的唯一全局标识符 (UUID/ULID),用于在数据库表或缓存结构中精准匹配该 user 实例。 - Returns: (UserAuthority): 经由当前业务模型加工处理后所输出的具体数据实例或领域模型对象。 """ + Returns: (UserAuthority): 经由当前业务模型加工处理后所输出的具体数据实例或领域模型对象。""" async with self.async_session_maker() as session: user = await session.get(User, user_id) if user is None: @@ -139,7 +142,9 @@ class AuthDatabase: return user.user_authority @database_exception - async def change_user_authority(self, user_id: str, new_authority: UserAuthority) -> User: + async def change_user_authority( + self, user_id: str, new_authority: UserAuthority + ) -> User: """ Changes the authority level of a specific user. @@ -161,4 +166,4 @@ class AuthDatabase: session.add(user) await session.commit() await session.refresh(user) - return user \ No newline at end of file + return user diff --git a/pretor/core/database/table/__init__.py b/pretor/core/database/table/__init__.py index 09008d9..9f6eda0 100644 --- a/pretor/core/database/table/__init__.py +++ b/pretor/core/database/table/__init__.py @@ -15,4 +15,5 @@ from pretor.core.database.table.user import User from pretor.core.database.table.provider import Provider from pretor.core.database.table.individual import WorkerIndividual + __all__ = ["User", "Provider", "WorkerIndividual"] diff --git a/pretor/core/database/table/event.py b/pretor/core/database/table/event.py index 2c5fdd9..f91d369 100644 --- a/pretor/core/database/table/event.py +++ b/pretor/core/database/table/event.py @@ -1,5 +1,8 @@ from sqlmodel import SQLModel, Field + class EventRecord(SQLModel, table=True): - trace_id: str = Field(primary_key=True, description="The unique trace ID of the PretorEvent") + trace_id: str = Field( + primary_key=True, description="The unique trace ID of the PretorEvent" + ) event_data_json: str = Field(description="The JSON serialized PretorEvent data") diff --git a/pretor/core/database/table/individual.py b/pretor/core/database/table/individual.py index 10f97c9..6c6849d 100644 --- a/pretor/core/database/table/individual.py +++ b/pretor/core/database/table/individual.py @@ -17,16 +17,20 @@ from typing import List, Optional from sqlalchemy import Column, JSON from enum import Enum + class AgentType(str, Enum): """AgentType 核心组件类。 - 这是一个领域数据模型或功能封装类,承载了 AgentType 相关的内聚属性定义与状态维护。它的存在隔离了局部的业务复杂性,并对外提供了类型安全的访问接口。 """ + 这是一个领域数据模型或功能封装类,承载了 AgentType 相关的内聚属性定义与状态维护。它的存在隔离了局部的业务复杂性,并对外提供了类型安全的访问接口。""" + SKILL_INDIVIDUAL = "skill_individual" ORDINARY_INDIVIDUAL = "ordinary_individual" SPECIAL_INDIVIDUAL = "special_individual" + class WorkerIndividual(SQLModel, table=True): """WorkerIndividual 核心组件类。 - 这是一个具体的 Worker 智能体实体类,代表着具备特定人设、领域技能或长文本处理能力的数字员工。它可以被控制器动态拉起,并在安全沙箱内执行复杂的工作流指令与多步骤推理任务。 """ + 这是一个具体的 Worker 智能体实体类,代表着具备特定人设、领域技能或长文本处理能力的数字员工。它可以被控制器动态拉起,并在安全沙箱内执行复杂的工作流指令与多步骤推理任务。""" + __tablename__ = "worker_individual" agent_id: str = Field(primary_key=True) agent_name: str = Field(index=True) @@ -35,8 +39,10 @@ class WorkerIndividual(SQLModel, table=True): provider_title: str model_id: str system_prompt: Optional[str] - output_template: Optional[dict] = Field(sa_column=Column(JSON),description="输出模板标识") + output_template: Optional[dict] = Field( + sa_column=Column(JSON), description="输出模板标识" + ) bound_skill: Optional[str] = Field(sa_column=Column(JSON)) workspace: Optional[List[str]] = Field(sa_column=Column(JSON)) tools: Optional[List[str]] = Field(sa_column=Column(JSON), default=None) - owner_id: str \ No newline at end of file + owner_id: str diff --git a/pretor/core/database/table/provider.py b/pretor/core/database/table/provider.py index 9f9ccc6..ad235f7 100644 --- a/pretor/core/database/table/provider.py +++ b/pretor/core/database/table/provider.py @@ -17,9 +17,11 @@ from typing import List from sqlalchemy import Column, JSON from typing import Optional + class Provider(SQLModel, table=True): """Provider 核心组件类。 - 这是一个模型/服务提供商适配器类,屏蔽了外部不同供应商(如 OpenAI、Anthropic 等)的底层 API 差异。它负责标准化参数组装、网络请求发送、鉴权处理以及响应结构的反序列化。 """ + 这是一个模型/服务提供商适配器类,屏蔽了外部不同供应商(如 OpenAI、Anthropic 等)的底层 API 差异。它负责标准化参数组装、网络请求发送、鉴权处理以及响应结构的反序列化。""" + __tablename__ = "provider" provider_id: str = Field(primary_key=True) provider_title: str = Field(index=True) @@ -31,4 +33,4 @@ class Provider(SQLModel, table=True): provider_models: List[str] = Field(sa_column=Column(JSON)) provider_owner: str - is_active: bool = Field(default=True, description="该服务商节点是否在线/启用") \ No newline at end of file + is_active: bool = Field(default=True, description="该服务商节点是否在线/启用") diff --git a/pretor/core/database/table/system_node.py b/pretor/core/database/table/system_node.py index 45789df..59e9a0f 100644 --- a/pretor/core/database/table/system_node.py +++ b/pretor/core/database/table/system_node.py @@ -17,9 +17,11 @@ from sqlmodel import SQLModel, Field from typing import List, Optional from sqlalchemy import Column, JSON + class SystemNodeConfig(SQLModel, table=True): """SystemNodeConfig 核心组件类。 - 这是一个系统执行节点类,作为多智能体架构中的独立处理单元。它能够接收工作流上下文,根据内置的大模型策略进行意图理解和自主决策,从而驱动特定阶段的任务闭环。 """ + 这是一个系统执行节点类,作为多智能体架构中的独立处理单元。它能够接收工作流上下文,根据内置的大模型策略进行意图理解和自主决策,从而驱动特定阶段的任务闭环。""" + __tablename__ = "system_node_config" node_name: str = Field(primary_key=True) provider_title: str diff --git a/pretor/core/database/table/user.py b/pretor/core/database/table/user.py index 3a23c57..782d330 100644 --- a/pretor/core/database/table/user.py +++ b/pretor/core/database/table/user.py @@ -15,21 +15,24 @@ from sqlmodel import SQLModel, Field from enum import IntEnum + class UserAuthority(IntEnum): """UserAuthority 核心组件类。 - 这是一个领域数据模型或功能封装类,承载了 UserAuthority 相关的内聚属性定义与状态维护。它的存在隔离了局部的业务复杂性,并对外提供了类型安全的访问接口。 """ + 这是一个领域数据模型或功能封装类,承载了 UserAuthority 相关的内聚属性定义与状态维护。它的存在隔离了局部的业务复杂性,并对外提供了类型安全的访问接口。""" + SUPER_ADMINISTRATOR = 100 ADMINISTRATOR = 50 USER = 20 UNAUTHORIZED_USER = 10 GUEST = 0 + class User(SQLModel, table=True): """User 核心组件类。 - 这是一个领域数据模型或功能封装类,承载了 User 相关的内聚属性定义与状态维护。它的存在隔离了局部的业务复杂性,并对外提供了类型安全的访问接口。 """ - __tablename__ = 'user' + 这是一个领域数据模型或功能封装类,承载了 User 相关的内聚属性定义与状态维护。它的存在隔离了局部的业务复杂性,并对外提供了类型安全的访问接口。""" + + __tablename__ = "user" user_id: str = Field(primary_key=True) user_name: str = Field(index=True) hashed_password: str user_authority: UserAuthority = Field(default=UserAuthority.USER) - diff --git a/pretor/core/global_state_machine/__init__.py b/pretor/core/global_state_machine/__init__.py index 5fa7362..2840edb 100644 --- a/pretor/core/global_state_machine/__init__.py +++ b/pretor/core/global_state_machine/__init__.py @@ -1,14 +1,3 @@ -# Copyright 2026 zhaoxi826 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. +from pretor.core.global_state_machine.global_state_machine import GlobalStateMachine +__all__ = ["GlobalStateMachine"] diff --git a/pretor/core/global_state_machine/global_state_machine.py b/pretor/core/global_state_machine/global_state_machine.py index aa89df6..c9c1e56 100644 --- a/pretor/core/global_state_machine/global_state_machine.py +++ b/pretor/core/global_state_machine/global_state_machine.py @@ -15,8 +15,7 @@ import ray from pretor.core.global_state_machine.provider_manager import ProviderManager from pretor.core.global_state_machine.tool_manager import GlobalToolManager -from pretor.core.database.postgres import PostgresDatabase -from pretor.core.workflow.workflow_template_manager import WorkflowManager +from pretor.core.postgres_database import PostgresDatabase from pretor.core.global_state_machine.skill_manager import GlobalSkillManager from pretor.core.global_state_machine.individual_manager import GlobalIndividualManager @@ -24,17 +23,17 @@ from pretor.core.global_state_machine.individual_manager import GlobalIndividual @ray.remote class GlobalStateMachine: """GlobalStateMachine 核心组件类。 - 这是一个领域数据模型或功能封装类,承载了 GlobalStateMachine 相关的内聚属性定义与状态维护。它的存在隔离了局部的业务复杂性,并对外提供了类型安全的访问接口。 """ + 这是一个领域数据模型或功能封装类,承载了 GlobalStateMachine 相关的内聚属性定义与状态维护。它的存在隔离了局部的业务复杂性,并对外提供了类型安全的访问接口。""" + def __init__(self, postgres_database: PostgresDatabase): import sys + print("GSM __init__ START", file=sys.stderr, flush=True) print(" event_dict done", file=sys.stderr, flush=True) self._global_provider_manager = ProviderManager(postgres_database) print(" provider_manager done", file=sys.stderr, flush=True) self._global_tool_manager = GlobalToolManager() print(" tool_manager done", file=sys.stderr, flush=True) - self._global_workflow_template_manager = WorkflowManager() - print(" workflow_template_manager done", file=sys.stderr, flush=True) self._global_skill_manager = GlobalSkillManager() print(" skill_manager done", file=sys.stderr, flush=True) self._global_individual_manager = GlobalIndividualManager() @@ -44,50 +43,63 @@ class GlobalStateMachine: async def init_state_machine(self): """完成 state machine 模块的启动与依赖初始化。 - 在系统引导或服务拉起阶段被调用,负责建立网络连接、分配基础内存资源及注册核心服务组件。 """ - await self._global_provider_manager.init_provider_register(self.postgres_database) - await self._global_individual_manager.init_individual_register(self.postgres_database) + 在系统引导或服务拉起阶段被调用,负责建立网络连接、分配基础内存资源及注册核心服务组件。""" + await self._global_provider_manager.init_provider_register( + self.postgres_database + ) + await self._global_individual_manager.init_individual_register( + self.postgres_database + ) - async def add_provider_wrap(self, provider_type, provider_title, provider_url, provider_apikey, provider_owner): + async def add_provider_wrap( + self, + provider_type, + provider_title, + provider_url, + provider_apikey, + provider_owner, + ): """创建并持久化新的 provider wrap 实体。 接收构建参数,执行必要的数据校验与默认值填充后,将新记录安全地写入底层存储或系统注册表中。 Args: provider_type: 目标对象的唯一全局标识符 (UUID/ULID),用于在数据库表或缓存结构中精准匹配该 provider_type 实例。 provider_title: 目标对象的唯一全局标识符 (UUID/ULID),用于在数据库表或缓存结构中精准匹配该 provider_title 实例。 provider_url: 目标对象的唯一全局标识符 (UUID/ULID),用于在数据库表或缓存结构中精准匹配该 provider_url 实例。 provider_apikey: 目标对象的唯一全局标识符 (UUID/ULID),用于在数据库表或缓存结构中精准匹配该 provider_apikey 实例。 provider_owner: 目标对象的唯一全局标识符 (UUID/ULID),用于在数据库表或缓存结构中精准匹配该 provider_owner 实例。 - Returns: : 经由当前业务模型加工处理后所输出的具体数据实例或领域模型对象。 """ + Returns: : 经由当前业务模型加工处理后所输出的具体数据实例或领域模型对象。""" return await self._global_provider_manager.add_provider( provider_type=provider_type, provider_title=provider_title, provider_url=provider_url, provider_apikey=provider_apikey, provider_owner=provider_owner, - postgres_database=self.postgres_database + postgres_database=self.postgres_database, ) # Provider Manager Methods def get_provider_list(self): """检索并获取特定的 provider list 数据集合或实例对象。 根据提供的查询条件或上下文凭证,从数据库、缓存或第三方服务中读取对应的资源状态。 - Returns: : 经由当前业务模型加工处理后所输出的具体数据实例或领域模型对象。 """ + Returns: : 经由当前业务模型加工处理后所输出的具体数据实例或领域模型对象。""" return self._global_provider_manager.get_provider_list() def get_provider(self, provider_title): """检索并获取特定的 provider 数据集合或实例对象。 根据提供的查询条件或上下文凭证,从数据库、缓存或第三方服务中读取对应的资源状态。 Args: provider_title: 目标对象的唯一全局标识符 (UUID/ULID),用于在数据库表或缓存结构中精准匹配该 provider_title 实例。 - Returns: : 经由当前业务模型加工处理后所输出的具体数据实例或领域模型对象。 """ + Returns: : 经由当前业务模型加工处理后所输出的具体数据实例或领域模型对象。""" return self._global_provider_manager.get_provider(provider_title) async def delete_provider(self, provider_title: str): """安全地移除或注销 provider。 执行物理删除或逻辑删除操作,并妥善清理相关的关联数据及占用资源。 Args: provider_title (str): 目标对象的唯一全局标识符 (UUID/ULID),用于在数据库表或缓存结构中精准匹配该 provider_title 实例。 - Returns: : 经由当前业务模型加工处理后所输出的具体数据实例或领域模型对象。 """ - return await self._global_provider_manager.delete_provider(provider_title, self.postgres_database) + Returns: : 经由当前业务模型加工处理后所输出的具体数据实例或领域模型对象。""" + return await self._global_provider_manager.delete_provider( + provider_title, self.postgres_database + ) # Tool Manager Methods def get_tool_mapper(self): """检索并获取特定的 tool mapper 数据集合或实例对象。 根据提供的查询条件或上下文凭证,从数据库、缓存或第三方服务中读取对应的资源状态。 - Returns: : 经由当前业务模型加工处理后所输出的具体数据实例或领域模型对象。 """ + Returns: : 经由当前业务模型加工处理后所输出的具体数据实例或领域模型对象。""" return self._global_tool_manager.tool_mapper def get_tool_list(self, agent_name: str): @@ -96,60 +108,32 @@ class GlobalStateMachine: """检索并获取特定的 tool list 数据集合或实例对象。 根据提供的查询条件或上下文凭证,从数据库、缓存或第三方服务中读取对应的资源状态。 Args: agent_name (str): 赋予该实体的人类可读名称或标题字符串,主要用于前端 UI 展示、日志记录或模糊检索。 - Returns: : 经由当前业务模型加工处理后所输出的具体数据实例或领域模型对象。 """ + Returns: : 经由当前业务模型加工处理后所输出的具体数据实例或领域模型对象。""" tools = self._global_tool_manager.tool_mapper.get(agent_name, {}) # also include default tools default_tools = self._global_tool_manager.tool_mapper.get("default", {}) merged_tools = {**default_tools, **tools} return merged_tools - # Workflow Template Manager Methods - def get_all_workflow_templates(self): - """检索并获取特定的 all workflow templates 数据集合或实例对象。 - 根据提供的查询条件或上下文凭证,从数据库、缓存或第三方服务中读取对应的资源状态。 - Returns: : 经由当前业务模型加工处理后所输出的具体数据实例或领域模型对象。 """ - return self._global_workflow_template_manager.get_all_workflow_templates() - - def add_workflow_template(self, template_name: str, workflow_template): - """创建并持久化新的 workflow template 实体。 - 接收构建参数,执行必要的数据校验与默认值填充后,将新记录安全地写入底层存储或系统注册表中。 - Args: template_name (str): 赋予该实体的人类可读名称或标题字符串,主要用于前端 UI 展示、日志记录或模糊检索。 workflow_template: 参与 add workflow template 逻辑运算或数据构建的上下文依赖对象。 - Returns: : 经由当前业务模型加工处理后所输出的具体数据实例或领域模型对象。 """ - return self._global_workflow_template_manager.add_workflow_template(template_name, workflow_template) - - def delete_workflow_template(self, template_name: str): - """安全地移除或注销 workflow template。 - 执行物理删除或逻辑删除操作,并妥善清理相关的关联数据及占用资源。 - Args: template_name (str): 赋予该实体的人类可读名称或标题字符串,主要用于前端 UI 展示、日志记录或模糊检索。 - Returns: : 经由当前业务模型加工处理后所输出的具体数据实例或领域模型对象。 """ - return self._global_workflow_template_manager.delete_workflow_template(template_name) - - def generate_workflow_template(self, workflow_template): - """执行与 generate workflow template 相关的核心业务流转操作。 - 该方法封装了具体的算法策略或状态控制逻辑,确保操作能够在事务上下文中被原子且一致地执行。 - Args: workflow_template: 参与 generate workflow template 逻辑运算或数据构建的上下文依赖对象。 - Returns: : 经由当前业务模型加工处理后所输出的具体数据实例或领域模型对象。 """ - return self._global_workflow_template_manager.generate_workflow_template(workflow_template) - # Skill Manager Methods def add_skill(self, skill_name: str): """创建并持久化新的 skill 实体。 接收构建参数,执行必要的数据校验与默认值填充后,将新记录安全地写入底层存储或系统注册表中。 Args: skill_name (str): 赋予该实体的人类可读名称或标题字符串,主要用于前端 UI 展示、日志记录或模糊检索。 - Returns: : 经由当前业务模型加工处理后所输出的具体数据实例或领域模型对象。 """ + Returns: : 经由当前业务模型加工处理后所输出的具体数据实例或领域模型对象。""" return self._global_skill_manager.add_skill(skill_name) def get_skill_list(self): """检索并获取特定的 skill list 数据集合或实例对象。 根据提供的查询条件或上下文凭证,从数据库、缓存或第三方服务中读取对应的资源状态。 - Returns: : 经由当前业务模型加工处理后所输出的具体数据实例或领域模型对象。 """ + Returns: : 经由当前业务模型加工处理后所输出的具体数据实例或领域模型对象。""" return self._global_skill_manager.get_skill_list() def remove_skill(self, skill_name: str): """安全地移除或注销 skill。 执行物理删除或逻辑删除操作,并妥善清理相关的关联数据及占用资源。 Args: skill_name (str): 赋予该实体的人类可读名称或标题字符串,主要用于前端 UI 展示、日志记录或模糊检索。 - Returns: : 经由当前业务模型加工处理后所输出的具体数据实例或领域模型对象。 """ + Returns: : 经由当前业务模型加工处理后所输出的具体数据实例或领域模型对象。""" return self._global_skill_manager.remove_skill(skill_name) # Individual Manager Methods @@ -157,26 +141,25 @@ class GlobalStateMachine: """创建并持久化新的 individual 实体。 接收构建参数,执行必要的数据校验与默认值填充后,将新记录安全地写入底层存储或系统注册表中。 Args: agent_id (str): 目标对象的唯一全局标识符 (UUID/ULID),用于在数据库表或缓存结构中精准匹配该 agent 实例。 config: 驱动该模块运行的核心配置字典或 Pydantic 数据模型,定义了重试策略、超时时间及模型参数等选项。 - Returns: : 经由当前业务模型加工处理后所输出的具体数据实例或领域模型对象。 """ + Returns: : 经由当前业务模型加工处理后所输出的具体数据实例或领域模型对象。""" return self._global_individual_manager.add_individual(agent_id, config) def get_individual(self, agent_id: str): """检索并获取特定的 individual 数据集合或实例对象。 根据提供的查询条件或上下文凭证,从数据库、缓存或第三方服务中读取对应的资源状态。 Args: agent_id (str): 目标对象的唯一全局标识符 (UUID/ULID),用于在数据库表或缓存结构中精准匹配该 agent 实例。 - Returns: : 经由当前业务模型加工处理后所输出的具体数据实例或领域模型对象。 """ + Returns: : 经由当前业务模型加工处理后所输出的具体数据实例或领域模型对象。""" return self._global_individual_manager.get_individual(agent_id) def remove_individual(self, agent_id: str): """安全地移除或注销 individual。 执行物理删除或逻辑删除操作,并妥善清理相关的关联数据及占用资源。 Args: agent_id (str): 目标对象的唯一全局标识符 (UUID/ULID),用于在数据库表或缓存结构中精准匹配该 agent 实例。 - Returns: : 经由当前业务模型加工处理后所输出的具体数据实例或领域模型对象。 """ + Returns: : 经由当前业务模型加工处理后所输出的具体数据实例或领域模型对象。""" return self._global_individual_manager.remove_individual(agent_id) def list_individuals(self): """执行与 list individuals 相关的核心业务流转操作。 该方法封装了具体的算法策略或状态控制逻辑,确保操作能够在事务上下文中被原子且一致地执行。 - Returns: : 经由当前业务模型加工处理后所输出的具体数据实例或领域模型对象。 """ + Returns: : 经由当前业务模型加工处理后所输出的具体数据实例或领域模型对象。""" return self._global_individual_manager.list_individuals() - diff --git a/pretor/core/global_state_machine/individual_manager.py b/pretor/core/global_state_machine/individual_manager.py index 94f9a9f..2e52068 100644 --- a/pretor/core/global_state_machine/individual_manager.py +++ b/pretor/core/global_state_machine/individual_manager.py @@ -14,11 +14,14 @@ from typing import Dict, Any from pretor.utils.logger import get_logger -logger = get_logger('individual_manager') + +logger = get_logger("individual_manager") + class GlobalIndividualManager: """GlobalIndividualManager 核心组件类。 - 这是一个管理器类,职责集中在维护整个系统内有关 GlobalIndividual 资源的全局生命周期。它提供了注册机制、状态同步以及跨组件的统一查询入口,确保系统中该类型资源的实例一致性与可控性。 """ + 这是一个管理器类,职责集中在维护整个系统内有关 GlobalIndividual 资源的全局生命周期。它提供了注册机制、状态同步以及跨组件的统一查询入口,确保系统中该类型资源的实例一致性与可控性。""" + def __init__(self): self._individuals: Dict[str, Dict[str, Any]] = {} @@ -26,21 +29,31 @@ class GlobalIndividualManager: """完成 individual register 模块的启动与依赖初始化。 在系统引导或服务拉起阶段被调用,负责建立网络连接、分配基础内存资源及注册核心服务组件。 Args: postgres: 参与 init individual register 逻辑运算或数据构建的上下文依赖对象。 - Returns: (None): 经由当前业务模型加工处理后所输出的具体数据实例或领域模型对象。 """ + Returns: (None): 经由当前业务模型加工处理后所输出的具体数据实例或领域模型对象。""" try: try: individuals = await postgres.get_all_worker_individual.remote() for ind in individuals: - agent_id = getattr(ind, 'agent_id', None) + agent_id = getattr(ind, "agent_id", None) if agent_id: - self._individuals[agent_id] = ind.model_dump() if hasattr(ind, 'model_dump') else dict(ind) - logger.info(f"成功从数据库拉取了 {len(self._individuals)} 个 Worker Individual 配置。") + self._individuals[agent_id] = ( + ind.model_dump() + if hasattr(ind, "model_dump") + else dict(ind) + ) + logger.info( + f"成功从数据库拉取了 {len(self._individuals)} 个 Worker Individual 配置。" + ) except AttributeError: - logger.warning("数据库中 get_all_worker_individual 方法未实现,跳过全量加载。可以在将来完善该接口。") + logger.warning( + "数据库中 get_all_worker_individual 方法未实现,跳过全量加载。可以在将来完善该接口。" + ) except Exception as e: # 捕获因 Ray 调用目标方法不存在引发的异常 if "has no attribute 'get_all_worker_individual'" in str(e): - logger.warning("数据库 individual_database 中缺少 get_all_worker_individual 方法,无法全量拉取。") + logger.warning( + "数据库 individual_database 中缺少 get_all_worker_individual 方法,无法全量拉取。" + ) else: raise e except Exception as e: @@ -64,12 +77,12 @@ class GlobalIndividualManager: """安全地移除或注销 individual。 执行物理删除或逻辑删除操作,并妥善清理相关的关联数据及占用资源。 Args: agent_id (str): 目标对象的唯一全局标识符 (UUID/ULID),用于在数据库表或缓存结构中精准匹配该 agent 实例。 - Returns: (None): 经由当前业务模型加工处理后所输出的具体数据实例或领域模型对象。 """ + Returns: (None): 经由当前业务模型加工处理后所输出的具体数据实例或领域模型对象。""" if agent_id in self._individuals: del self._individuals[agent_id] def list_individuals(self) -> Dict[str, Dict[str, Any]]: """执行与 list individuals 相关的核心业务流转操作。 该方法封装了具体的算法策略或状态控制逻辑,确保操作能够在事务上下文中被原子且一致地执行。 - Returns: (Dict[str, Dict[str, Any]]): 高度聚合的字典结构数据,将多维度的属性特征或统计指标组合后一并返回。 """ + Returns: (Dict[str, Dict[str, Any]]): 高度聚合的字典结构数据,将多维度的属性特征或统计指标组合后一并返回。""" return self._individuals diff --git a/pretor/core/global_state_machine/model_provider/__init__.py b/pretor/core/global_state_machine/model_provider/__init__.py index 5cb81a3..48769ca 100644 --- a/pretor/core/global_state_machine/model_provider/__init__.py +++ b/pretor/core/global_state_machine/model_provider/__init__.py @@ -12,8 +12,24 @@ # See the License for the specific language governing permissions and # limitations under the License. -from pretor.core.global_state_machine.model_provider.base_provider import Provider, ProviderArgs -from pretor.core.global_state_machine.model_provider.openai_provider import OpenAIProvider -from pretor.core.global_state_machine.model_provider.claude_provider import ClaudeProvider -from pretor.core.global_state_machine.model_provider.deepseek_provider import DeepseekProvider -__all__ = ["Provider", "ProviderArgs", "OpenAIProvider", "ClaudeProvider", "DeepseekProvider"] +from pretor.core.global_state_machine.model_provider.base_provider import ( + Provider, + ProviderArgs, +) +from pretor.core.global_state_machine.model_provider.openai_provider import ( + OpenAIProvider, +) +from pretor.core.global_state_machine.model_provider.claude_provider import ( + ClaudeProvider, +) +from pretor.core.global_state_machine.model_provider.deepseek_provider import ( + DeepseekProvider, +) + +__all__ = [ + "Provider", + "ProviderArgs", + "OpenAIProvider", + "ClaudeProvider", + "DeepseekProvider", +] diff --git a/pretor/core/global_state_machine/model_provider/base_provider.py b/pretor/core/global_state_machine/model_provider/base_provider.py index 4383522..5355956 100644 --- a/pretor/core/global_state_machine/model_provider/base_provider.py +++ b/pretor/core/global_state_machine/model_provider/base_provider.py @@ -17,15 +17,19 @@ from pydantic import BaseModel from typing import List from enum import Enum + class ProviderStatus(str, Enum): """ProviderStatus 核心组件类。 - 这是一个模型/服务提供商适配器类,屏蔽了外部不同供应商(如 OpenAI、Anthropic 等)的底层 API 差异。它负责标准化参数组装、网络请求发送、鉴权处理以及响应结构的反序列化。 """ + 这是一个模型/服务提供商适配器类,屏蔽了外部不同供应商(如 OpenAI、Anthropic 等)的底层 API 差异。它负责标准化参数组装、网络请求发送、鉴权处理以及响应结构的反序列化。""" + UP = "up" DOWN = "down" + class Provider(BaseModel): """Provider 核心组件类。 - 这是一个模型/服务提供商适配器类,屏蔽了外部不同供应商(如 OpenAI、Anthropic 等)的底层 API 差异。它负责标准化参数组装、网络请求发送、鉴权处理以及响应结构的反序列化。 """ + 这是一个模型/服务提供商适配器类,屏蔽了外部不同供应商(如 OpenAI、Anthropic 等)的底层 API 差异。它负责标准化参数组装、网络请求发送、鉴权处理以及响应结构的反序列化。""" + provider_title: str provider_url: str provider_apikey: str @@ -34,17 +38,21 @@ class Provider(BaseModel): provider_owner: str | None = None provider_status: ProviderStatus = ProviderStatus.UP + class ProviderArgs(BaseModel): """ProviderArgs 核心组件类。 - 这是一个模型/服务提供商适配器类,屏蔽了外部不同供应商(如 OpenAI、Anthropic 等)的底层 API 差异。它负责标准化参数组装、网络请求发送、鉴权处理以及响应结构的反序列化。 """ + 这是一个模型/服务提供商适配器类,屏蔽了外部不同供应商(如 OpenAI、Anthropic 等)的底层 API 差异。它负责标准化参数组装、网络请求发送、鉴权处理以及响应结构的反序列化。""" + provider_title: str provider_url: str provider_apikey: str provider_owner: str + class BaseProvider(ABC): """BaseProvider 核心组件类。 - 这是一个模型/服务提供商适配器类,屏蔽了外部不同供应商(如 OpenAI、Anthropic 等)的底层 API 差异。它负责标准化参数组装、网络请求发送、鉴权处理以及响应结构的反序列化。 """ + 这是一个模型/服务提供商适配器类,屏蔽了外部不同供应商(如 OpenAI、Anthropic 等)的底层 API 差异。它负责标准化参数组装、网络请求发送、鉴权处理以及响应结构的反序列化。""" + @staticmethod @abstractmethod async def create_provider(provider_args: ProviderArgs) -> Provider: @@ -83,7 +91,9 @@ class BaseProvider(ABC): @staticmethod @abstractmethod - def _return_provider(provider_args: ProviderArgs, provider_models: List[str]) -> Provider: + def _return_provider( + provider_args: ProviderArgs, provider_models: List[str] + ) -> Provider: """ 包装Provider对象并返回 将provider_args和_load_models获取的方法包装为provider对象 @@ -100,5 +110,3 @@ class BaseProvider(ABC): 返回一个Provider对象 """ pass - - diff --git a/pretor/core/global_state_machine/model_provider/claude_provider.py b/pretor/core/global_state_machine/model_provider/claude_provider.py index 77159e4..27d32c0 100644 --- a/pretor/core/global_state_machine/model_provider/claude_provider.py +++ b/pretor/core/global_state_machine/model_provider/claude_provider.py @@ -14,21 +14,29 @@ from pretor.utils.retry import retry_on_retryable_error -from pretor.core.global_state_machine.model_provider.base_provider import BaseProvider, Provider, ProviderArgs +from pretor.core.global_state_machine.model_provider.base_provider import ( + BaseProvider, + Provider, + ProviderArgs, +) import httpx from typing import List + class ClaudeProvider(BaseProvider): """ClaudeProvider 核心组件类。 - 这是一个模型/服务提供商适配器类,屏蔽了外部不同供应商(如 OpenAI、Anthropic 等)的底层 API 差异。它负责标准化参数组装、网络请求发送、鉴权处理以及响应结构的反序列化。 """ + 这是一个模型/服务提供商适配器类,屏蔽了外部不同供应商(如 OpenAI、Anthropic 等)的底层 API 差异。它负责标准化参数组装、网络请求发送、鉴权处理以及响应结构的反序列化。""" + @staticmethod async def create_provider(provider_args: ProviderArgs) -> Provider: """创建并持久化新的 provider 实体。 接收构建参数,执行必要的数据校验与默认值填充后,将新记录安全地写入底层存储或系统注册表中。 Args: provider_args (ProviderArgs): 目标对象的唯一全局标识符 (UUID/ULID),用于在数据库表或缓存结构中精准匹配该 provider_args 实例。 - Returns: (Provider): 经由当前业务模型加工处理后所输出的具体数据实例或领域模型对象。 """ + Returns: (Provider): 经由当前业务模型加工处理后所输出的具体数据实例或领域模型对象。""" provider_models: List[str] = await ClaudeProvider._load_models(provider_args) - provider: Provider = ClaudeProvider._return_provider(provider_args, provider_models) + provider: Provider = ClaudeProvider._return_provider( + provider_args, provider_models + ) return provider @staticmethod @@ -38,11 +46,11 @@ class ClaudeProvider(BaseProvider): """执行与 load models 相关的核心业务流转操作。 该方法封装了具体的算法策略或状态控制逻辑,确保操作能够在事务上下文中被原子且一致地执行。 Args: provider_args (ProviderArgs): 目标对象的唯一全局标识符 (UUID/ULID),用于在数据库表或缓存结构中精准匹配该 provider_args 实例。 - Returns: (List[str]): 经过筛选、排序或分页处理后的实体对象列表集合。 """ + Returns: (List[str]): 经过筛选、排序或分页处理后的实体对象列表集合。""" headers = { "x-api-key": provider_args.provider_apikey, "anthropic-version": "2023-06-01", - "Content-Type": "application/json" + "Content-Type": "application/json", } # 如果是官方 API,通常使用 /v1/models (如果支持) # 注意:很多时候 Anthropic 并不返回完整列表,如果请求失败,建议返回硬编码的常用模型 @@ -57,19 +65,27 @@ class ClaudeProvider(BaseProvider): return sorted(model_ids) else: # 如果官方列表接口不可用,fallback 到已知常用模型 - return ["claude-3-5-sonnet-20240620", "claude-3-opus-20240229", "claude-3-haiku-20240307"] + return [ + "claude-3-5-sonnet-20240620", + "claude-3-opus-20240229", + "claude-3-haiku-20240307", + ] except Exception as e: print(f"[{provider_args.provider_title}] 获取 Claude 模型列表错误: {e}") return [] @staticmethod - def _return_provider(provider_args: ProviderArgs, provider_models: List[str]) -> Provider: + def _return_provider( + provider_args: ProviderArgs, provider_models: List[str] + ) -> Provider: """执行与 return provider 相关的核心业务流转操作。 该方法封装了具体的算法策略或状态控制逻辑,确保操作能够在事务上下文中被原子且一致地执行。 Args: provider_args (ProviderArgs): 目标对象的唯一全局标识符 (UUID/ULID),用于在数据库表或缓存结构中精准匹配该 provider_args 实例。 provider_models (List[str]): 目标对象的唯一全局标识符 (UUID/ULID),用于在数据库表或缓存结构中精准匹配该 provider_models 实例。 - Returns: (Provider): 经由当前业务模型加工处理后所输出的具体数据实例或领域模型对象。 """ - return Provider(provider_title=provider_args.provider_title, - provider_apikey=provider_args.provider_apikey, - provider_url=provider_args.provider_url, - provider_models=provider_models, - provider_type="claude") \ No newline at end of file + Returns: (Provider): 经由当前业务模型加工处理后所输出的具体数据实例或领域模型对象。""" + return Provider( + provider_title=provider_args.provider_title, + provider_apikey=provider_args.provider_apikey, + provider_url=provider_args.provider_url, + provider_models=provider_models, + provider_type="claude", + ) diff --git a/pretor/core/global_state_machine/model_provider/deepseek_provider.py b/pretor/core/global_state_machine/model_provider/deepseek_provider.py index 3285809..876470f 100644 --- a/pretor/core/global_state_machine/model_provider/deepseek_provider.py +++ b/pretor/core/global_state_machine/model_provider/deepseek_provider.py @@ -13,21 +13,29 @@ # limitations under the License. from pretor.utils.retry import retry_on_retryable_error -from pretor.core.global_state_machine.model_provider.base_provider import BaseProvider, Provider, ProviderArgs +from pretor.core.global_state_machine.model_provider.base_provider import ( + BaseProvider, + Provider, + ProviderArgs, +) import httpx from typing import List + class DeepseekProvider(BaseProvider): """DeepseekProvider 核心组件类。 - 这是一个模型/服务提供商适配器类,屏蔽了外部不同供应商(如 OpenAI、Anthropic 等)的底层 API 差异。它负责标准化参数组装、网络请求发送、鉴权处理以及响应结构的反序列化。 """ + 这是一个模型/服务提供商适配器类,屏蔽了外部不同供应商(如 OpenAI、Anthropic 等)的底层 API 差异。它负责标准化参数组装、网络请求发送、鉴权处理以及响应结构的反序列化。""" + @staticmethod async def create_provider(provider_args: ProviderArgs) -> Provider: """创建并持久化新的 provider 实体。 接收构建参数,执行必要的数据校验与默认值填充后,将新记录安全地写入底层存储或系统注册表中。 Args: provider_args (ProviderArgs): 目标对象的唯一全局标识符 (UUID/ULID),用于在数据库表或缓存结构中精准匹配该 provider_args 实例。 - Returns: (Provider): 经由当前业务模型加工处理后所输出的具体数据实例或领域模型对象。 """ + Returns: (Provider): 经由当前业务模型加工处理后所输出的具体数据实例或领域模型对象。""" provider_models: List[str] = await DeepseekProvider._load_models(provider_args) - provider: Provider = DeepseekProvider._return_provider(provider_args, provider_models) + provider: Provider = DeepseekProvider._return_provider( + provider_args, provider_models + ) return provider @staticmethod @@ -36,17 +44,23 @@ class DeepseekProvider(BaseProvider): """执行与 load models 相关的核心业务流转操作。 该方法封装了具体的算法策略或状态控制逻辑,确保操作能够在事务上下文中被原子且一致地执行。 Args: provider_args (ProviderArgs): 目标对象的唯一全局标识符 (UUID/ULID),用于在数据库表或缓存结构中精准匹配该 provider_args 实例。 - Returns: (List[str]): 经过筛选、排序或分页处理后的实体对象列表集合。 """ + Returns: (List[str]): 经过筛选、排序或分页处理后的实体对象列表集合。""" headers = { "Authorization": f"Bearer {provider_args.provider_apikey}", - "Content-Type": "application/json" + "Content-Type": "application/json", } - url = f"{provider_args.provider_url}/models" if "/v1" in provider_args.provider_url else f"{provider_args.provider_url}/v1/models" + url = ( + f"{provider_args.provider_url}/models" + if "/v1" in provider_args.provider_url + else f"{provider_args.provider_url}/v1/models" + ) try: async with httpx.AsyncClient(timeout=10.0) as client: response = await client.get(url, headers=headers) if response.status_code != 200: - print(f"[{provider_args.provider_title}] 获取模型失败: {response.status_code}") + print( + f"[{provider_args.provider_title}] 获取模型失败: {response.status_code}" + ) return [] data = response.json() raw_models = data.get("data", []) @@ -54,20 +68,27 @@ class DeepseekProvider(BaseProvider): return sorted(model_ids) except httpx.RequestError as e: from pretor.utils.error import RetryableError + print(f"[{provider_args.provider_title}] 网络请求异常: {e}") - raise RetryableError(f"[{provider_args.provider_title}] 网络请求异常: {e}") from e + raise RetryableError( + f"[{provider_args.provider_title}] 网络请求异常: {e}" + ) from e except Exception as e: print(f"[{provider_args.provider_title}] 解析模型列表时发生错误: {e}") return [] @staticmethod - def _return_provider(provider_args: ProviderArgs, provider_models: List[str]) -> Provider: + def _return_provider( + provider_args: ProviderArgs, provider_models: List[str] + ) -> Provider: """执行与 return provider 相关的核心业务流转操作。 该方法封装了具体的算法策略或状态控制逻辑,确保操作能够在事务上下文中被原子且一致地执行。 Args: provider_args (ProviderArgs): 目标对象的唯一全局标识符 (UUID/ULID),用于在数据库表或缓存结构中精准匹配该 provider_args 实例。 provider_models (List[str]): 目标对象的唯一全局标识符 (UUID/ULID),用于在数据库表或缓存结构中精准匹配该 provider_models 实例。 - Returns: (Provider): 经由当前业务模型加工处理后所输出的具体数据实例或领域模型对象。 """ - return Provider(provider_title=provider_args.provider_title, - provider_apikey=provider_args.provider_apikey, - provider_url=provider_args.provider_url, - provider_models=provider_models, - provider_type="deepseek") \ No newline at end of file + Returns: (Provider): 经由当前业务模型加工处理后所输出的具体数据实例或领域模型对象。""" + return Provider( + provider_title=provider_args.provider_title, + provider_apikey=provider_args.provider_apikey, + provider_url=provider_args.provider_url, + provider_models=provider_models, + provider_type="deepseek", + ) diff --git a/pretor/core/global_state_machine/model_provider/openai_provider.py b/pretor/core/global_state_machine/model_provider/openai_provider.py index a38a77f..fdf2e48 100644 --- a/pretor/core/global_state_machine/model_provider/openai_provider.py +++ b/pretor/core/global_state_machine/model_provider/openai_provider.py @@ -13,21 +13,29 @@ # limitations under the License. from pretor.utils.retry import retry_on_retryable_error -from pretor.core.global_state_machine.model_provider.base_provider import BaseProvider, Provider, ProviderArgs +from pretor.core.global_state_machine.model_provider.base_provider import ( + BaseProvider, + Provider, + ProviderArgs, +) import httpx from typing import List + class OpenAIProvider(BaseProvider): """OpenAIProvider 核心组件类。 - 这是一个模型/服务提供商适配器类,屏蔽了外部不同供应商(如 OpenAI、Anthropic 等)的底层 API 差异。它负责标准化参数组装、网络请求发送、鉴权处理以及响应结构的反序列化。 """ + 这是一个模型/服务提供商适配器类,屏蔽了外部不同供应商(如 OpenAI、Anthropic 等)的底层 API 差异。它负责标准化参数组装、网络请求发送、鉴权处理以及响应结构的反序列化。""" + @staticmethod async def create_provider(provider_args: ProviderArgs) -> Provider: """创建并持久化新的 provider 实体。 接收构建参数,执行必要的数据校验与默认值填充后,将新记录安全地写入底层存储或系统注册表中。 Args: provider_args (ProviderArgs): 目标对象的唯一全局标识符 (UUID/ULID),用于在数据库表或缓存结构中精准匹配该 provider_args 实例。 - Returns: (Provider): 经由当前业务模型加工处理后所输出的具体数据实例或领域模型对象。 """ + Returns: (Provider): 经由当前业务模型加工处理后所输出的具体数据实例或领域模型对象。""" provider_models: List[str] = await OpenAIProvider._load_models(provider_args) - provider: Provider = OpenAIProvider._return_provider(provider_args, provider_models) + provider: Provider = OpenAIProvider._return_provider( + provider_args, provider_models + ) return provider @staticmethod @@ -36,17 +44,23 @@ class OpenAIProvider(BaseProvider): """执行与 load models 相关的核心业务流转操作。 该方法封装了具体的算法策略或状态控制逻辑,确保操作能够在事务上下文中被原子且一致地执行。 Args: provider_args (ProviderArgs): 目标对象的唯一全局标识符 (UUID/ULID),用于在数据库表或缓存结构中精准匹配该 provider_args 实例。 - Returns: (List[str]): 经过筛选、排序或分页处理后的实体对象列表集合。 """ + Returns: (List[str]): 经过筛选、排序或分页处理后的实体对象列表集合。""" headers = { "Authorization": f"Bearer {provider_args.provider_apikey}", - "Content-Type": "application/json" + "Content-Type": "application/json", } - url = f"{provider_args.provider_url}/models" if "/v1" in provider_args.provider_url else f"{provider_args.provider_url}/v1/models" + url = ( + f"{provider_args.provider_url}/models" + if "/v1" in provider_args.provider_url + else f"{provider_args.provider_url}/v1/models" + ) try: async with httpx.AsyncClient(timeout=10.0) as client: response = await client.get(url, headers=headers) if response.status_code != 200: - print(f"[{provider_args.provider_title}] 获取模型失败: {response.status_code}") + print( + f"[{provider_args.provider_title}] 获取模型失败: {response.status_code}" + ) return [] data = response.json() raw_models = data.get("data", []) @@ -54,20 +68,27 @@ class OpenAIProvider(BaseProvider): return sorted(model_ids) except httpx.RequestError as e: from pretor.utils.error import RetryableError + print(f"[{provider_args.provider_title}] 网络请求异常: {e}") - raise RetryableError(f"[{provider_args.provider_title}] 网络请求异常: {e}") from e + raise RetryableError( + f"[{provider_args.provider_title}] 网络请求异常: {e}" + ) from e except Exception as e: print(f"[{provider_args.provider_title}] 解析模型列表时发生错误: {e}") return [] @staticmethod - def _return_provider(provider_args: ProviderArgs, provider_models: List[str]) -> Provider: + def _return_provider( + provider_args: ProviderArgs, provider_models: List[str] + ) -> Provider: """执行与 return provider 相关的核心业务流转操作。 该方法封装了具体的算法策略或状态控制逻辑,确保操作能够在事务上下文中被原子且一致地执行。 Args: provider_args (ProviderArgs): 目标对象的唯一全局标识符 (UUID/ULID),用于在数据库表或缓存结构中精准匹配该 provider_args 实例。 provider_models (List[str]): 目标对象的唯一全局标识符 (UUID/ULID),用于在数据库表或缓存结构中精准匹配该 provider_models 实例。 - Returns: (Provider): 经由当前业务模型加工处理后所输出的具体数据实例或领域模型对象。 """ - return Provider(provider_title=provider_args.provider_title, - provider_apikey=provider_args.provider_apikey, - provider_url=provider_args.provider_url, - provider_models=provider_models, - provider_type="openai") \ No newline at end of file + Returns: (Provider): 经由当前业务模型加工处理后所输出的具体数据实例或领域模型对象。""" + return Provider( + provider_title=provider_args.provider_title, + provider_apikey=provider_args.provider_apikey, + provider_url=provider_args.provider_url, + provider_models=provider_models, + provider_type="openai", + ) diff --git a/pretor/core/global_state_machine/provider_manager.py b/pretor/core/global_state_machine/provider_manager.py index b262f14..513e73a 100644 --- a/pretor/core/global_state_machine/provider_manager.py +++ b/pretor/core/global_state_machine/provider_manager.py @@ -12,51 +12,73 @@ # See the License for the specific language governing permissions and # limitations under the License. -from pretor.core.global_state_machine.model_provider import Provider, OpenAIProvider, ClaudeProvider, DeepseekProvider +from pretor.core.global_state_machine.model_provider import ( + Provider, + OpenAIProvider, + ClaudeProvider, + DeepseekProvider, +) from typing import Dict, Type + class ProviderManager: """ 模型供应商管理器 (ProviderManager)。 负责维护不同的 LLM 协议适配器,提供从配置注册到 Agent 实例化的全生命周期管理。 """ + # --- 类属性显式标注 (IDE 友好) --- provider_mapper: Dict[str, Type[Provider]] """协议映射表:键为协议名(如 'openai'),值为对应的 Provider 类。""" provider_register: Dict[str, Provider] """供应商注册表:键为用户自定义别名,值为已实例化的 Provider 对象。""" + def __init__(self, postgres): - self.provider_mapper = {"openai": OpenAIProvider, - "claude": ClaudeProvider, - "deepseek": DeepseekProvider} + self.provider_mapper = { + "openai": OpenAIProvider, + "claude": ClaudeProvider, + "deepseek": DeepseekProvider, + } self.provider_register = {} async def init_provider_register(self, postgres) -> None: """完成 provider register 模块的启动与依赖初始化。 在系统引导或服务拉起阶段被调用,负责建立网络连接、分配基础内存资源及注册核心服务组件。 Args: postgres: 参与 init provider register 逻辑运算或数据构建的上下文依赖对象。 - Returns: (None): 经由当前业务模型加工处理后所输出的具体数据实例或领域模型对象。 """ + Returns: (None): 经由当前业务模型加工处理后所输出的具体数据实例或领域模型对象。""" providers = await postgres.get_provider.remote() for provider in providers: self.provider_register[provider.provider_title] = provider - async def add_provider(self, provider_type, provider_title, provider_url, provider_apikey, provider_owner, postgres_database) -> None: + async def add_provider( + self, + provider_type, + provider_title, + provider_url, + provider_apikey, + provider_owner, + postgres_database, + ) -> None: """创建并持久化新的 provider 实体。 接收构建参数,执行必要的数据校验与默认值填充后,将新记录安全地写入底层存储或系统注册表中。 Args: provider_type: 目标对象的唯一全局标识符 (UUID/ULID),用于在数据库表或缓存结构中精准匹配该 provider_type 实例。 provider_title: 目标对象的唯一全局标识符 (UUID/ULID),用于在数据库表或缓存结构中精准匹配该 provider_title 实例。 provider_url: 目标对象的唯一全局标识符 (UUID/ULID),用于在数据库表或缓存结构中精准匹配该 provider_url 实例。 provider_apikey: 目标对象的唯一全局标识符 (UUID/ULID),用于在数据库表或缓存结构中精准匹配该 provider_apikey 实例。 provider_owner: 目标对象的唯一全局标识符 (UUID/ULID),用于在数据库表或缓存结构中精准匹配该 provider_owner 实例。 postgres_database: 从客户端传递过来或由上游组件生成的核心业务数据体,通常需要进一步的清洗和结构化解析。 - Returns: (None): 经由当前业务模型加工处理后所输出的具体数据实例或领域模型对象。 """ + Returns: (None): 经由当前业务模型加工处理后所输出的具体数据实例或领域模型对象。""" from pretor.core.global_state_machine.model_provider import ProviderArgs from pretor.utils.logger import get_logger - logger = get_logger('provider_manager') + + logger = get_logger("provider_manager") import httpx - provider_args: ProviderArgs = ProviderArgs(provider_title=provider_title, - provider_url=provider_url, - provider_apikey=provider_apikey, - provider_owner=provider_owner) + provider_args: ProviderArgs = ProviderArgs( + provider_title=provider_title, + provider_url=provider_url, + provider_apikey=provider_apikey, + provider_owner=provider_owner, + ) try: import ulid + provider_class = self.provider_mapper.get(provider_type, None) if provider_class is None: logger.warning(f"Provider type {provider_type} is not supported.") @@ -65,41 +87,49 @@ class ProviderManager: provider.provider_owner = provider_owner self.provider_register[provider_title] = provider await postgres_database.add_provider_db.remote( - provider_id=str(ulid.ULID()), - provider_title=provider.provider_title, - provider_url=provider.provider_url, - provider_apikey=provider.provider_apikey, - provider_models=provider.provider_models, - provider_type=provider.provider_type, - provider_owner=provider.provider_owner) + provider_id=str(ulid.ULID()), + provider_title=provider.provider_title, + provider_url=provider.provider_url, + provider_apikey=provider.provider_apikey, + provider_models=provider.provider_models, + provider_type=provider.provider_type, + provider_owner=provider.provider_owner, + ) logger.info(f"已添加适配器{provider_title}") except httpx.RequestError as e: from pretor.utils.error import RetryableError + logger.warning(f"[{provider_args.provider_title}] 网络请求异常: {e}") - raise RetryableError(f"[{provider_args.provider_title}] 网络请求异常: {e}") from e + raise RetryableError( + f"[{provider_args.provider_title}] 网络请求异常: {e}" + ) from e except Exception as e: - logger.warning(f"[{provider_args.provider_title}] 解析模型列表时发生错误: {e}") + logger.warning( + f"[{provider_args.provider_title}] 解析模型列表时发生错误: {e}" + ) def get_provider_list(self): """检索并获取特定的 provider list 数据集合或实例对象。 根据提供的查询条件或上下文凭证,从数据库、缓存或第三方服务中读取对应的资源状态。 - Returns: : 经由当前业务模型加工处理后所输出的具体数据实例或领域模型对象。 """ + Returns: : 经由当前业务模型加工处理后所输出的具体数据实例或领域模型对象。""" return self.provider_register def get_provider(self, provider_title): """检索并获取特定的 provider 数据集合或实例对象。 根据提供的查询条件或上下文凭证,从数据库、缓存或第三方服务中读取对应的资源状态。 Args: provider_title: 目标对象的唯一全局标识符 (UUID/ULID),用于在数据库表或缓存结构中精准匹配该 provider_title 实例。 - Returns: : 经由当前业务模型加工处理后所输出的具体数据实例或领域模型对象。 """ + Returns: : 经由当前业务模型加工处理后所输出的具体数据实例或领域模型对象。""" return self.provider_register.get(provider_title) async def delete_provider(self, provider_title: str, postgres_database) -> None: """安全地移除或注销 provider。 执行物理删除或逻辑删除操作,并妥善清理相关的关联数据及占用资源。 Args: provider_title (str): 目标对象的唯一全局标识符 (UUID/ULID),用于在数据库表或缓存结构中精准匹配该 provider_title 实例。 postgres_database: 从客户端传递过来或由上游组件生成的核心业务数据体,通常需要进一步的清洗和结构化解析。 - Returns: (None): 经由当前业务模型加工处理后所输出的具体数据实例或领域模型对象。 """ + Returns: (None): 经由当前业务模型加工处理后所输出的具体数据实例或领域模型对象。""" if provider_title in self.provider_register: provider = self.provider_register[provider_title] - await postgres_database.delete_provider_db.remote( provider_id=provider.provider_id) - del self.provider_register[provider_title] \ No newline at end of file + await postgres_database.delete_provider_db.remote( + provider_id=provider.provider_id + ) + del self.provider_register[provider_title] diff --git a/pretor/core/global_state_machine/skill_manager.py b/pretor/core/global_state_machine/skill_manager.py index 85a556c..b5bbecc 100644 --- a/pretor/core/global_state_machine/skill_manager.py +++ b/pretor/core/global_state_machine/skill_manager.py @@ -17,22 +17,29 @@ from collections import defaultdict import pathlib import json + class GlobalSkillManager: """GlobalSkillManager 核心组件类。 - 这是一个管理器类,职责集中在维护整个系统内有关 GlobalSkill 资源的全局生命周期。它提供了注册机制、状态同步以及跨组件的统一查询入口,确保系统中该类型资源的实例一致性与可控性。 """ - skill_mapper = Dict[str,Tuple[str]] + 这是一个管理器类,职责集中在维护整个系统内有关 GlobalSkill 资源的全局生命周期。它提供了注册机制、状态同步以及跨组件的统一查询入口,确保系统中该类型资源的实例一致性与可控性。""" + + skill_mapper = Dict[str, Tuple[str]] """skill的存储表""" def __init__(self): self.skill_mapper = defaultdict(tuple) import os - skill_plugin_dir = pathlib.Path(os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "..", "plugin", "skill"))) + + skill_plugin_dir = pathlib.Path( + os.path.abspath( + os.path.join(os.path.dirname(__file__), "..", "..", "plugin", "skill") + ) + ) if not skill_plugin_dir.exists() or not skill_plugin_dir.is_dir(): return for item in skill_plugin_dir.iterdir(): if item.is_dir() and not item.name.startswith((".", "__")): - json_path = item / "skill.json" # 拼接文件路径 + json_path = item / "skill.json" # 拼接文件路径 if json_path.exists(): try: with open(json_path, "r", encoding="utf-8") as f: @@ -42,7 +49,7 @@ class GlobalSkillManager: if name: self.skill_mapper[name] = ( skill.get("description", ""), - skill.get("instructions", "") + skill.get("instructions", ""), ) except (json.JSONDecodeError, OSError) as e: print(f"警告: 加载插件 {item.name} 失败: {e}") @@ -50,7 +57,12 @@ class GlobalSkillManager: def add_skill(self, skill_name: str) -> None: """Add a skill to the manager by reading its skill.json from the path""" import os - skill_plugin_dir = pathlib.Path(os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "..", "plugin", "skill"))) + + skill_plugin_dir = pathlib.Path( + os.path.abspath( + os.path.join(os.path.dirname(__file__), "..", "..", "plugin", "skill") + ) + ) item = skill_plugin_dir / skill_name if item.is_dir() and not item.name.startswith((".", "__")): json_path = item / "skill.json" @@ -62,7 +74,7 @@ class GlobalSkillManager: if name: self.skill_mapper[name] = ( skill.get("description", ""), - skill.get("instructions", "") + skill.get("instructions", ""), ) except (json.JSONDecodeError, OSError) as e: print(f"警告: 加载插件 {item.name} 失败: {e}") diff --git a/pretor/core/global_state_machine/tool_manager.py b/pretor/core/global_state_machine/tool_manager.py index 4c7b82e..8a3a9a7 100644 --- a/pretor/core/global_state_machine/tool_manager.py +++ b/pretor/core/global_state_machine/tool_manager.py @@ -19,17 +19,22 @@ from collections import defaultdict from pretor.plugin.tool_plugin.base_tool import BaseToolData from typing import Dict, Type from pretor.utils.logger import get_logger -logger = get_logger('tool_manager') + +logger = get_logger("tool_manager") + class GlobalToolManager: """GlobalToolManager 核心组件类。 - 这是一个管理器类,职责集中在维护整个系统内有关 GlobalTool 资源的全局生命周期。它提供了注册机制、状态同步以及跨组件的统一查询入口,确保系统中该类型资源的实例一致性与可控性。 """ + 这是一个管理器类,职责集中在维护整个系统内有关 GlobalTool 资源的全局生命周期。它提供了注册机制、状态同步以及跨组件的统一查询入口,确保系统中该类型资源的实例一致性与可控性。""" + tool_mapper: Dict[str, Dict[str, Type[BaseToolData]]] def __init__(self): self.tool_mapper = defaultdict(dict) - tool_plugin_dir = pathlib.Path(__file__).parent.parent.parent / "plugin" / "tool_plugin" + tool_plugin_dir = ( + pathlib.Path(__file__).parent.parent.parent / "plugin" / "tool_plugin" + ) if not tool_plugin_dir.exists() or not tool_plugin_dir.is_dir(): return @@ -51,4 +56,4 @@ class GlobalToolManager: for scope in action_scopes: self.tool_mapper[scope][plugin_name] = obj except Exception as e: - logger.warning(f"Failed to load tool plugin {plugin_name}: {e}") \ No newline at end of file + logger.warning(f"Failed to load tool plugin {plugin_name}: {e}") diff --git a/pretor/core/global_workflow_manager/__init__.py b/pretor/core/global_workflow_manager/__init__.py new file mode 100644 index 0000000..753739f --- /dev/null +++ b/pretor/core/global_workflow_manager/__init__.py @@ -0,0 +1,5 @@ +from pretor.core.global_workflow_manager.global_workflow_manager import ( + GlobalWorkflowManager, +) + +__all__ = ["GlobalWorkflowManager"] diff --git a/pretor/core/global_state_machine/global_workflow_manager.py b/pretor/core/global_workflow_manager/global_workflow_manager.py similarity index 77% rename from pretor/core/global_state_machine/global_workflow_manager.py rename to pretor/core/global_workflow_manager/global_workflow_manager.py index 38fb881..17f876c 100644 --- a/pretor/core/global_state_machine/global_workflow_manager.py +++ b/pretor/core/global_workflow_manager/global_workflow_manager.py @@ -6,6 +6,7 @@ from pretor.core.workflow.workflow import PretorWorkflow from pretor.utils.ray_hook import ray_actor_hook from pretor.utils.logger import get_logger + @ray.remote class GlobalWorkflowManager: def __init__(self): @@ -31,7 +32,9 @@ class GlobalWorkflowManager: event_copy = event.model_copy() event_copy.pending_queue = None event_copy.receive_queue = None - self.event_object_refs[event.trace_id] = ray.put(event_copy.model_dump_json()) + self.event_object_refs[event.trace_id] = ray.put( + event_copy.model_dump_json() + ) except Exception as e: self.logger.error(f"Failed to load event {record.trace_id}: {e}") @@ -40,13 +43,22 @@ class GlobalWorkflowManager: # Trigger resumption of incomplete workflows workflow_running_engine = None for trace_id, event in self.event_dict.items(): - if event.workflow and event.workflow.status.status in ["waiting_llm_working", "waiting_tool_working", "llm_working", "tool_working"]: + if event.workflow and event.workflow.status.status in [ + "waiting_llm_working", + "waiting_tool_working", + "llm_working", + "tool_working", + ]: self.logger.info(f"Resuming incomplete workflow {trace_id}") if not workflow_running_engine: try: - workflow_running_engine = ray_actor_hook("workflow_running_engine").workflow_running_engine + workflow_running_engine = ray_actor_hook( + "workflow_running_engine" + ).workflow_running_engine except AttributeError: - self.logger.warning("workflow_running_engine not found, cannot resume workflow") + self.logger.warning( + "workflow_running_engine not found, cannot resume workflow" + ) break await workflow_running_engine.resume_workflow.remote(event) @@ -64,12 +76,11 @@ class GlobalWorkflowManager: # Update cache self.event_object_refs[event.trace_id] = ray.put(event_json) - await self.postgres_database.upsert_event.remote( - event.trace_id, - event_json - ) + await self.postgres_database.upsert_event.remote(event.trace_id, event_json) except Exception as e: - self.logger.error(f"Failed to upsert event {event.trace_id} to database: {e}") + self.logger.error( + f"Failed to upsert event {event.trace_id} to database: {e}" + ) async def add_event(self, event: PretorEvent) -> None: event.pending_queue = asyncio.Queue() @@ -98,7 +109,9 @@ class GlobalWorkflowManager: event_json = ray.get(self.event_object_refs[trace_id]) return PretorEvent.model_validate_json(event_json) except Exception as e: - self.logger.warning(f"Failed to fetch event from cache for trace {trace_id}: {e}") + self.logger.warning( + f"Failed to fetch event from cache for trace {trace_id}: {e}" + ) # Fallback to database try: @@ -119,11 +132,15 @@ class GlobalWorkflowManager: return event except Exception as e: - self.logger.error(f"Failed to fetch event {trace_id} from database fallback: {e}") + self.logger.error( + f"Failed to fetch event {trace_id} from database fallback: {e}" + ) return None - async def update_attachment(self, trace_id: str, attachment: Dict[str, str]) -> None: + async def update_attachment( + self, trace_id: str, attachment: Dict[str, str] + ) -> None: if trace_id in self.event_dict: self.event_dict[trace_id].attachment = attachment await self._upsert_event_to_db(self.event_dict[trace_id]) @@ -148,17 +165,25 @@ class GlobalWorkflowManager: try: event = PretorEvent.model_validate_json(record.event_data_json) workflow_title = event.workflow.title if event.workflow else None - workflow_status = event.workflow.status.status if event.workflow and event.workflow.status else None - result.append({ - "event_id": event.trace_id, - "workflow_title": workflow_title, - "status": workflow_status, - "user_name": event.user_name, - "message": event.message, - "create_time": event.create_time, - }) + workflow_status = ( + event.workflow.status.status + if event.workflow and event.workflow.status + else None + ) + result.append( + { + "event_id": event.trace_id, + "workflow_title": workflow_title, + "status": workflow_status, + "user_name": event.user_name, + "message": event.message, + "create_time": event.create_time, + } + ) # Best-effort cache population - self.event_object_refs[event.trace_id] = ray.put(record.event_data_json) + self.event_object_refs[event.trace_id] = ray.put( + record.event_data_json + ) except Exception: continue except Exception as e: @@ -173,7 +198,7 @@ class GlobalWorkflowManager: async def get_pending(self, trace_id) -> str: if trace_id in self.event_dict and self.event_dict[trace_id].pending_queue: return await self.event_dict[trace_id].pending_queue.get() - await asyncio.sleep(1) # Prevent CPU spinning if not found + await asyncio.sleep(1) # Prevent CPU spinning if not found return "" async def put_received(self, trace_id, item) -> None: @@ -183,5 +208,5 @@ class GlobalWorkflowManager: async def get_received(self, trace_id) -> str: if trace_id in self.event_dict and self.event_dict[trace_id].receive_queue: return await self.event_dict[trace_id].receive_queue.get() - await asyncio.sleep(1) # Prevent CPU spinning if not found + await asyncio.sleep(1) # Prevent CPU spinning if not found return "" diff --git a/pretor/core/individual/consciousness_node/__init__.py b/pretor/core/individual/consciousness_node/__init__.py index 500b303..bab0541 100644 --- a/pretor/core/individual/consciousness_node/__init__.py +++ b/pretor/core/individual/consciousness_node/__init__.py @@ -13,4 +13,5 @@ # limitations under the License. from .consciousness_node import ConsciousnessNode + __all__ = ["ConsciousnessNode"] diff --git a/pretor/core/individual/consciousness_node/consciousness_node.py b/pretor/core/individual/consciousness_node/consciousness_node.py index 6ee07e0..476b3bb 100644 --- a/pretor/core/individual/consciousness_node/consciousness_node.py +++ b/pretor/core/individual/consciousness_node/consciousness_node.py @@ -15,8 +15,15 @@ import ray from typing import Union, overload -from pretor.core.individual.consciousness_node.template import (ConsciousnessNodeDeps, ForSupervisoryNode, ForWorkflow,\ - ForWorkflowEngine, ForWorkflowInput, ForSupervisoryInput, ForWorkflowEngineInput) +from pretor.core.individual.consciousness_node.template import ( + ConsciousnessNodeDeps, + ForSupervisoryNode, + ForWorkflow, + ForWorkflowEngine, + ForWorkflowInput, + ForSupervisoryInput, + ForWorkflowEngineInput, +) from pydantic_ai import Agent, RunContext from pretor.core.global_state_machine.global_state_machine import GlobalStateMachine from pretor.core.global_state_machine.model_provider.base_provider import Provider @@ -26,14 +33,21 @@ from pretor.adapter.model_adapter.agent_factory import AgentFactory @ray.remote class ConsciousnessNode: """ConsciousnessNode 核心组件类。 - 这是一个系统执行节点类,作为多智能体架构中的独立处理单元。它能够接收工作流上下文,根据内置的大模型策略进行意图理解和自主决策,从而驱动特定阶段的任务闭环。 """ + 这是一个系统执行节点类,作为多智能体架构中的独立处理单元。它能够接收工作流上下文,根据内置的大模型策略进行意图理解和自主决策,从而驱动特定阶段的任务闭环。""" + def __init__(self) -> None: from pretor.utils.logger import get_logger - self.logger = get_logger('consciousness_node') + + self.logger = get_logger("consciousness_node") self.agent: None | Agent = None - - async def create_agent(self, global_state_machine: GlobalStateMachine, provider_title: str, model_id: str, tools_list: list[str] = None) -> None: + async def create_agent( + self, + global_state_machine: GlobalStateMachine, + provider_title: str, + model_id: str, + tools_list: list[str] = None, + ) -> None: """ create_agent方法,将agent对象装配到ConsciousnessNode的属性内 该方法通过provider_title从global_state_machine中获取provider对象,然后从provider对象中取出供应商形象,装配为pydantic_ai的 @@ -58,32 +72,35 @@ class ConsciousnessNode: ) output_type = Union[ForSupervisoryNode, ForWorkflow, ForWorkflowEngine] from pretor.utils.get_tool import load_tools_from_list - provider: Provider = await global_state_machine.get_provider.remote( provider_title) + + provider: Provider = await global_state_machine.get_provider.remote( + provider_title + ) agent_factory = AgentFactory() callables = load_tools_from_list(tools_list) - self.agent = agent_factory.create_agent(provider=provider, - model_id=model_id, - output_type=output_type, - system_prompt=system_prompt, - deps_type=ConsciousnessNodeDeps, - agent_name="consciousness_node", - tools=callables) + self.agent = agent_factory.create_agent( + provider=provider, + model_id=model_id, + output_type=output_type, + system_prompt=system_prompt, + deps_type=ConsciousnessNodeDeps, + agent_name="consciousness_node", + tools=callables, + ) @self.agent.system_prompt async def dynamic_prompt(ctx: RunContext[ConsciousnessNodeDeps]): """执行与 dynamic prompt 相关的核心业务流转操作。 该方法封装了具体的算法策略或状态控制逻辑,确保操作能够在事务上下文中被原子且一致地执行。 Args: ctx (RunContext[ConsciousnessNodeDeps]): 参与 dynamic prompt 逻辑运算或数据构建的上下文依赖对象。 - Returns: : 经由当前业务模型加工处理后所输出的具体数据实例或领域模型对象。 """ + Returns: : 经由当前业务模型加工处理后所输出的具体数据实例或领域模型对象。""" prompt = system_prompt + "\n\n" prompt += ( f"=== 当前任务上下文 ===\n" f"- 当前指令 (Command): {ctx.deps.command}\n" f"- 原始用户命令 (Original Command): {ctx.deps.original_command}\n" ) - if ctx.deps.workflow_template: - prompt += f"- 选定工作流模板 (Workflow Template): {ctx.deps.workflow_template}\n" if ctx.deps.available_skills: prompt += "\n=== 当前可用 Skill Individual ===\n" prompt += "你可以直接将以下 Skill Individual 安排进工作流的步骤中(设置 node 为 skill_individual,并将 agent_id 设置为对应 Skill Individual 的真实 agent_id,不要用名称!),作为可调用的工具。\n" @@ -92,30 +109,34 @@ class ConsciousnessNode: return prompt - async def working(self, payload: Union[ForWorkflowEngineInput, ForWorkflowInput, ForSupervisoryInput]) -> Union[ForWorkflowEngine, ForWorkflow, ForSupervisoryNode, None]: + async def working( + self, + payload: Union[ForWorkflowEngineInput, ForWorkflowInput, ForSupervisoryInput], + ) -> Union[ForWorkflowEngine, ForWorkflow, ForSupervisoryNode, None]: """执行与 working 相关的核心业务流转操作。 该方法封装了具体的算法策略或状态控制逻辑,确保操作能够在事务上下文中被原子且一致地执行。 Args: payload (Union[ForWorkflowEngineInput, ForWorkflowInput, ForSupervisoryInput]): 从客户端传递过来或由上游组件生成的核心业务数据体,通常需要进一步的清洗和结构化解析。 - Returns: (Union[ForWorkflowEngine, ForWorkflow, ForSupervisoryNode, None]): 经由当前业务模型加工处理后所输出的具体数据实例或领域模型对象。 """ + Returns: (Union[ForWorkflowEngine, ForWorkflow, ForSupervisoryNode, None]): 经由当前业务模型加工处理后所输出的具体数据实例或领域模型对象。""" try: result = await self._run(payload) if isinstance(result, (ForWorkflowEngine, ForWorkflow, ForSupervisoryNode)): return result else: - self.logger.error(f"ConsciousnessNode: 未知或不匹配的返回类型: {type(result)}") + self.logger.error( + f"ConsciousnessNode: 未知或不匹配的返回类型: {type(result)}" + ) return None except Exception: self.logger.exception("ConsciousnessNode在执行working时发生严重错误") return None - @overload async def _run(self, payload: ForWorkflowEngineInput) -> ForWorkflowEngine: """ _run方法 该分支应当在supervisory_node简单处理用户命令后,工作流创建前调用! Args: - payload: 应当包含workflow_template和event对象 + payload: 应当包含原始命令和可用技能等信息 Returns: ForWorkflowEngine对象,将被放到全局状态机后丢入WorkflowEngine的异步队列 @@ -148,45 +169,53 @@ class ConsciousnessNode: """ pass - async def _run(self, payload: Union[ForSupervisoryInput, ForWorkflowInput, ForWorkflowEngineInput]) -> Union[ForSupervisoryNode, ForWorkflow, ForWorkflowEngine]: + async def _run( + self, + payload: Union[ForSupervisoryInput, ForWorkflowInput, ForWorkflowEngineInput], + ) -> Union[ForSupervisoryNode, ForWorkflow, ForWorkflowEngine]: """执行与 run 相关的核心业务流转操作。 该方法封装了具体的算法策略或状态控制逻辑,确保操作能够在事务上下文中被原子且一致地执行。 Args: payload (Union[ForSupervisoryInput, ForWorkflowInput, ForWorkflowEngineInput]): 从客户端传递过来或由上游组件生成的核心业务数据体,通常需要进一步的清洗和结构化解析。 - Returns: (Union[ForSupervisoryNode, ForWorkflow, ForWorkflowEngine]): 经由当前业务模型加工处理后所输出的具体数据实例或领域模型对象。 """ + Returns: (Union[ForSupervisoryNode, ForWorkflow, ForWorkflowEngine]): 经由当前业务模型加工处理后所输出的具体数据实例或领域模型对象。""" try: self.agent.retries = 3 if isinstance(payload, ForWorkflowEngineInput): deps = ConsciousnessNodeDeps( original_command=payload.original_command, - workflow_template=payload.workflow_template, - command="拆解原始命令变成一个工作流", - available_skills=payload.available_skills + command="自主分析并拆解原始命令,生成严密可执行的工作流", + available_skills=payload.available_skills, ) self.logger.debug("ConsciousnessNode: 开始生成工作流 (原生重试开启)") prompt = "根据original_command制定严密的可执行workflow" - if payload.workflow_template: - prompt += ",可以学习并参考workflow_template的设计理念" result = await self.agent.run(prompt, deps=deps) return result.output elif isinstance(payload, ForWorkflowInput): deps = ConsciousnessNodeDeps( original_command=payload.original_command, - command="完成workflow step中分配给意识节点的特定任务或指导" + command="完成workflow step中分配给意识节点的特定任务或指导", + ) + self.logger.debug( + "ConsciousnessNode: 开始处理工作流节点任务 (原生重试开启)" + ) + result = await self.agent.run( + f"处理此工作流步骤信息:\n{payload.workflow_step.model_dump_json()}", + deps=deps, ) - self.logger.debug("ConsciousnessNode: 开始处理工作流节点任务 (原生重试开启)") - result = await self.agent.run(f"处理此工作流步骤信息:\n{payload.workflow_step.model_dump_json()}", - deps=deps) return result.output elif isinstance(payload, ForSupervisoryInput): deps = ConsciousnessNodeDeps( original_command=payload.original_command, - command="对于工作流整体执行结果进行检查,并且生成一份专业的技术性总结报告" + command="对于工作流整体执行结果进行检查,并且生成一份专业的技术性总结报告", + ) + self.logger.debug( + "ConsciousnessNode: 开始生成技术总结报告 (原生重试开启)" + ) + result = await self.agent.run( + f"基于以下工作流的执行记录,生成技术报告:\n{payload.workflow.model_dump_json()}", + deps=deps, ) - self.logger.debug("ConsciousnessNode: 开始生成技术总结报告 (原生重试开启)") - result = await self.agent.run(f"基于以下工作流的执行记录,生成技术报告:\n{payload.workflow.model_dump_json()}", - deps=deps) return result.output except Exception as e: self.logger.exception(f"ConsciousnessNode 模型生成最终失败: {str(e)}") diff --git a/pretor/core/individual/consciousness_node/template.py b/pretor/core/individual/consciousness_node/template.py index 6141764..3b359ee 100644 --- a/pretor/core/individual/consciousness_node/template.py +++ b/pretor/core/individual/consciousness_node/template.py @@ -18,60 +18,71 @@ from pretor.utils.agent_model import ResponseModel, DepsModel, InputModel from pydantic import Field -#意识节点回复类 +# 意识节点回复类 class ConsciousnessNodeResponse(ResponseModel): """Consciousness response model,是意识节点所有回复类型的父类""" + pass class ForWorkflowEngine(ConsciousnessNodeResponse): """生成workflow并放入WorkflowEngine""" - workflow: PretorWorkflow = Field(..., description="生成好的符合规范的完整工作流对象。") + + workflow: PretorWorkflow = Field( + ..., description="生成好的符合规范的完整工作流对象。" + ) reasoning: str = Field(..., description="生成此工作流的原因和思路简述。") class ForWorkflow(ConsciousnessNodeResponse): """处理workflow中需要ConsciousnessNode的工作""" + output: str = Field(..., description="对当前工作流步骤的具体处理结果或指导意见。") class ForSupervisoryNode(ConsciousnessNodeResponse): """工作流完成后进行校验并返回给SupervisoryNode""" - output: str = Field(..., description="为监控节点提供的全工作流执行情况的技术性总结报告。") + + output: str = Field( + ..., description="为监控节点提供的全工作流执行情况的技术性总结报告。" + ) class ConsciousnessNodeDeps(DepsModel): """ConsciousnessNodeDeps 核心组件类。 - 这是一个系统执行节点类,作为多智能体架构中的独立处理单元。它能够接收工作流上下文,根据内置的大模型策略进行意图理解和自主决策,从而驱动特定阶段的任务闭环。 """ + 这是一个系统执行节点类,作为多智能体架构中的独立处理单元。它能够接收工作流上下文,根据内置的大模型策略进行意图理解和自主决策,从而驱动特定阶段的任务闭环。""" + original_command: str - workflow_template: str | None = None command: str available_skills: list[dict] | None = None class ConsciousnessNodeInput(InputModel): """ConsciousnessNodeInput 核心组件类。 - 这是一个系统执行节点类,作为多智能体架构中的独立处理单元。它能够接收工作流上下文,根据内置的大模型策略进行意图理解和自主决策,从而驱动特定阶段的任务闭环。 """ + 这是一个系统执行节点类,作为多智能体架构中的独立处理单元。它能够接收工作流上下文,根据内置的大模型策略进行意图理解和自主决策,从而驱动特定阶段的任务闭环。""" + pass class ForWorkflowEngineInput(ConsciousnessNodeInput): """ForWorkflowEngineInput 核心组件类。 - 这是一个领域数据模型或功能封装类,承载了 ForWorkflowEngineInput 相关的内聚属性定义与状态维护。它的存在隔离了局部的业务复杂性,并对外提供了类型安全的访问接口。 """ - workflow_template: str | None = None + 这是一个领域数据模型或功能封装类,承载了 ForWorkflowEngineInput 相关的内聚属性定义与状态维护。它的存在隔离了局部的业务复杂性,并对外提供了类型安全的访问接口。""" + original_command: str available_skills: list[dict] | None = None class ForWorkflowInput(ConsciousnessNodeInput): """ForWorkflowInput 核心组件类。 - 这是一个领域数据模型或功能封装类,承载了 ForWorkflowInput 相关的内聚属性定义与状态维护。它的存在隔离了局部的业务复杂性,并对外提供了类型安全的访问接口。 """ + 这是一个领域数据模型或功能封装类,承载了 ForWorkflowInput 相关的内聚属性定义与状态维护。它的存在隔离了局部的业务复杂性,并对外提供了类型安全的访问接口。""" + workflow_step: WorkStep original_command: str class ForSupervisoryInput(ConsciousnessNodeInput): """ForSupervisoryInput 核心组件类。 - 这是一个领域数据模型或功能封装类,承载了 ForSupervisoryInput 相关的内聚属性定义与状态维护。它的存在隔离了局部的业务复杂性,并对外提供了类型安全的访问接口。 """ + 这是一个领域数据模型或功能封装类,承载了 ForSupervisoryInput 相关的内聚属性定义与状态维护。它的存在隔离了局部的业务复杂性,并对外提供了类型安全的访问接口。""" + workflow: PretorWorkflow original_command: str diff --git a/pretor/core/individual/control_node/__init__.py b/pretor/core/individual/control_node/__init__.py index 152c5ff..9948e4e 100644 --- a/pretor/core/individual/control_node/__init__.py +++ b/pretor/core/individual/control_node/__init__.py @@ -13,4 +13,5 @@ # limitations under the License. from .control_node import ControlNode + __all__ = ["ControlNode"] diff --git a/pretor/core/individual/control_node/control_node.py b/pretor/core/individual/control_node/control_node.py index 468befd..eaf9468 100644 --- a/pretor/core/individual/control_node/control_node.py +++ b/pretor/core/individual/control_node/control_node.py @@ -17,21 +17,31 @@ from pydantic_ai import Agent, RunContext from pretor.core.global_state_machine.global_state_machine import GlobalStateMachine from pretor.core.global_state_machine.model_provider.base_provider import Provider from pretor.adapter.model_adapter.agent_factory import AgentFactory -from pretor.core.individual.control_node.template import ForWorkflow, ForWorkflowInput, ControlNodeDeps - +from pretor.core.individual.control_node.template import ( + ForWorkflow, + ForWorkflowInput, + ControlNodeDeps, +) @ray.remote class ControlNode: """ControlNode 核心组件类。 - 这是一个系统执行节点类,作为多智能体架构中的独立处理单元。它能够接收工作流上下文,根据内置的大模型策略进行意图理解和自主决策,从而驱动特定阶段的任务闭环。 """ + 这是一个系统执行节点类,作为多智能体架构中的独立处理单元。它能够接收工作流上下文,根据内置的大模型策略进行意图理解和自主决策,从而驱动特定阶段的任务闭环。""" + def __init__(self): from pretor.utils.logger import get_logger - self.logger = get_logger('control_node') + + self.logger = get_logger("control_node") self.agent: Agent | None = None - - async def create_agent(self, global_state_machine: GlobalStateMachine, provider_title: str, model_id: str, tools_list: list[str] = None) -> None: + async def create_agent( + self, + global_state_machine: GlobalStateMachine, + provider_title: str, + model_id: str, + tools_list: list[str] = None, + ) -> None: """ create_agent方法,将agent对象装配到Control的属性内 该方法通过provider_title从global_state_machine中获取provider对象,然后从provider对象中取出供应商形象,装配为pydantic_ai的 @@ -56,23 +66,29 @@ class ControlNode: ) output_type = ForWorkflow from pretor.utils.get_tool import load_tools_from_list - provider: Provider = await global_state_machine.get_provider.remote( provider_title) + + provider: Provider = await global_state_machine.get_provider.remote( + provider_title + ) agent_factory = AgentFactory() callables = load_tools_from_list(tools_list) - self.agent = agent_factory.create_agent(provider=provider, - model_id=model_id, - output_type=output_type, - system_prompt=system_prompt, - deps_type=ControlNodeDeps, - agent_name="control_node", - tools=callables) + self.agent = agent_factory.create_agent( + provider=provider, + model_id=model_id, + output_type=output_type, + system_prompt=system_prompt, + deps_type=ControlNodeDeps, + agent_name="control_node", + tools=callables, + ) + @self.agent.system_prompt async def dynamic_prompt(ctx: RunContext[ControlNodeDeps]): """执行与 dynamic prompt 相关的核心业务流转操作。 该方法封装了具体的算法策略或状态控制逻辑,确保操作能够在事务上下文中被原子且一致地执行。 Args: ctx (RunContext[ControlNodeDeps]): 参与 dynamic prompt 逻辑运算或数据构建的上下文依赖对象。 - Returns: : 经由当前业务模型加工处理后所输出的具体数据实例或领域模型对象。 """ + Returns: : 经由当前业务模型加工处理后所输出的具体数据实例或领域模型对象。""" prompt = system_prompt + "\n\n" prompt += ( f"=== 当前任务步骤上下文 ===\n" @@ -86,7 +102,7 @@ class ControlNode: """执行与 working 相关的核心业务流转操作。 该方法封装了具体的算法策略或状态控制逻辑,确保操作能够在事务上下文中被原子且一致地执行。 Args: payload (ForWorkflowInput): 从客户端传递过来或由上游组件生成的核心业务数据体,通常需要进一步的清洗和结构化解析。 - Returns: (str): 处理流程所输出的具体字符串产物,可能是新生成的 ID 序列、格式化好的文本片段或 LLM 推理的回答内容。 """ + Returns: (str): 处理流程所输出的具体字符串产物,可能是新生成的 ID 序列、格式化好的文本片段或 LLM 推理的回答内容。""" try: result: ForWorkflow = await self._run(payload) return result @@ -98,19 +114,21 @@ class ControlNode: """执行与 run 相关的核心业务流转操作。 该方法封装了具体的算法策略或状态控制逻辑,确保操作能够在事务上下文中被原子且一致地执行。 Args: payload (ForWorkflowInput): 从客户端传递过来或由上游组件生成的核心业务数据体,通常需要进一步的清洗和结构化解析。 - Returns: (ForWorkflow): 经由当前业务模型加工处理后所输出的具体数据实例或领域模型对象。 """ + Returns: (ForWorkflow): 经由当前业务模型加工处理后所输出的具体数据实例或领域模型对象。""" try: self.agent.retries = 3 - deps = ControlNodeDeps( - workflow_step=payload.workflow_step + deps = ControlNodeDeps(workflow_step=payload.workflow_step) + self.logger.debug( + f"ControlNode: 开始执行工作流节点 [{payload.workflow_step.name}] (原生重试开启)" ) - self.logger.debug(f"ControlNode: 开始执行工作流节点 [{payload.workflow_step.name}] (原生重试开启)") result = await self.agent.run( f"请根据提供的 workflow_step 上下文,执行此步骤并输出结果。\n详细指令或附加数据:{payload.workflow_step.model_dump_json()}", - deps=deps + deps=deps, ) return result.output except Exception as e: - self.logger.exception(f"ControlNode 在执行步骤 [{payload.workflow_step.name}] 时最终失败: {str(e)}") + self.logger.exception( + f"ControlNode 在执行步骤 [{payload.workflow_step.name}] 时最终失败: {str(e)}" + ) raise RuntimeError(f"ControlNode 执行步骤失败: {str(e)}") from e diff --git a/pretor/core/individual/control_node/template.py b/pretor/core/individual/control_node/template.py index 0ccf3a4..55f88fe 100644 --- a/pretor/core/individual/control_node/template.py +++ b/pretor/core/individual/control_node/template.py @@ -17,31 +17,39 @@ from pydantic import Field from pretor.core.workflow.workflow import WorkStep from pretor.utils.agent_model import ResponseModel, InputModel, DepsModel + class ControlNodeResponse(ResponseModel): """控制节点回复的基类""" + pass class ControlNodeInput(InputModel): """ControlNodeInput 核心组件类。 - 这是一个系统执行节点类,作为多智能体架构中的独立处理单元。它能够接收工作流上下文,根据内置的大模型策略进行意图理解和自主决策,从而驱动特定阶段的任务闭环。 """ + 这是一个系统执行节点类,作为多智能体架构中的独立处理单元。它能够接收工作流上下文,根据内置的大模型策略进行意图理解和自主决策,从而驱动特定阶段的任务闭环。""" + pass class ControlNodeDeps(DepsModel): """ControlNodeDeps 核心组件类。 - 这是一个系统执行节点类,作为多智能体架构中的独立处理单元。它能够接收工作流上下文,根据内置的大模型策略进行意图理解和自主决策,从而驱动特定阶段的任务闭环。 """ + 这是一个系统执行节点类,作为多智能体架构中的独立处理单元。它能够接收工作流上下文,根据内置的大模型策略进行意图理解和自主决策,从而驱动特定阶段的任务闭环。""" + workflow_step: WorkStep # In the future, this can be dynamically populated with tools specific to the current task execution class ForWorkflow(ControlNodeResponse): """ForWorkflow 核心组件类。 - 这是一个领域数据模型或功能封装类,承载了 ForWorkflow 相关的内聚属性定义与状态维护。它的存在隔离了局部的业务复杂性,并对外提供了类型安全的访问接口。 """ - output: str = Field(..., description="控制节点执行特定工作流步骤的结果。包含执行细节和输出数据。") + 这是一个领域数据模型或功能封装类,承载了 ForWorkflow 相关的内聚属性定义与状态维护。它的存在隔离了局部的业务复杂性,并对外提供了类型安全的访问接口。""" + + output: str = Field( + ..., description="控制节点执行特定工作流步骤的结果。包含执行细节和输出数据。" + ) class ForWorkflowInput(ControlNodeInput): """ForWorkflowInput 核心组件类。 - 这是一个领域数据模型或功能封装类,承载了 ForWorkflowInput 相关的内聚属性定义与状态维护。它的存在隔离了局部的业务复杂性,并对外提供了类型安全的访问接口。 """ + 这是一个领域数据模型或功能封装类,承载了 ForWorkflowInput 相关的内聚属性定义与状态维护。它的存在隔离了局部的业务复杂性,并对外提供了类型安全的访问接口。""" + workflow_step: WorkStep diff --git a/pretor/core/individual/supervisory_node/__init__.py b/pretor/core/individual/supervisory_node/__init__.py index 002ac27..49adab5 100644 --- a/pretor/core/individual/supervisory_node/__init__.py +++ b/pretor/core/individual/supervisory_node/__init__.py @@ -13,4 +13,5 @@ # limitations under the License. from .supervisory_node import SupervisoryNode + __all__ = ["SupervisoryNode"] diff --git a/pretor/core/individual/supervisory_node/supervisory_node.py b/pretor/core/individual/supervisory_node/supervisory_node.py index 2f46c67..921ef61 100644 --- a/pretor/core/individual/supervisory_node/supervisory_node.py +++ b/pretor/core/individual/supervisory_node/supervisory_node.py @@ -19,7 +19,12 @@ from pretor.api.platform.event import PretorEvent from pretor.adapter.model_adapter.agent_factory import AgentFactory from pretor.core.global_state_machine.global_state_machine import GlobalStateMachine from pretor.core.global_state_machine.model_provider import Provider -from pretor.core.individual.supervisory_node.template import ForConsciousnessNode, ForUser, SupervisoryNodeDeps, TerminationMessage +from pretor.core.individual.supervisory_node.template import ( + ForConsciousnessNode, + ForUser, + SupervisoryNodeDeps, + TerminationMessage, +) from pydantic_ai import RunContext, Agent from pretor.utils.ray_hook import ray_actor_hook @@ -27,14 +32,21 @@ from pretor.utils.ray_hook import ray_actor_hook @ray.remote class SupervisoryNode: """SupervisoryNode 核心组件类。 - 这是一个系统执行节点类,作为多智能体架构中的独立处理单元。它能够接收工作流上下文,根据内置的大模型策略进行意图理解和自主决策,从而驱动特定阶段的任务闭环。 """ + 这是一个系统执行节点类,作为多智能体架构中的独立处理单元。它能够接收工作流上下文,根据内置的大模型策略进行意图理解和自主决策,从而驱动特定阶段的任务闭环。""" + def __init__(self) -> None: from pretor.utils.logger import get_logger - self.logger = get_logger('supervisory_node') + + self.logger = get_logger("supervisory_node") self.agent: None | Agent = None - - async def create_agent(self, global_state_machine: GlobalStateMachine, provider_title: str, model_id: str, tools_list: list[str] = None) -> None: + async def create_agent( + self, + global_state_machine: GlobalStateMachine, + provider_title: str, + model_id: str, + tools_list: list[str] = None, + ) -> None: """ create_agent方法,将agent对象装配到SupervisoryNode的属性内 该方法通过provider_title从global_state_machine中获取provider对象,然后从provider对象中取出供应商形象,装配为pydantic_ai的Agent实例, @@ -53,43 +65,47 @@ class SupervisoryNode: "你的核心职责是进行【意图识别与路由】。请仔细阅读用户的请求:\n" "1. 如果用户只是进行简单的问候、闲聊或查询非常基础的信息,请直接生成友好的回复,使用 ForUser 格式。\n" "2. 如果用户提出的是复杂任务(如需要编写代码、多步骤规划、数据处理等),请务必将其判定为需要工作流处理的任务," - " 并使用 ForConsciousnessNode 格式。若提供的【可用模板列表】中有合适的模板请选用,若都不匹配则 workflow_template 设为 null。\n" + " 并使用 ForConsciousnessNode 格式将其移交意识节点处理。\n" "3. 如果你收到的是 TerminationMessage(代表工作流已完成并生成了报告),请将报告内容转化为友好的面向用户的回复,使用 ForUser 格式。\n" "请保持冷静、专业,并严格遵循上述路由规则。" ) output_type = Union[ForConsciousnessNode, ForUser] from pretor.utils.get_tool import load_tools_from_list - provider: Provider = await global_state_machine.get_provider.remote( provider_title) + + provider: Provider = await global_state_machine.get_provider.remote( + provider_title + ) agent_factory = AgentFactory() callables = load_tools_from_list(tools_list) - self.agent = agent_factory.create_agent(provider=provider, - model_id=model_id, - output_type=output_type, - system_prompt=system_prompt, - deps_type=SupervisoryNodeDeps, - agent_name="supervisory_node", - tools=callables) + self.agent = agent_factory.create_agent( + provider=provider, + model_id=model_id, + output_type=output_type, + system_prompt=system_prompt, + deps_type=SupervisoryNodeDeps, + agent_name="supervisory_node", + tools=callables, + ) @self.agent.system_prompt async def dynamic_prompt(ctx: RunContext[SupervisoryNodeDeps]): """执行与 dynamic prompt 相关的核心业务流转操作。 该方法封装了具体的算法策略或状态控制逻辑,确保操作能够在事务上下文中被原子且一致地执行。 Args: ctx (RunContext[SupervisoryNodeDeps]): 参与 dynamic prompt 逻辑运算或数据构建的上下文依赖对象。 - Returns: : 经由当前业务模型加工处理后所输出的具体数据实例或领域模型对象。 """ + Returns: : 经由当前业务模型加工处理后所输出的具体数据实例或领域模型对象。""" prompt = system_prompt + "\n\n" prompt += ( f"=== 当前上下文 ===\n" f"- 平台 (Platform): {ctx.deps.platform}\n" f"- 用户名 (User): {ctx.deps.user_name}\n" f"- 当前时间 (Time): {ctx.deps.time}\n" - f"- 可用工作流模板 (Available Templates): {ctx.deps.available_templates}\n" ) # 修改 system_prompt 变量 prompt += ( "\n\n注意:你必须调用且只能调用一个函数(工具)来输出结果。" "如果你想直接回复用户,请调用 ForUser;" - "如果你想移交给工作流,请调用 ForConsciousnessNode(若没有合适的模板,workflow_template 填 null)。" + "如果你想移交给工作流,请调用 ForConsciousnessNode。" "严禁返回纯文本,必须使用工具格式!" ) if ctx.deps.error_history: @@ -113,16 +129,21 @@ class SupervisoryNode: try: result = await self._run(payload) if isinstance(result, ForConsciousnessNode): - self.logger.info(f"SupervisoryNode: 任务已分配给工作流引擎处理,选用模板 [{result.workflow_template}]") + self.logger.info("SupervisoryNode: 任务已分配给工作流引擎处理") if isinstance(payload, PretorEvent): - payload.context["workflow_template"] = result.workflow_template try: - global_workflow_manager = ray_actor_hook("global_workflow_manager").global_workflow_manager + global_workflow_manager = ray_actor_hook( + "global_workflow_manager" + ).global_workflow_manager await global_workflow_manager.add_event.remote(payload) - workflow_running_engine = ray_actor_hook("workflow_running_engine").workflow_running_engine + workflow_running_engine = ray_actor_hook( + "workflow_running_engine" + ).workflow_running_engine await workflow_running_engine.put_event.remote(payload) except Exception as e: - self.logger.error(f"SupervisoryNode: 无法将事件放入 WorkflowRunningEngine: {e}") + self.logger.error( + f"SupervisoryNode: 无法将事件放入 WorkflowRunningEngine: {e}" + ) return "抱歉,任务提交失败,系统内部错误。" return f"任务已创建,准备创建工作流。原因:{result.reasoning}" elif isinstance(result, ForUser): @@ -144,7 +165,7 @@ class SupervisoryNode: Returns: ForUser对象,监控节点对于用户进行的简单回答 - ForConsciousnessNode对象,监控节点将用户的请求判断为复杂任务,将PretorEvent传递给意识节点,并且给选择好的工作流模板 + ForConsciousnessNode对象,监控节点将用户的请求判断为复杂任务,将PretorEvent传递给意识节点 """ ... @@ -160,7 +181,9 @@ class SupervisoryNode: """ ... - async def _run(self, payload: Union[PretorEvent, TerminationMessage]) -> Union[ForConsciousnessNode, ForUser]: + async def _run( + self, payload: Union[PretorEvent, TerminationMessage] + ) -> Union[ForConsciousnessNode, ForUser]: """ _run方法,将payload转化为对llm发送的消息并发送 Args: @@ -175,23 +198,15 @@ class SupervisoryNode: message = payload.message time_str = datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S") try: - global_state_machine = ray_actor_hook("global_state_machine").global_state_machine - workflow_template_dict = await global_state_machine.get_all_workflow_templates.remote() - available_templates_str = "\n".join([f"- 名称: {k}, 描述/内容: {v}" for k, v in - workflow_template_dict.items()]) if workflow_template_dict else "暂无注册的工作流模板" deps = SupervisoryNodeDeps( - platform=platform, - user_name=user_name, - time=time_str, - available_templates=available_templates_str + platform=platform, user_name=user_name, time=time_str ) self.logger.debug("SupervisoryNode 开始生成 (启用原生 Pydantic-AI 重试)") prompt_message = message if isinstance(payload, TerminationMessage): prompt_message = f"【工作流执行结束报告】\n请将以下技术报告转化为对用户的友好回复:\n{message}" self.agent.retries = 3 - result = await self.agent.run(prompt_message, - deps=deps) + result = await self.agent.run(prompt_message, deps=deps) return result.output except Exception as e: self.logger.exception(f"SupervisoryNode 模型生成或解析最终失败: {str(e)}") diff --git a/pretor/core/individual/supervisory_node/template.py b/pretor/core/individual/supervisory_node/template.py index e72c9da..8f7ab54 100644 --- a/pretor/core/individual/supervisory_node/template.py +++ b/pretor/core/individual/supervisory_node/template.py @@ -16,35 +16,46 @@ from pydantic import Field from pretor.utils.agent_model import ResponseModel, DepsModel from pydantic import BaseModel + class SupervisoryNodeResponse(ResponseModel): """SupervisoryNodeResponse 核心组件类。 - 这是一个系统执行节点类,作为多智能体架构中的独立处理单元。它能够接收工作流上下文,根据内置的大模型策略进行意图理解和自主决策,从而驱动特定阶段的任务闭环。 """ + 这是一个系统执行节点类,作为多智能体架构中的独立处理单元。它能够接收工作流上下文,根据内置的大模型策略进行意图理解和自主决策,从而驱动特定阶段的任务闭环。""" + pass + class ForUser(SupervisoryNodeResponse): """ForUser 核心组件类。 - 这是一个领域数据模型或功能封装类,承载了 ForUser 相关的内聚属性定义与状态维护。它的存在隔离了局部的业务复杂性,并对外提供了类型安全的访问接口。 """ - context: str = Field(..., description="对用户的回复,应当使用和蔼的语气进行回复。用于直接解答简单问题或返回最终报告。") + 这是一个领域数据模型或功能封装类,承载了 ForUser 相关的内聚属性定义与状态维护。它的存在隔离了局部的业务复杂性,并对外提供了类型安全的访问接口。""" + + context: str = Field( + ..., + description="对用户的回复,应当使用和蔼的语气进行回复。用于直接解答简单问题或返回最终报告。", + ) + class ForConsciousnessNode(SupervisoryNodeResponse): """ForConsciousnessNode 核心组件类。 - 这是一个系统执行节点类,作为多智能体架构中的独立处理单元。它能够接收工作流上下文,根据内置的大模型策略进行意图理解和自主决策,从而驱动特定阶段的任务闭环。 """ - workflow_template: str | None = Field(default=None, description="选择的工作流模板的名称,用于处理复杂任务。若无需模板则为 None。") - reasoning: str = Field(..., description="选择将任务移交意识节点并选用该模板的简短原因。") + 这是一个系统执行节点类,作为多智能体架构中的独立处理单元。它能够接收工作流上下文,根据内置的大模型策略进行意图理解和自主决策,从而驱动特定阶段的任务闭环。""" + + reasoning: str = Field(..., description="选择将任务移交意识节点的简短原因。") + class TerminationMessage(BaseModel): """TerminationMessage 核心组件类。 - 这是一个领域数据模型或功能封装类,承载了 TerminationMessage 相关的内聚属性定义与状态维护。它的存在隔离了局部的业务复杂性,并对外提供了类型安全的访问接口。 """ + 这是一个领域数据模型或功能封装类,承载了 TerminationMessage 相关的内聚属性定义与状态维护。它的存在隔离了局部的业务复杂性,并对外提供了类型安全的访问接口。""" + platform: str user_name: str message: str + class SupervisoryNodeDeps(DepsModel): """SupervisoryNodeDeps 核心组件类。 - 这是一个系统执行节点类,作为多智能体架构中的独立处理单元。它能够接收工作流上下文,根据内置的大模型策略进行意图理解和自主决策,从而驱动特定阶段的任务闭环。 """ + 这是一个系统执行节点类,作为多智能体架构中的独立处理单元。它能够接收工作流上下文,根据内置的大模型策略进行意图理解和自主决策,从而驱动特定阶段的任务闭环。""" + platform: str user_name: str time: str retry_count: int = 0 error_history: str = "" - available_templates: str = "默认工作流 (default_workflow)" \ No newline at end of file diff --git a/pretor/core/postgres_database/__init__.py b/pretor/core/postgres_database/__init__.py new file mode 100644 index 0000000..f722554 --- /dev/null +++ b/pretor/core/postgres_database/__init__.py @@ -0,0 +1,3 @@ +from pretor.core.postgres_database.postgres import PostgresDatabase + +__all__ = ["PostgresDatabase"] diff --git a/pretor/core/database/postgres.py b/pretor/core/postgres_database/postgres.py similarity index 87% rename from pretor/core/database/postgres.py rename to pretor/core/postgres_database/postgres.py index d1b978a..baa962a 100644 --- a/pretor/core/database/postgres.py +++ b/pretor/core/postgres_database/postgres.py @@ -26,19 +26,25 @@ from pretor.core.database.module.user import AuthDatabase from pretor.core.database.module.provider import ProviderDatabase from pretor.core.database.module.system_node import SystemNodeDatabase + @ray.remote class PostgresDatabase: """PostgresDatabase 核心组件类。 - 这是一个数据库操作层 (DAO/Repository) 封装类,专注于处理实体模型与关系型数据库表之间的映射。它将复杂的 SQL 查询、跨表 Join 和事务回滚逻辑进行了高级抽象,向上层服务暴露简洁的数据读写接口。 """ + 这是一个数据库操作层 (DAO/Repository) 封装类,专注于处理实体模型与关系型数据库表之间的映射。它将复杂的 SQL 查询、跨表 Join 和事务回滚逻辑进行了高级抽象,向上层服务暴露简洁的数据读写接口。""" + def __init__(self): - user = os.environ.get('POSTGRES_USER') - password = os.environ.get('POSTGRES_PASSWORD') - host = os.environ.get('POSTGRES_HOST') - port = os.environ.get('POSTGRES_PORT') - database = os.environ.get('POSTGRES_DB') - database_url = f"postgresql+asyncpg://{user}:{password}@{host}:{port}/{database}" + user = os.environ.get("POSTGRES_USER") + password = os.environ.get("POSTGRES_PASSWORD") + host = os.environ.get("POSTGRES_HOST") + port = os.environ.get("POSTGRES_PORT") + database = os.environ.get("POSTGRES_DB") + database_url = ( + f"postgresql+asyncpg://{user}:{password}@{host}:{port}/{database}" + ) self.async_engine = create_async_engine(database_url, echo=True) - self.async_session_maker = sessionmaker(self.async_engine, class_=AsyncSession, expire_on_commit=False) + self.async_session_maker = sessionmaker( + self.async_engine, class_=AsyncSession, expire_on_commit=False + ) self._auth_database = AuthDatabase(self.async_session_maker) self._provider_database = ProviderDatabase(self.async_session_maker) @@ -51,7 +57,7 @@ class PostgresDatabase: async def init_db(self) -> None: """完成 db 模块的启动与依赖初始化。 在系统引导或服务拉起阶段被调用,负责建立网络连接、分配基础内存资源及注册核心服务组件。 - Returns: (None): 经由当前业务模型加工处理后所输出的具体数据实例或领域模型对象。 """ + Returns: (None): 经由当前业务模型加工处理后所输出的具体数据实例或领域模型对象。""" try: async with self.async_engine.begin() as conn: await conn.run_sync(SQLModel.metadata.create_all) @@ -67,7 +73,7 @@ class PostgresDatabase: """创建并持久化新的 user 实体。 接收构建参数,执行必要的数据校验与默认值填充后,将新记录安全地写入底层存储或系统注册表中。 Args: user_name (str): 赋予该实体的人类可读名称或标题字符串,主要用于前端 UI 展示、日志记录或模糊检索。 hashed_password (str): 控制逻辑流向的具体字符串参数,指定了期望的 hashed password 内容。 - Returns: : 经由当前业务模型加工处理后所输出的具体数据实例或领域模型对象。 """ + Returns: : 经由当前业务模型加工处理后所输出的具体数据实例或领域模型对象。""" await self.ready_event.wait() return await self._auth_database.add_user(user_name, hashed_password) @@ -75,15 +81,17 @@ class PostgresDatabase: """执行与 change password 相关的核心业务流转操作。 该方法封装了具体的算法策略或状态控制逻辑,确保操作能够在事务上下文中被原子且一致地执行。 Args: user_name: 赋予该实体的人类可读名称或标题字符串,主要用于前端 UI 展示、日志记录或模糊检索。 old_password: 参与 change password 逻辑运算或数据构建的上下文依赖对象。 new_password: 参与 change password 逻辑运算或数据构建的上下文依赖对象。 - Returns: : 经由当前业务模型加工处理后所输出的具体数据实例或领域模型对象。 """ + Returns: : 经由当前业务模型加工处理后所输出的具体数据实例或领域模型对象。""" await self.ready_event.wait() - return await self._auth_database.change_password(user_name, old_password, new_password) + return await self._auth_database.change_password( + user_name, old_password, new_password + ) async def delete_user(self, user_name: str): """安全地移除或注销 user。 执行物理删除或逻辑删除操作,并妥善清理相关的关联数据及占用资源。 Args: user_name (str): 赋予该实体的人类可读名称或标题字符串,主要用于前端 UI 展示、日志记录或模糊检索。 - Returns: : 经由当前业务模型加工处理后所输出的具体数据实例或领域模型对象。 """ + Returns: : 经由当前业务模型加工处理后所输出的具体数据实例或领域模型对象。""" await self.ready_event.wait() return await self._auth_database.delete_user(user_name) @@ -91,7 +99,7 @@ class PostgresDatabase: """安全地移除或注销 user by id。 执行物理删除或逻辑删除操作,并妥善清理相关的关联数据及占用资源。 Args: user_id (str): 目标对象的唯一全局标识符 (UUID/ULID),用于在数据库表或缓存结构中精准匹配该 user 实例。 - Returns: : 经由当前业务模型加工处理后所输出的具体数据实例或领域模型对象。 """ + Returns: : 经由当前业务模型加工处理后所输出的具体数据实例或领域模型对象。""" await self.ready_event.wait() return await self._auth_database.delete_user_by_id(user_id) @@ -99,14 +107,14 @@ class PostgresDatabase: """执行与 login user 相关的核心业务流转操作。 该方法封装了具体的算法策略或状态控制逻辑,确保操作能够在事务上下文中被原子且一致地执行。 Args: user_name (str): 赋予该实体的人类可读名称或标题字符串,主要用于前端 UI 展示、日志记录或模糊检索。 - Returns: : 经由当前业务模型加工处理后所输出的具体数据实例或领域模型对象。 """ + Returns: : 经由当前业务模型加工处理后所输出的具体数据实例或领域模型对象。""" await self.ready_event.wait() return await self._auth_database.login_user(user_name) async def get_all_users(self): """检索并获取特定的 all users 数据集合或实例对象。 根据提供的查询条件或上下文凭证,从数据库、缓存或第三方服务中读取对应的资源状态。 - Returns: : 经由当前业务模型加工处理后所输出的具体数据实例或领域模型对象。 """ + Returns: : 经由当前业务模型加工处理后所输出的具体数据实例或领域模型对象。""" await self.ready_event.wait() return await self._auth_database.get_all_users() @@ -114,7 +122,7 @@ class PostgresDatabase: """检索并获取特定的 user authority 数据集合或实例对象。 根据提供的查询条件或上下文凭证,从数据库、缓存或第三方服务中读取对应的资源状态。 Args: user_id (str): 目标对象的唯一全局标识符 (UUID/ULID),用于在数据库表或缓存结构中精准匹配该 user 实例。 - Returns: : 经由当前业务模型加工处理后所输出的具体数据实例或领域模型对象。 """ + Returns: : 经由当前业务模型加工处理后所输出的具体数据实例或领域模型对象。""" await self.ready_event.wait() return await self._auth_database.get_user_authority(user_id) @@ -122,7 +130,7 @@ class PostgresDatabase: """执行与 change user authority 相关的核心业务流转操作。 该方法封装了具体的算法策略或状态控制逻辑,确保操作能够在事务上下文中被原子且一致地执行。 Args: user_id (str): 目标对象的唯一全局标识符 (UUID/ULID),用于在数据库表或缓存结构中精准匹配该 user 实例。 new_authority: 参与 change user authority 逻辑运算或数据构建的上下文依赖对象。 - Returns: : 经由当前业务模型加工处理后所输出的具体数据实例或领域模型对象。 """ + Returns: : 经由当前业务模型加工处理后所输出的具体数据实例或领域模型对象。""" await self.ready_event.wait() return await self._auth_database.change_user_authority(user_id, new_authority) @@ -130,14 +138,14 @@ class PostgresDatabase: async def get_provider(self): """检索并获取特定的 provider 数据集合或实例对象。 根据提供的查询条件或上下文凭证,从数据库、缓存或第三方服务中读取对应的资源状态。 - Returns: : 经由当前业务模型加工处理后所输出的具体数据实例或领域模型对象。 """ + Returns: : 经由当前业务模型加工处理后所输出的具体数据实例或领域模型对象。""" await self.ready_event.wait() return await self._provider_database.get_provider() async def add_provider_db(self, **kwargs): """创建并持久化新的 provider db 实体。 接收构建参数,执行必要的数据校验与默认值填充后,将新记录安全地写入底层存储或系统注册表中。 - Returns: : 经由当前业务模型加工处理后所输出的具体数据实例或领域模型对象。 """ + Returns: : 经由当前业务模型加工处理后所输出的具体数据实例或领域模型对象。""" await self.ready_event.wait() return await self._provider_database.add_provider(**kwargs) @@ -145,7 +153,7 @@ class PostgresDatabase: """安全地移除或注销 provider db。 执行物理删除或逻辑删除操作,并妥善清理相关的关联数据及占用资源。 Args: provider_id (str): 目标对象的唯一全局标识符 (UUID/ULID),用于在数据库表或缓存结构中精准匹配该 provider 实例。 - Returns: : 经由当前业务模型加工处理后所输出的具体数据实例或领域模型对象。 """ + Returns: : 经由当前业务模型加工处理后所输出的具体数据实例或领域模型对象。""" await self.ready_event.wait() return await self._provider_database.delete_provider(provider_id) @@ -153,23 +161,31 @@ class PostgresDatabase: """对现有的 provider db 进行状态更新或属性覆盖。 基于增量变更原则,合并最新的配置或数据,并触发相关依赖组件的缓存刷新或事件通知。 Args: provider_id (str): 目标对象的唯一全局标识符 (UUID/ULID),用于在数据库表或缓存结构中精准匹配该 provider 实例。 - Returns: : 经由当前业务模型加工处理后所输出的具体数据实例或领域模型对象。 """ + Returns: : 经由当前业务模型加工处理后所输出的具体数据实例或领域模型对象。""" await self.ready_event.wait() return await self._provider_database.update_provider(provider_id, **kwargs) # System Node Database Methods - async def upsert_system_node_config(self, node_name: str, provider_title: str, model_id: str, tools: list[str] = None): + async def upsert_system_node_config( + self, + node_name: str, + provider_title: str, + model_id: str, + tools: list[str] = None, + ): """执行与 upsert system node config 相关的核心业务流转操作。 该方法封装了具体的算法策略或状态控制逻辑,确保操作能够在事务上下文中被原子且一致地执行。 Args: node_name (str): 赋予该实体的人类可读名称或标题字符串,主要用于前端 UI 展示、日志记录或模糊检索。 provider_title (str): 目标对象的唯一全局标识符 (UUID/ULID),用于在数据库表或缓存结构中精准匹配该 provider_title 实例。 model_id (str): 目标对象的唯一全局标识符 (UUID/ULID),用于在数据库表或缓存结构中精准匹配该 model 实例。 tools (list[str]): 控制逻辑流向的具体字符串参数,指定了期望的 tools 内容。 - Returns: : 经由当前业务模型加工处理后所输出的具体数据实例或领域模型对象。 """ + Returns: : 经由当前业务模型加工处理后所输出的具体数据实例或领域模型对象。""" await self.ready_event.wait() - return await self._system_node_database.upsert_system_node_config(node_name, provider_title, model_id, tools) + return await self._system_node_database.upsert_system_node_config( + node_name, provider_title, model_id, tools + ) async def get_all_system_node_configs(self): """检索并获取特定的 all system node configs 数据集合或实例对象。 根据提供的查询条件或上下文凭证,从数据库、缓存或第三方服务中读取对应的资源状态。 - Returns: : 经由当前业务模型加工处理后所输出的具体数据实例或领域模型对象。 """ + Returns: : 经由当前业务模型加工处理后所输出的具体数据实例或领域模型对象。""" await self.ready_event.wait() return await self._system_node_database.get_all_system_node_configs() @@ -177,7 +193,7 @@ class PostgresDatabase: async def add_worker_individual(self, **kwargs): """创建并持久化新的 worker individual 实体。 接收构建参数,执行必要的数据校验与默认值填充后,将新记录安全地写入底层存储或系统注册表中。 - Returns: : 经由当前业务模型加工处理后所输出的具体数据实例或领域模型对象。 """ + Returns: : 经由当前业务模型加工处理后所输出的具体数据实例或领域模型对象。""" await self.ready_event.wait() return await self._individual_database.add_worker_individual(**kwargs) @@ -185,7 +201,7 @@ class PostgresDatabase: """检索并获取特定的 worker individual 数据集合或实例对象。 根据提供的查询条件或上下文凭证,从数据库、缓存或第三方服务中读取对应的资源状态。 Args: agent_id (str): 目标对象的唯一全局标识符 (UUID/ULID),用于在数据库表或缓存结构中精准匹配该 agent 实例。 - Returns: : 经由当前业务模型加工处理后所输出的具体数据实例或领域模型对象。 """ + Returns: : 经由当前业务模型加工处理后所输出的具体数据实例或领域模型对象。""" await self.ready_event.wait() return await self._individual_database.get_worker_individual(agent_id) @@ -193,7 +209,7 @@ class PostgresDatabase: """检索并获取特定的 worker individual list 数据集合或实例对象。 根据提供的查询条件或上下文凭证,从数据库、缓存或第三方服务中读取对应的资源状态。 Args: owner_id (str): 目标对象的唯一全局标识符 (UUID/ULID),用于在数据库表或缓存结构中精准匹配该 owner 实例。 - Returns: : 经由当前业务模型加工处理后所输出的具体数据实例或领域模型对象。 """ + Returns: : 经由当前业务模型加工处理后所输出的具体数据实例或领域模型对象。""" await self.ready_event.wait() return await self._individual_database.get_worker_individual_list(owner_id) @@ -201,24 +217,27 @@ class PostgresDatabase: """对现有的 worker individual 进行状态更新或属性覆盖。 基于增量变更原则,合并最新的配置或数据,并触发相关依赖组件的缓存刷新或事件通知。 Args: agent_id (str): 目标对象的唯一全局标识符 (UUID/ULID),用于在数据库表或缓存结构中精准匹配该 agent 实例。 - Returns: : 经由当前业务模型加工处理后所输出的具体数据实例或领域模型对象。 """ + Returns: : 经由当前业务模型加工处理后所输出的具体数据实例或领域模型对象。""" await self.ready_event.wait() - return await self._individual_database.update_worker_individual(agent_id, **kwargs) + return await self._individual_database.update_worker_individual( + agent_id, **kwargs + ) async def delete_worker_individual(self, agent_id: str): """安全地移除或注销 worker individual。 执行物理删除或逻辑删除操作,并妥善清理相关的关联数据及占用资源。 Args: agent_id (str): 目标对象的唯一全局标识符 (UUID/ULID),用于在数据库表或缓存结构中精准匹配该 agent 实例。 - Returns: : 经由当前业务模型加工处理后所输出的具体数据实例或领域模型对象。 """ + Returns: : 经由当前业务模型加工处理后所输出的具体数据实例或领域模型对象。""" await self.ready_event.wait() return await self._individual_database.delete_worker_individual(agent_id) async def get_all_worker_individual(self): """检索并获取特定的 all worker individual 数据集合或实例对象。 根据提供的查询条件或上下文凭证,从数据库、缓存或第三方服务中读取对应的资源状态。 - Returns: : 经由当前业务模型加工处理后所输出的具体数据实例或领域模型对象。 """ + Returns: : 经由当前业务模型加工处理后所输出的具体数据实例或领域模型对象。""" await self.ready_event.wait() return await self._individual_database.get_all_worker_individual() + # Event Database Methods async def upsert_event(self, trace_id: str, event_data_json: str): await self.ready_event.wait() diff --git a/pretor/core/workflow/workflow.py b/pretor/core/workflow/workflow.py index 2283430..273b390 100644 --- a/pretor/core/workflow/workflow.py +++ b/pretor/core/workflow/workflow.py @@ -16,69 +16,88 @@ from typing import List, Optional, Union, Literal, Dict, Any from pydantic import BaseModel, Field, model_validator from pretor.utils.logger import get_logger -logger = get_logger('workflow') + +logger = get_logger("workflow") NodeType = Literal[ "consciousness_node", "control_node", "supervisory_node", "skill_individual" ] + class EventInfo(BaseModel): """EventInfo 核心组件类。 - 这是一个领域数据模型或功能封装类,承载了 EventInfo 相关的内聚属性定义与状态维护。它的存在隔离了局部的业务复杂性,并对外提供了类型安全的访问接口。 """ + 这是一个领域数据模型或功能封装类,承载了 EventInfo 相关的内聚属性定义与状态维护。它的存在隔离了局部的业务复杂性,并对外提供了类型安全的访问接口。""" + platform: str user_name: str + class LogicGate(BaseModel): """LogicGate 核心组件类。 - 这是一个领域数据模型或功能封装类,承载了 LogicGate 相关的内聚属性定义与状态维护。它的存在隔离了局部的业务复杂性,并对外提供了类型安全的访问接口。 """ + 这是一个领域数据模型或功能封装类,承载了 LogicGate 相关的内聚属性定义与状态维护。它的存在隔离了局部的业务复杂性,并对外提供了类型安全的访问接口。""" + if_fail: str = Field(..., description="失败跳转目标,如 'jump_to_step_1'") - if_pass: Literal["continue", "exit"] = Field(default="continue", description="成功后的动作") + if_pass: Literal["continue", "exit"] = Field( + default="continue", description="成功后的动作" + ) + class WorkStep(BaseModel): """WorkStep 核心组件类。 - 这是一个领域数据模型或功能封装类,承载了 WorkStep 相关的内聚属性定义与状态维护。它的存在隔离了局部的业务复杂性,并对外提供了类型安全的访问接口。 """ + 这是一个领域数据模型或功能封装类,承载了 WorkStep 相关的内聚属性定义与状态维护。它的存在隔离了局部的业务复杂性,并对外提供了类型安全的访问接口。""" + step: int = Field(..., gt=0, description="步骤序号,严格自增") name: str = Field(..., description="步骤名称") node: NodeType = Field(..., description="负责执行的节点类型") action: str = Field(..., description="执行的原子动作") desc: str = Field(..., description="动作细节的自然语言描述,包含人工规范指导") - inputs: Optional[Union[str, List[str]]] = Field(default=None, description="前置依赖输出") + inputs: Optional[Union[str, List[str]]] = Field( + default=None, description="前置依赖输出" + ) outputs: Optional[str] = Field(default=None, description="当前步骤产出物变量名") - agent_id: Optional[str] = Field(default=None, description="分配给 skill_individual 的 Skill Individual 真实 agent_id,不可用名称代替") + agent_id: Optional[str] = Field( + default=None, + description="分配给 skill_individual 的 Skill Individual 真实 agent_id,不可用名称代替", + ) logic_gate: Optional[LogicGate] = Field(default=None, description="逻辑跳转控制") status: Literal["waiting", "running", "completed", "failed"] = Field( - default="waiting", - description="执行状态 (LLM建议保留默认值)" + default="waiting", description="执行状态 (LLM建议保留默认值)" ) class WorkflowStatus(BaseModel): """WorkflowStatus 核心组件类。 - 这是一个领域数据模型或功能封装类,承载了 WorkflowStatus 相关的内聚属性定义与状态维护。它的存在隔离了局部的业务复杂性,并对外提供了类型安全的访问接口。 """ + 这是一个领域数据模型或功能封装类,承载了 WorkflowStatus 相关的内聚属性定义与状态维护。它的存在隔离了局部的业务复杂性,并对外提供了类型安全的访问接口。""" + step: int = Field(default=1, gt=0, description="当前运行到的工作流步数") - status: Literal["waiting_llm_working", "waiting_tool_working", "llm_working", "tool_working"] = Field( - default="waiting_llm_working", - description="当前系统调度状态" - ) + status: Literal[ + "waiting_llm_working", "waiting_tool_working", "llm_working", "tool_working" + ] = Field(default="waiting_llm_working", description="当前系统调度状态") + class PretorWorkflow(BaseModel): """PretorWorkflow 核心组件类。 - 这是一个领域数据模型或功能封装类,承载了 PretorWorkflow 相关的内聚属性定义与状态维护。它的存在隔离了局部的业务复杂性,并对外提供了类型安全的访问接口。 """ + 这是一个领域数据模型或功能封装类,承载了 PretorWorkflow 相关的内聚属性定义与状态维护。它的存在隔离了局部的业务复杂性,并对外提供了类型安全的访问接口。""" + title: str = Field(..., description="工作流的标题") work_link: List[WorkStep] = Field(..., description="工作链逻辑定义") # ---------------- 以下为系统级管控字段,LLM 无需关心 ---------------- # trace_id: str | None = Field(description="系统自动生成的追溯ID") version: str = Field(default="v1.0", description="系统协议版本号") command: Optional[str] = Field(default=None, description="触发此工作流的原始命令") - output: Dict[str, Any] = Field(default_factory=dict, description="工作流最终产出结果") - status: WorkflowStatus = Field(default_factory=WorkflowStatus, description="运行时状态对象") + output: Dict[str, Any] = Field( + default_factory=dict, description="工作流最终产出结果" + ) + status: WorkflowStatus = Field( + default_factory=WorkflowStatus, description="运行时状态对象" + ) event_info: EventInfo | None = Field(default=None) context_memory: Dict[str, Any] = Field(default_factory=dict) - @model_validator(mode='after') - def validate_workflow_integrity(self) -> 'PretorWorkflow': + @model_validator(mode="after") + def validate_workflow_integrity(self) -> "PretorWorkflow": """执行与 validate workflow integrity 相关的核心业务流转操作。 该方法封装了具体的算法策略或状态控制逻辑,确保操作能够在事务上下文中被原子且一致地执行。 - Returns: ('PretorWorkflow'): 经由当前业务模型加工处理后所输出的具体数据实例或领域模型对象。 """ + Returns: ('PretorWorkflow'): 经由当前业务模型加工处理后所输出的具体数据实例或领域模型对象。""" steps = [s.step for s in self.work_link] expected = list(range(1, len(steps) + 1)) if steps != expected: @@ -90,9 +109,11 @@ class PretorWorkflow(BaseModel): try: target = int(s.logic_gate.if_fail.split("_")[-1]) if target > max_step or target < 1: - raise ValueError(f"Step {s.step} 的跳转目标 Step {target} 越界了!") + raise ValueError( + f"Step {s.step} 的跳转目标 Step {target} 越界了!" + ) except ValueError as e: if "越界" in str(e): raise e raise ValueError(f"LogicGate 格式错误: {s.logic_gate.if_fail}") - return self \ No newline at end of file + return self diff --git a/pretor/core/workflow/workflow_template_generator/__init__.py b/pretor/core/workflow/workflow_template_generator/__init__.py deleted file mode 100644 index 5fa7362..0000000 --- a/pretor/core/workflow/workflow_template_generator/__init__.py +++ /dev/null @@ -1,14 +0,0 @@ -# Copyright 2026 zhaoxi826 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - diff --git a/pretor/core/workflow/workflow_template_generator/workflow_template.py b/pretor/core/workflow/workflow_template_generator/workflow_template.py deleted file mode 100644 index f7f8e3c..0000000 --- a/pretor/core/workflow/workflow_template_generator/workflow_template.py +++ /dev/null @@ -1,46 +0,0 @@ -# Copyright 2026 zhaoxi826 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from pydantic import BaseModel, model_validator -from typing import Dict,List - -class WorkflowTemplateStep(BaseModel): - """WorkflowTemplateStep 核心组件类。 - 这是一个领域数据模型或功能封装类,承载了 WorkflowTemplateStep 相关的内聚属性定义与状态维护。它的存在隔离了局部的业务复杂性,并对外提供了类型安全的访问接口。 """ - step: int - node: str - action: str - desc: str - input: List[str] - output: List[str] - logic_gate: Dict[str, str] - -class WorkflowTemplate(BaseModel): - """WorkflowTemplate 核心组件类。 - 这是一个领域数据模型或功能封装类,承载了 WorkflowTemplate 相关的内聚属性定义与状态维护。它的存在隔离了局部的业务复杂性,并对外提供了类型安全的访问接口。 """ - name: str - desc: str - work_link: list[WorkflowTemplateStep] - - @model_validator(mode='after') - def validate_steps(self) -> 'WorkflowTemplate': - """执行与 validate steps 相关的核心业务流转操作。 - 该方法封装了具体的算法策略或状态控制逻辑,确保操作能够在事务上下文中被原子且一致地执行。 - Returns: ('WorkflowTemplate'): 经由当前业务模型加工处理后所输出的具体数据实例或领域模型对象。 """ - steps = [s.step for s in self.work_link] - if len(steps) != len(set(steps)): - raise ValueError("Step numbers in work_link must be unique") - return self - - \ No newline at end of file diff --git a/pretor/core/workflow/workflow_template_generator/workflow_template_generator.py b/pretor/core/workflow/workflow_template_generator/workflow_template_generator.py deleted file mode 100644 index e249e5d..0000000 --- a/pretor/core/workflow/workflow_template_generator/workflow_template_generator.py +++ /dev/null @@ -1,32 +0,0 @@ -# Copyright 2026 zhaoxi826 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from pathlib import Path -from pretor.core.workflow.workflow_template_generator.workflow_template import WorkflowTemplate - -class WorkflowTemplateGenerator: - """WorkflowTemplateGenerator 核心组件类。 - 这是一个领域数据模型或功能封装类,承载了 WorkflowTemplateGenerator 相关的内聚属性定义与状态维护。它的存在隔离了局部的业务复杂性,并对外提供了类型安全的访问接口。 """ - @staticmethod - def generate_workflow_template(workflow_template: WorkflowTemplate) -> WorkflowTemplate: - """执行与 generate workflow template 相关的核心业务流转操作。 - 该方法封装了具体的算法策略或状态控制逻辑,确保操作能够在事务上下文中被原子且一致地执行。 - Args: workflow_template (WorkflowTemplate): 参与 generate workflow template 逻辑运算或数据构建的上下文依赖对象。 - Returns: (WorkflowTemplate): 经由当前业务模型加工处理后所输出的具体数据实例或领域模型对象。 """ - output_dir = Path("pretor") / "workflow_template" - if not output_dir.exists(): - output_dir.mkdir(parents=True) - output_file = output_dir / f"{workflow_template.name}_workflow_template.json" - with output_file.open("w", encoding="utf-8") as f: - f.write(workflow_template.model_dump_json(indent=4)) diff --git a/pretor/core/workflow/workflow_template_manager.py b/pretor/core/workflow/workflow_template_manager.py deleted file mode 100644 index 3ad0d8e..0000000 --- a/pretor/core/workflow/workflow_template_manager.py +++ /dev/null @@ -1,76 +0,0 @@ -# Copyright 2026 zhaoxi826 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import json -from pretor.core.workflow.workflow_template_generator.workflow_template_generator import WorkflowTemplateGenerator -from pathlib import Path -from pretor.core.workflow.workflow_template_generator.workflow_template import WorkflowTemplate - -from pretor.utils.logger import get_logger -logger = get_logger('workflow_template_manager') - -class WorkflowManager: - """WorkflowManager 核心组件类。 - 这是一个管理器类,职责集中在维护整个系统内有关 Workflow 资源的全局生命周期。它提供了注册机制、状态同步以及跨组件的统一查询入口,确保系统中该类型资源的实例一致性与可控性。 """ - def __init__(self): - self.workflow_template_generator = WorkflowTemplateGenerator() - self.workflow_templates_registry = {} - self.template_path = Path("pretor/workflow_template") - self._load_workflow_template() - - def _load_workflow_template(self) -> None: - """执行与 load workflow template 相关的核心业务流转操作。 - 该方法封装了具体的算法策略或状态控制逻辑,确保操作能够在事务上下文中被原子且一致地执行。 - Returns: (None): 经由当前业务模型加工处理后所输出的具体数据实例或领域模型对象。 """ - for workflow_template_file in self.template_path.glob("*_workflow_template.json"): - with workflow_template_file.open("r",encoding="utf-8") as f: - try: - workflow_template = json.load(f) - self.workflow_templates_registry[workflow_template.get("name")] = workflow_template.get("desc") - except json.decoder.JSONDecodeError: - logger.warning(f"{workflow_template_file}不是json文件或格式错误") - except KeyError: - logger.warning(f"{workflow_template_file}不符合workflow_template格式") - - def generate_workflow_template(self, workflow_template: WorkflowTemplate) -> None: - """执行与 generate workflow template 相关的核心业务流转操作。 - 该方法封装了具体的算法策略或状态控制逻辑,确保操作能够在事务上下文中被原子且一致地执行。 - Args: workflow_template (WorkflowTemplate): 参与 generate workflow template 逻辑运算或数据构建的上下文依赖对象。 - Returns: (None): 经由当前业务模型加工处理后所输出的具体数据实例或领域模型对象。 """ - try: - workflow_template = self.workflow_template_generator.generate_workflow_template(workflow_template=workflow_template) - self.workflow_templates_registry[workflow_template.name] = workflow_template.desc - except Exception: - logger.exception("Failed to generate workflow template") - - def add_workflow_template(self, template_name: str, workflow_template: WorkflowTemplate) -> None: - """创建并持久化新的 workflow template 实体。 - 接收构建参数,执行必要的数据校验与默认值填充后,将新记录安全地写入底层存储或系统注册表中。 - Args: template_name (str): 赋予该实体的人类可读名称或标题字符串,主要用于前端 UI 展示、日志记录或模糊检索。 workflow_template (WorkflowTemplate): 参与 add workflow template 逻辑运算或数据构建的上下文依赖对象。 - Returns: (None): 经由当前业务模型加工处理后所输出的具体数据实例或领域模型对象。 """ - self.generate_workflow_template(workflow_template) - - def get_all_workflow_templates(self) -> dict: - """检索并获取特定的 all workflow templates 数据集合或实例对象。 - 根据提供的查询条件或上下文凭证,从数据库、缓存或第三方服务中读取对应的资源状态。 - Returns: (dict): 高度聚合的字典结构数据,将多维度的属性特征或统计指标组合后一并返回。 """ - return self.workflow_templates_registry - - def delete_workflow_template(self, template_name: str) -> None: - """安全地移除或注销 workflow template。 - 执行物理删除或逻辑删除操作,并妥善清理相关的关联数据及占用资源。 - Args: template_name (str): 赋予该实体的人类可读名称或标题字符串,主要用于前端 UI 展示、日志记录或模糊检索。 - Returns: (None): 经由当前业务模型加工处理后所输出的具体数据实例或领域模型对象。 """ - if template_name in self.workflow_templates_registry: - del self.workflow_templates_registry[template_name] \ No newline at end of file diff --git a/pretor/core/workflow_running_engine/__init__.py b/pretor/core/workflow_running_engine/__init__.py new file mode 100644 index 0000000..537dee0 --- /dev/null +++ b/pretor/core/workflow_running_engine/__init__.py @@ -0,0 +1,3 @@ +from pretor.core.workflow_running_engine.workflow_runner import WorkflowRunningEngine + +__all__ = ["WorkflowRunningEngine"] diff --git a/pretor/core/workflow/workflow_runner.py b/pretor/core/workflow_running_engine/workflow_runner.py similarity index 70% rename from pretor/core/workflow/workflow_runner.py rename to pretor/core/workflow_running_engine/workflow_runner.py index 19edd61..0140f02 100644 --- a/pretor/core/workflow/workflow_runner.py +++ b/pretor/core/workflow_running_engine/workflow_runner.py @@ -19,44 +19,40 @@ from pretor.core.workflow.workflow import PretorWorkflow, WorkStep, EventInfo from typing import Optional, Dict, Union, Any, List from pretor.utils.error import WorkflowError, WorkflowExit from pretor.api.platform.event import PretorEvent -from pretor.core.individual.control_node.template import ForWorkflowInput as ControlForWorkflowInput, \ - ForWorkflow as ControlForWorkflow +from pretor.core.individual.control_node.template import ( + ForWorkflowInput as ControlForWorkflowInput, + ForWorkflow as ControlForWorkflow, +) from pretor.core.individual.consciousness_node.template import ( ForWorkflowInput as ConsciousnessForWorkflowInput, ForSupervisoryInput, ForSupervisoryNode, ForWorkflow as ConsciousnessForWorkflow, ForWorkflowEngineInput, - ForWorkflowEngine + ForWorkflowEngine, ) from pretor.core.individual.supervisory_node.template import TerminationMessage -import pathlib - - -def get_workflow_template(workflow_name: str) -> str: - """检索并获取特定的 workflow template 数据集合或实例对象。 - 根据提供的查询条件或上下文凭证,从数据库、缓存或第三方服务中读取对应的资源状态。 - Args: workflow_name (str): 赋予该实体的人类可读名称或标题字符串,主要用于前端 UI 展示、日志记录或模糊检索。 - Returns: (str): 处理流程所输出的具体字符串产物,可能是新生成的 ID 序列、格式化好的文本片段或 LLM 推理的回答内容。 """ - workflow_template = pathlib.Path(__file__).parent.parent.parent / "workflow_template" / (workflow_name + "_workflow_template.json") - with open(workflow_template, "r", encoding="utf-8") as workflow_template_file: - workflow_template = workflow_template_file.read() - return workflow_template class WorkflowEngine: """WorkflowEngine 核心组件类。 - 这是一个领域数据模型或功能封装类,承载了 WorkflowEngine 相关的内聚属性定义与状态维护。它的存在隔离了局部的业务复杂性,并对外提供了类型安全的访问接口。 """ - def __init__(self, - workflow: PretorWorkflow, - consciousness_node=None, - control_node=None, - supervisory_node=None): + 这是一个领域数据模型或功能封装类,承载了 WorkflowEngine 相关的内聚属性定义与状态维护。它的存在隔离了局部的业务复杂性,并对外提供了类型安全的访问接口。""" + + def __init__( + self, + workflow: PretorWorkflow, + consciousness_node=None, + control_node=None, + supervisory_node=None, + ): from pretor.utils.logger import get_logger - self.logger = get_logger('workflow_runner') + + self.logger = get_logger("workflow_runner") self.workflow: PretorWorkflow = workflow """工作流:当前WorkflowEngine待执行的workflow""" - self._steps_by_id: Dict[int, WorkStep] = {step.step: step for step in self.workflow.work_link} + self._steps_by_id: Dict[int, WorkStep] = { + step.step: step for step in self.workflow.work_link + } """步骤表:将当前workflow的步骤序号和步骤内容存放""" self.consciousness_node = consciousness_node """意识节点""" @@ -70,7 +66,7 @@ class WorkflowEngine: """执行与 push sse 相关的核心业务流转操作。 该方法封装了具体的算法策略或状态控制逻辑,确保操作能够在事务上下文中被原子且一致地执行。 Args: msg (str): 控制逻辑流向的具体字符串参数,指定了期望的 msg 内容。 - Returns: (None): 经由当前业务模型加工处理后所输出的具体数据实例或领域模型对象。 """ + Returns: (None): 经由当前业务模型加工处理后所输出的具体数据实例或领域模型对象。""" try: await self._gwm.put_pending.remote(self.workflow.trace_id, msg) except Exception: @@ -95,37 +91,55 @@ class WorkflowEngine: async def run(self): """ - run方法 - 处理并执行workflow的方法 + run方法 + 处理并执行workflow的方法 """ - self.logger.info(f"🚀 工作流引擎启动: {self.workflow.title} [Trace ID: {self.workflow.trace_id}]") + self.logger.info( + f"🚀 工作流引擎启动: {self.workflow.title} [Trace ID: {self.workflow.trace_id}]" + ) await self._push_sse(f"[工作流启动] {self.workflow.title}") max_step = len(self.workflow.work_link) while 1 <= self.workflow.status.step <= max_step: current_step_id = self.workflow.status.step current_step = self._steps_by_id.get(current_step_id) if not current_step: - self.logger.error(f"严重错误:找不到步骤 {current_step_id},工作流强制终止。") + self.logger.error( + f"严重错误:找不到步骤 {current_step_id},工作流强制终止。" + ) self.workflow.status.status = "failed" await self._push_sse(f"[工作流失败] 找不到步骤 {current_step_id}") break - self.logger.info(f"▶️ 开始执行 Step {current_step_id}: [{current_step.node}] -> {current_step.action}") + self.logger.info( + f"▶️ 开始执行 Step {current_step_id}: [{current_step.node}] -> {current_step.action}" + ) current_step.status = "running" - await self._push_sse(f"[Step {current_step_id}] {current_step.name}: {current_step.desc}") + await self._push_sse( + f"[Step {current_step_id}] {current_step.name}: {current_step.desc}" + ) try: step_input_data = self._prepare_inputs(current_step.inputs) - step_result, is_success = await self._dispatch_to_node(current_step, step_input_data) + step_result, is_success = await self._dispatch_to_node( + current_step, step_input_data + ) if is_success: if current_step.outputs: self.workflow.context_memory[current_step.outputs] = step_result - self.logger.debug(f"Step {current_step_id} 产出已保存至变量: '{current_step.outputs}'") + self.logger.debug( + f"Step {current_step_id} 产出已保存至变量: '{current_step.outputs}'" + ) current_step.status = "completed" - await self._push_sse(f"[Step {current_step_id} 完成] {current_step.name}") + await self._push_sse( + f"[Step {current_step_id} 完成] {current_step.name}" + ) else: - self.logger.warning(f"Step {current_step_id} 执行遇到业务失败/驳回。") + self.logger.warning( + f"Step {current_step_id} 执行遇到业务失败/驳回。" + ) current_step.status = "failed" - await self._push_sse(f"[Step {current_step_id} 失败] {current_step.name}") + await self._push_sse( + f"[Step {current_step_id} 失败] {current_step.name}" + ) self._handle_logic_gate(current_step, is_success) except WorkflowExit: self.logger.info("命中 if_pass='exit',工作流被主动要求结束。") @@ -137,7 +151,10 @@ class WorkflowEngine: await self._push_sse(f"[工作流失败] {e}") break except Exception as e: - self.logger.error(f"❌ Step {current_step_id} 发生系统级未捕获异常: {e}", exc_info=True) + self.logger.error( + f"❌ Step {current_step_id} 发生系统级未捕获异常: {e}", + exc_info=True, + ) current_step.status = "failed" self.workflow.status.status = "failed" await self._push_sse(f"[工作流异常] {e}") @@ -163,9 +180,11 @@ class WorkflowEngine: if self.consciousness_node: supervisory_input = ForSupervisoryInput( workflow=self.workflow, - original_command=self.workflow.command or "未知命令" + original_command=self.workflow.command or "未知命令", + ) + report_obj = await self.consciousness_node.working.remote( + supervisory_input ) - report_obj = await self.consciousness_node.working.remote(supervisory_input) if isinstance(report_obj, ForSupervisoryNode): report = report_obj.output elif isinstance(report_obj, str): @@ -178,7 +197,7 @@ class WorkflowEngine: term_msg = TerminationMessage( platform=self.workflow.event_info.platform, user_name=self.workflow.event_info.user_name, - message=f"工作流执行完毕。系统报告:{report}" + message=f"工作流执行完毕。系统报告:{report}", ) user_response = await self.supervisory_node.working.remote(term_msg) self.workflow.context_memory["_final_user_response"] = user_response @@ -188,7 +207,9 @@ class WorkflowEngine: except Exception: self.logger.exception("生成工作流执行汇报时发生错误") - async def _dispatch_to_node(self, step: WorkStep, input_data: Any) -> tuple[Any, bool]: + async def _dispatch_to_node( + self, step: WorkStep, input_data: Any + ) -> tuple[Any, bool]: """ 分流器 调用当前step的执行对象 @@ -216,8 +237,7 @@ class WorkflowEngine: raise WorkflowError("未提供 consciousness_node 句柄!") original_cmd = self.workflow.command or "" payload = ConsciousnessForWorkflowInput( - workflow_step=step, - original_command=original_cmd + workflow_step=step, original_command=original_cmd ) result_obj = await self.consciousness_node.working.remote(payload) if isinstance(result_obj, ConsciousnessForWorkflow): @@ -225,9 +245,12 @@ class WorkflowEngine: return result_obj, True elif step.node == "skill_individual": - self.logger.info(f"正在通过 WorkerCluster 调度 skill_individual 执行 {step.action}。") + self.logger.info( + f"正在通过 WorkerCluster 调度 skill_individual 执行 {step.action}。" + ) try: from pretor.utils.ray_hook import ray_actor_hook + worker_cluster = ray_actor_hook("worker_cluster").worker_cluster task_id = f"{self.workflow.trace_id}_step_{step.step}" agent_id = step.agent_id or f"default_{step.node}" @@ -235,18 +258,24 @@ class WorkflowEngine: "action": step.action, "description": step.desc, "input_data": input_data, - "context_memory": self.workflow.context_memory + "context_memory": self.workflow.context_memory, } - result_response = await worker_cluster.submit_task.remote(task_id, agent_id, task_event) + result_response = await worker_cluster.submit_task.remote( + task_id, agent_id, task_event + ) if result_response.get("success"): return result_response.get("data"), True else: - self.logger.error(f"WorkerCluster 执行 {step.node} 失败: {result_response.get('error')}") + self.logger.error( + f"WorkerCluster 执行 {step.node} 失败: {result_response.get('error')}" + ) return result_response.get("error"), False except Exception as e: - self.logger.exception(f"调度 WorkerCluster 执行 {step.node} 时发生异常: {e}") + self.logger.exception( + f"调度 WorkerCluster 执行 {step.node} 时发生异常: {e}" + ) raise WorkflowError(f"WorkerCluster 调度异常: {e}") else: raise WorkflowError(f"未知的节点类型:{step.node}") @@ -275,7 +304,9 @@ class WorkflowEngine: match gate.if_fail.split("_"): case ["jump", "to", "step", target] if target.isdigit(): target_step = int(target) - self.logger.warning(f"触发逻辑门分支!从 Step {step.step} 跳转至 Step {target_step}") + self.logger.warning( + f"触发逻辑门分支!从 Step {step.step} 跳转至 Step {target_step}" + ) self.workflow.status.step = target_step case _: raise WorkflowError(f"未知的 if_fail 格式: {gate.if_fail}") @@ -284,10 +315,14 @@ class WorkflowEngine: @ray.remote class WorkflowRunningEngine: """WorkflowRunningEngine 核心组件类。 - 这是一个领域数据模型或功能封装类,承载了 WorkflowRunningEngine 相关的内聚属性定义与状态维护。它的存在隔离了局部的业务复杂性,并对外提供了类型安全的访问接口。 """ - def __init__(self, consciousness_node=None, control_node=None, supervisory_node=None): + 这是一个领域数据模型或功能封装类,承载了 WorkflowRunningEngine 相关的内聚属性定义与状态维护。它的存在隔离了局部的业务复杂性,并对外提供了类型安全的访问接口。""" + + def __init__( + self, consciousness_node=None, control_node=None, supervisory_node=None + ): from pretor.utils.logger import get_logger - self.logger = get_logger('workflow_runner') + + self.logger = get_logger("workflow_runner") self.runner_engine = {} self.workflow_queue: asyncio.Queue[PretorEvent] = None self.consciousness_node = consciousness_node @@ -298,28 +333,31 @@ class WorkflowRunningEngine: async def run(self): # Move actor hook to async start so we don't race during __init__ across cluster """执行与 run 相关的核心业务流转操作。 - 该方法封装了具体的算法策略或状态控制逻辑,确保操作能够在事务上下文中被原子且一致地执行。 """ - self.global_state_machine = ray_actor_hook("global_state_machine").global_state_machine + 该方法封装了具体的算法策略或状态控制逻辑,确保操作能够在事务上下文中被原子且一致地执行。""" + self.global_state_machine = ray_actor_hook( + "global_state_machine" + ).global_state_machine self.workflow_queue = asyncio.Queue() self.runner_engine = { - f"runner_{i}": asyncio.create_task(self.runner(i)) - for i in range(10) + f"runner_{i}": asyncio.create_task(self.runner(i)) for i in range(10) } async def put_event(self, event: PretorEvent) -> None: """执行与 put event 相关的核心业务流转操作。 该方法封装了具体的算法策略或状态控制逻辑,确保操作能够在事务上下文中被原子且一致地执行。 Args: event (PretorEvent): 由事件总线或工作流引擎分发过来的事件载荷,封装了触发此次调用的上下文快照与任务目标指令。 - Returns: (None): 经由当前业务模型加工处理后所输出的具体数据实例或领域模型对象。 """ + Returns: (None): 经由当前业务模型加工处理后所输出的具体数据实例或领域模型对象。""" await self.workflow_queue.put(event) async def resume_workflow(self, event: PretorEvent) -> None: """Resume an incomplete workflow that was loaded from the database.""" self.logger.info(f"Resuming workflow {event.trace_id}") - workflow_engine = WorkflowEngine(event.workflow, - self.consciousness_node, - self.control_node, - self.supervisory_node) + workflow_engine = WorkflowEngine( + event.workflow, + self.consciousness_node, + self.control_node, + self.supervisory_node, + ) # Assuming you want to schedule it via a task asyncio.create_task(workflow_engine.run()) @@ -333,33 +371,37 @@ class WorkflowRunningEngine: while True: try: event = await self.workflow_queue.get() - self.logger.info(f"WorkflowRunningEngine: runner_{i} 接收到事件 {event.trace_id} 准备生成工作流。") + self.logger.info( + f"WorkflowRunningEngine: runner_{i} 接收到事件 {event.trace_id} 准备生成工作流。" + ) if not self.consciousness_node: raise WorkflowError("未配置 consciousness_node,无法生成工作流") - workflow_template_name = event.context.get("workflow_template", "") - workflow_template = get_workflow_template(workflow_template_name) if workflow_template_name else None - available_skills = None if self.global_state_machine: try: - all_individuals = await self.global_state_machine.list_individuals.remote() + all_individuals = ( + await self.global_state_machine.list_individuals.remote() + ) available_skills = [] for agent_id, config in all_individuals.items(): - if config.get("agent_type") == "skill_individual" or config.get("type") == "skill_individual": - available_skills.append({ - "agent_id": agent_id, - "name": config.get("agent_name", "Unknown"), - "description": config.get("description", "") - }) + if ( + config.get("agent_type") == "skill_individual" + or config.get("type") == "skill_individual" + ): + available_skills.append( + { + "agent_id": agent_id, + "name": config.get("agent_name", "Unknown"), + "description": config.get("description", ""), + } + ) except Exception as e: self.logger.warning(f"获取Skill Individual列表失败: {e}") payload = ForWorkflowEngineInput( - original_command=event.message, - workflow_template=workflow_template, - available_skills=available_skills + original_command=event.message, available_skills=available_skills ) result_obj = await self.consciousness_node.working.remote(payload) @@ -369,25 +411,39 @@ class WorkflowRunningEngine: workflow.trace_id = event.trace_id workflow.command = event.message - workflow.event_info = EventInfo(platform=event.platform, - user_name=event.user_name,) + workflow.event_info = EventInfo( + platform=event.platform, + user_name=event.user_name, + ) self.logger.info( - f"WorkflowRunningEngine: runner_{i} 成功生成工作流 {workflow.trace_id}:{workflow.title}") + f"WorkflowRunningEngine: runner_{i} 成功生成工作流 {workflow.trace_id}:{workflow.title}" + ) - global_workflow_manager = ray_actor_hook("global_workflow_manager").global_workflow_manager - await global_workflow_manager.update_workflow.remote(event.trace_id, workflow) + global_workflow_manager = ray_actor_hook( + "global_workflow_manager" + ).global_workflow_manager + await global_workflow_manager.update_workflow.remote( + event.trace_id, workflow + ) - workflow_engine = WorkflowEngine(workflow, - self.consciousness_node, - self.control_node, - self.supervisory_node) + workflow_engine = WorkflowEngine( + workflow, + self.consciousness_node, + self.control_node, + self.supervisory_node, + ) await workflow_engine.run() else: - self.logger.error(f"WorkflowRunningEngine: runner_{i} 无法生成工作流,返回类型为 {type(result_obj)}") + self.logger.error( + f"WorkflowRunningEngine: runner_{i} 无法生成工作流,返回类型为 {type(result_obj)}" + ) except asyncio.CancelledError: self.logger.info(f"WorkflowRunningEngine: runner_{i} 被取消。") raise except Exception as e: - self.logger.error(f"WorkflowRunningEngine: runner_{i} 遇到未捕获的异常: {e}", exc_info=True) \ No newline at end of file + self.logger.error( + f"WorkflowRunningEngine: runner_{i} 遇到未捕获的异常: {e}", + exc_info=True, + ) diff --git a/pretor/plugin/tool_plugin/__init__.py b/pretor/plugin/tool_plugin/__init__.py index a997743..319b42c 100644 --- a/pretor/plugin/tool_plugin/__init__.py +++ b/pretor/plugin/tool_plugin/__init__.py @@ -10,4 +10,4 @@ # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and -# limitations under the License. \ No newline at end of file +# limitations under the License. diff --git a/pretor/plugin/tool_plugin/approval/__init__.py b/pretor/plugin/tool_plugin/approval/__init__.py index b53f657..5440cdc 100644 --- a/pretor/plugin/tool_plugin/approval/__init__.py +++ b/pretor/plugin/tool_plugin/approval/__init__.py @@ -13,4 +13,5 @@ # limitations under the License. from .approval import ApprovalToolData, approval + __all__ = ["ApprovalToolData", "approval"] diff --git a/pretor/plugin/tool_plugin/approval/approval.py b/pretor/plugin/tool_plugin/approval/approval.py index 1eeba85..a74ec89 100644 --- a/pretor/plugin/tool_plugin/approval/approval.py +++ b/pretor/plugin/tool_plugin/approval/approval.py @@ -16,12 +16,22 @@ from pretor.plugin.tool_plugin.base_tool import BaseToolData from pretor.utils.ray_hook import ray_actor_hook from typing import List, Literal, Dict + class ApprovalToolData(BaseToolData): """ApprovalToolData 核心组件类。 - 这是一个可被智能体动态调用的外部工具组件类。它定义了清晰的输入参数 Schema 与执行契约,赋予智能体与外界真实系统(如文件、网页、API)进行交互的能力。 """ + 这是一个可被智能体动态调用的外部工具组件类。它定义了清晰的输入参数 Schema 与执行契约,赋予智能体与外界真实系统(如文件、网页、API)进行交互的能力。""" + is_system: bool = True - action_scope: List[Literal["control_node", "consciousness_node", "supervisory_node", "growth_node", "", ""]] = [ - "control_node", "consciousness_node"] + action_scope: List[ + Literal[ + "control_node", + "consciousness_node", + "supervisory_node", + "growth_node", + "", + "", + ] + ] = ["control_node", "consciousness_node"] config_args: Dict[str, str] = {} @@ -38,4 +48,4 @@ async def approval(message: str, trace_id: str) -> str: actor_list = ray_actor_hook("global_state_machine") await actor_list.global_state_machine.put_pending.remote(trace_id, message) reply = await actor_list.global_state_machine.get_received.remote(trace_id) - return reply \ No newline at end of file + return reply diff --git a/pretor/plugin/tool_plugin/base_tool.py b/pretor/plugin/tool_plugin/base_tool.py index 92cae0e..f7c7513 100644 --- a/pretor/plugin/tool_plugin/base_tool.py +++ b/pretor/plugin/tool_plugin/base_tool.py @@ -16,10 +16,21 @@ from pydantic import BaseModel from typing import List, Literal, Dict from pydantic import ConfigDict + class BaseToolData(BaseModel): """BaseToolData 核心组件类。 - 这是一个可被智能体动态调用的外部工具组件类。它定义了清晰的输入参数 Schema 与执行契约,赋予智能体与外界真实系统(如文件、网页、API)进行交互的能力。 """ + 这是一个可被智能体动态调用的外部工具组件类。它定义了清晰的输入参数 Schema 与执行契约,赋予智能体与外界真实系统(如文件、网页、API)进行交互的能力。""" + model_config = ConfigDict(extra="allow") is_system: bool - action_scope: List[Literal["control_node", "consciousness_node", "supervisory_node", "growth_node", "", ""]] = [] - config_args: Dict[str, str] = {} \ No newline at end of file + action_scope: List[ + Literal[ + "control_node", + "consciousness_node", + "supervisory_node", + "growth_node", + "", + "", + ] + ] = [] + config_args: Dict[str, str] = {} diff --git a/pretor/plugin/tool_plugin/file_reader/file_reader.py b/pretor/plugin/tool_plugin/file_reader/file_reader.py index 71aec3a..607ab81 100644 --- a/pretor/plugin/tool_plugin/file_reader/file_reader.py +++ b/pretor/plugin/tool_plugin/file_reader/file_reader.py @@ -16,13 +16,16 @@ from pydantic_ai import RunContext from pretor.plugin.tool_plugin.base_tool import BaseToolData import os + class FileReaderData(BaseToolData): """FileReaderData 核心组件类。 - 这是一个领域数据模型或功能封装类,承载了 FileReaderData 相关的内聚属性定义与状态维护。它的存在隔离了局部的业务复杂性,并对外提供了类型安全的访问接口。 """ + 这是一个领域数据模型或功能封装类,承载了 FileReaderData 相关的内聚属性定义与状态维护。它的存在隔离了局部的业务复杂性,并对外提供了类型安全的访问接口。""" + is_system: bool = True name: str = "file_reader" description: str = "读取本地文件的内容" + def file_reader(ctx: RunContext, filepath: str) -> str: """读取本地文件内容的工具。 @@ -38,7 +41,7 @@ def file_reader(ctx: RunContext, filepath: str) -> str: return f"Error: {filepath} 不是一个文件。" try: - with open(filepath, 'r', encoding='utf-8') as f: + with open(filepath, "r", encoding="utf-8") as f: content = f.read() return content except Exception as e: diff --git a/pretor/utils/access.py b/pretor/utils/access.py index e9548e2..7f5799f 100644 --- a/pretor/utils/access.py +++ b/pretor/utils/access.py @@ -24,11 +24,13 @@ from pwdlib import PasswordHash class TokenData(BaseModel): """TokenData 核心组件类。 - 这是一个领域数据模型或功能封装类,承载了 TokenData 相关的内聚属性定义与状态维护。它的存在隔离了局部的业务复杂性,并对外提供了类型安全的访问接口。 """ + 这是一个领域数据模型或功能封装类,承载了 TokenData 相关的内聚属性定义与状态维护。它的存在隔离了局部的业务复杂性,并对外提供了类型安全的访问接口。""" + user_id: str username: Optional[str] = None exp: Optional[int] = None + SECRET_KEY = os.getenv("SECRET_KEY") ALGORITHM = "HS256" ACCESS_TOKEN_EXPIRE_MINUTES = 60 * 24 @@ -41,19 +43,16 @@ password_hasher = PasswordHash.recommended() class Accessor: """Accessor 核心组件类。 - 这是一个领域数据模型或功能封装类,承载了 Accessor 相关的内聚属性定义与状态维护。它的存在隔离了局部的业务复杂性,并对外提供了类型安全的访问接口。 """ + 这是一个领域数据模型或功能封装类,承载了 Accessor 相关的内聚属性定义与状态维护。它的存在隔离了局部的业务复杂性,并对外提供了类型安全的访问接口。""" + @staticmethod def _decode_token(token: str) -> TokenData: """执行与 decode token 相关的核心业务流转操作。 该方法封装了具体的算法策略或状态控制逻辑,确保操作能够在事务上下文中被原子且一致地执行。 Args: token (str): 由认证中心颁发的 JWT 或长期访问令牌,用于跨服务调用时的身份自证与权限校验。 - Returns: (TokenData): 经由当前业务模型加工处理后所输出的具体数据实例或领域模型对象。 """ + Returns: (TokenData): 经由当前业务模型加工处理后所输出的具体数据实例或领域模型对象。""" try: - payload = jwt.decode( - token, - SECRET_KEY, - algorithms=[ALGORITHM] - ) + payload = jwt.decode(token, SECRET_KEY, algorithms=[ALGORITHM]) return TokenData(**payload) except jwt.ExpiredSignatureError: raise HTTPException( @@ -71,9 +70,11 @@ class Accessor: """创建并持久化新的 access token 实体。 接收构建参数,执行必要的数据校验与默认值填充后,将新记录安全地写入底层存储或系统注册表中。 Args: data (dict): 从客户端传递过来或由上游组件生成的核心业务数据体,通常需要进一步的清洗和结构化解析。 - Returns: (str): 处理流程所输出的具体字符串产物,可能是新生成的 ID 序列、格式化好的文本片段或 LLM 推理的回答内容。 """ + Returns: (str): 处理流程所输出的具体字符串产物,可能是新生成的 ID 序列、格式化好的文本片段或 LLM 推理的回答内容。""" to_encode = data.copy() - expire = datetime.now(timezone.utc) + timedelta(minutes=ACCESS_TOKEN_EXPIRE_MINUTES) + expire = datetime.now(timezone.utc) + timedelta( + minutes=ACCESS_TOKEN_EXPIRE_MINUTES + ) to_encode.update({"exp": int(expire.timestamp())}) return jwt.encode(to_encode, SECRET_KEY, algorithm=ALGORITHM) @@ -82,7 +83,7 @@ class Accessor: """执行与 verify password 相关的核心业务流转操作。 该方法封装了具体的算法策略或状态控制逻辑,确保操作能够在事务上下文中被原子且一致地执行。 Args: plain_password (str): 控制逻辑流向的具体字符串参数,指定了期望的 plain password 内容。 hashed_password (str): 控制逻辑流向的具体字符串参数,指定了期望的 hashed password 内容。 - Returns: (bool): 一个布尔型结果标志,明确返回 True 表示该操作成功应用或条件达成,False 则表示失败或被拒绝。 """ + Returns: (bool): 一个布尔型结果标志,明确返回 True 表示该操作成功应用或条件达成,False 则表示失败或被拒绝。""" return password_hasher.verify(plain_password, hashed_password) @staticmethod @@ -90,7 +91,7 @@ class Accessor: """检索并获取特定的 current user 数据集合或实例对象。 根据提供的查询条件或上下文凭证,从数据库、缓存或第三方服务中读取对应的资源状态。 Args: request (Request): FastAPI 框架注入的原生 HTTP 请求对象,包含了完整的 Header 标头、查询参数和正文流。 - Returns: (TokenData): 经由当前业务模型加工处理后所输出的具体数据实例或领域模型对象。 """ + Returns: (TokenData): 经由当前业务模型加工处理后所输出的具体数据实例或领域模型对象。""" auth_header = request.headers.get("Authorization") if not auth_header or not auth_header.startswith("Bearer "): raise HTTPException( @@ -105,7 +106,7 @@ class Accessor: """执行与 login hashed password 相关的核心业务流转操作。 该方法封装了具体的算法策略或状态控制逻辑,确保操作能够在事务上下文中被原子且一致地执行。 Args: user (User): 当前已通过鉴权流程的访问者实体对象,内部包含用户角色、权限层级及租户归属等核心元信息。 password (str): 控制逻辑流向的具体字符串参数,指定了期望的 password 内容。 - Returns: (str): 处理流程所输出的具体字符串产物,可能是新生成的 ID 序列、格式化好的文本片段或 LLM 推理的回答内容。 """ + Returns: (str): 处理流程所输出的具体字符串产物,可能是新生成的 ID 序列、格式化好的文本片段或 LLM 推理的回答内容。""" if not user: raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, @@ -116,10 +117,7 @@ class Accessor: status_code=status.HTTP_401_UNAUTHORIZED, detail="用户名或密码错误", ) - token_payload = { - "user_id": str(user.user_id), - "username": user.user_name - } + token_payload = {"user_id": str(user.user_id), "username": user.user_name} return Accessor._create_access_token(data=token_payload) @staticmethod @@ -127,9 +125,9 @@ class Accessor: """执行与 hash password 相关的核心业务流转操作。 该方法封装了具体的算法策略或状态控制逻辑,确保操作能够在事务上下文中被原子且一致地执行。 Args: password (str): 控制逻辑流向的具体字符串参数,指定了期望的 password 内容。 - Returns: (str): 处理流程所输出的具体字符串产物,可能是新生成的 ID 序列、格式化好的文本片段或 LLM 推理的回答内容。 """ + Returns: (str): 处理流程所输出的具体字符串产物,可能是新生成的 ID 序列、格式化好的文本片段或 LLM 推理的回答内容。""" if not password: raise ValueError("密码不能为空") if len(password) < 6: raise ValueError("密码长度不能小于 6 位") - return password_hasher.hash(password) \ No newline at end of file + return password_hasher.hash(password) diff --git a/pretor/utils/agent_model.py b/pretor/utils/agent_model.py index 048c97e..92ae2ba 100644 --- a/pretor/utils/agent_model.py +++ b/pretor/utils/agent_model.py @@ -15,17 +15,23 @@ from pydantic import BaseModel + class ResponseModel(BaseModel): """ResponseModel 核心组件类。 - 这是一个领域数据模型或功能封装类,承载了 ResponseModel 相关的内聚属性定义与状态维护。它的存在隔离了局部的业务复杂性,并对外提供了类型安全的访问接口。 """ + 这是一个领域数据模型或功能封装类,承载了 ResponseModel 相关的内聚属性定义与状态维护。它的存在隔离了局部的业务复杂性,并对外提供了类型安全的访问接口。""" + pass + class DepsModel(BaseModel): """DepsModel 核心组件类。 - 这是一个领域数据模型或功能封装类,承载了 DepsModel 相关的内聚属性定义与状态维护。它的存在隔离了局部的业务复杂性,并对外提供了类型安全的访问接口。 """ + 这是一个领域数据模型或功能封装类,承载了 DepsModel 相关的内聚属性定义与状态维护。它的存在隔离了局部的业务复杂性,并对外提供了类型安全的访问接口。""" + pass + class InputModel(BaseModel): """InputModel 核心组件类。 - 这是一个领域数据模型或功能封装类,承载了 InputModel 相关的内聚属性定义与状态维护。它的存在隔离了局部的业务复杂性,并对外提供了类型安全的访问接口。 """ - pass \ No newline at end of file + 这是一个领域数据模型或功能封装类,承载了 InputModel 相关的内聚属性定义与状态维护。它的存在隔离了局部的业务复杂性,并对外提供了类型安全的访问接口。""" + + pass diff --git a/pretor/utils/banner.py b/pretor/utils/banner.py index 8908c08..a065982 100644 --- a/pretor/utils/banner.py +++ b/pretor/utils/banner.py @@ -15,11 +15,13 @@ from rich.console import Console from rich.text import Text import yaml + + def print_banner() -> None: """执行与 print banner 相关的核心业务流转操作。 该方法封装了具体的算法策略或状态控制逻辑,确保操作能够在事务上下文中被原子且一致地执行。 - Returns: (None): 经由当前业务模型加工处理后所输出的具体数据实例或领域模型对象。 """ - with open("config/config.yml","r") as config: + Returns: (None): 经由当前业务模型加工处理后所输出的具体数据实例或领域模型对象。""" + with open("config/config.yml", "r") as config: config = yaml.load(config, Loader=yaml.FullLoader) version = config.get("version", "unknown") pretor_banner = """ diff --git a/pretor/utils/check_user/role_check.py b/pretor/utils/check_user/role_check.py index c04a217..630f279 100644 --- a/pretor/utils/check_user/role_check.py +++ b/pretor/utils/check_user/role_check.py @@ -17,47 +17,53 @@ from pretor.utils.access import Accessor, TokenData from pretor.core.database.table.user import UserAuthority from pretor.utils.ray_hook import ray_actor_hook + async def get_authority(user_id: str) -> UserAuthority: """检索并获取特定的 authority 数据集合或实例对象。 根据提供的查询条件或上下文凭证,从数据库、缓存或第三方服务中读取对应的资源状态。 Args: user_id (str): 目标对象的唯一全局标识符 (UUID/ULID),用于在数据库表或缓存结构中精准匹配该 user 实例。 - Returns: (UserAuthority): 经由当前业务模型加工处理后所输出的具体数据实例或领域模型对象。 """ + Returns: (UserAuthority): 经由当前业务模型加工处理后所输出的具体数据实例或领域模型对象。""" from pretor.utils.error import UserNotExistError + postgres_database = ray_actor_hook("postgres_database").postgres_database try: - user_authority = await postgres_database.get_user_authority.remote(user_id=user_id) + user_authority = await postgres_database.get_user_authority.remote( + user_id=user_id + ) return user_authority except UserNotExistError: - raise HTTPException( - status_code=401, - detail="用户不存在或已被删除,请重新登录" - ) + raise HTTPException(status_code=401, detail="用户不存在或已被删除,请重新登录") except Exception as e: # Check if it's a RayTaskError wrapping UserNotExistError if "UserNotExistError" in str(e): raise HTTPException( - status_code=401, - detail="用户不存在或已被删除,请重新登录" + status_code=401, detail="用户不存在或已被删除,请重新登录" ) raise + class RoleChecker: """RoleChecker 核心组件类。 - 这是一个领域数据模型或功能封装类,承载了 RoleChecker 相关的内聚属性定义与状态维护。它的存在隔离了局部的业务复杂性,并对外提供了类型安全的访问接口。 """ - def __init__(self, **kwargs): - self.allowed_roles = kwargs.get("allowed_roles", ) + 这是一个领域数据模型或功能封装类,承载了 RoleChecker 相关的内聚属性定义与状态维护。它的存在隔离了局部的业务复杂性,并对外提供了类型安全的访问接口。""" - async def __call__(self, - token_data: Annotated[TokenData, Depends(Accessor.get_current_user)]): + def __init__(self, **kwargs): + self.allowed_roles = kwargs.get( + "allowed_roles", + ) + + async def __call__( + self, token_data: Annotated[TokenData, Depends(Accessor.get_current_user)] + ): """执行与 call 相关的核心业务流转操作。 该方法封装了具体的算法策略或状态控制逻辑,确保操作能够在事务上下文中被原子且一致地执行。 Args: token_data (Annotated[TokenData, Depends(Accessor.get_current_user)]): 从客户端传递过来或由上游组件生成的核心业务数据体,通常需要进一步的清洗和结构化解析。 - Returns: : 经由当前业务模型加工处理后所输出的具体数据实例或领域模型对象。 """ + Returns: : 经由当前业务模型加工处理后所输出的具体数据实例或领域模型对象。""" user_authority = await get_authority(token_data.user_id) if user_authority < self.allowed_roles: raise HTTPException( status_code=403, - detail={"message": f"User {token_data.user_id} does not have allowed roles"}, + detail={ + "message": f"User {token_data.user_id} does not have allowed roles" + }, ) return token_data - diff --git a/pretor/utils/error.py b/pretor/utils/error.py index ebe1654..4dcaab8 100644 --- a/pretor/utils/error.py +++ b/pretor/utils/error.py @@ -12,57 +12,77 @@ # See the License for the specific language governing permissions and # limitations under the License. + class RetryableError(Exception): """基类:所有可重试错误(如网络断开、抖动等临时性故障)""" + pass + class NonRetryableError(Exception): """基类:所有不可重试错误(如数据验证失败、类型错误等业务逻辑故障)""" + pass + class DemandError(NonRetryableError): """DemandError 核心组件类。 - 这是一个自定义异常类,专门用于在 Demand 相关业务流程中触发中断。它携带了精确的错误上下文与追溯代码,帮助最外层网关能够统一捕获并返回友好的前端错误提示。 """ + 这是一个自定义异常类,专门用于在 Demand 相关业务流程中触发中断。它携带了精确的错误上下文与追溯代码,帮助最外层网关能够统一捕获并返回友好的前端错误提示。""" + pass + class ModelNotExistError(Exception): """ModelNotExistError 核心组件类。 - 这是一个自定义异常类,专门用于在 ModelNotExist 相关业务流程中触发中断。它携带了精确的错误上下文与追溯代码,帮助最外层网关能够统一捕获并返回友好的前端错误提示。 """ + 这是一个自定义异常类,专门用于在 ModelNotExist 相关业务流程中触发中断。它携带了精确的错误上下文与追溯代码,帮助最外层网关能够统一捕获并返回友好的前端错误提示。""" + pass + class UserError(Exception): """UserError 核心组件类。 - 这是一个自定义异常类,专门用于在 User 相关业务流程中触发中断。它携带了精确的错误上下文与追溯代码,帮助最外层网关能够统一捕获并返回友好的前端错误提示。 """ + 这是一个自定义异常类,专门用于在 User 相关业务流程中触发中断。它携带了精确的错误上下文与追溯代码,帮助最外层网关能够统一捕获并返回友好的前端错误提示。""" + pass + class UserNotExistError(UserError): """UserNotExistError 核心组件类。 - 这是一个自定义异常类,专门用于在 UserNotExist 相关业务流程中触发中断。它携带了精确的错误上下文与追溯代码,帮助最外层网关能够统一捕获并返回友好的前端错误提示。 """ + 这是一个自定义异常类,专门用于在 UserNotExist 相关业务流程中触发中断。它携带了精确的错误上下文与追溯代码,帮助最外层网关能够统一捕获并返回友好的前端错误提示。""" + pass + class UserPasswordError(UserError): """UserPasswordError 核心组件类。 - 这是一个自定义异常类,专门用于在 UserPassword 相关业务流程中触发中断。它携带了精确的错误上下文与追溯代码,帮助最外层网关能够统一捕获并返回友好的前端错误提示。 """ + 这是一个自定义异常类,专门用于在 UserPassword 相关业务流程中触发中断。它携带了精确的错误上下文与追溯代码,帮助最外层网关能够统一捕获并返回友好的前端错误提示。""" + pass + class ProviderError(Exception): """ProviderError 核心组件类。 - 这是一个模型/服务提供商适配器类,屏蔽了外部不同供应商(如 OpenAI、Anthropic 等)的底层 API 差异。它负责标准化参数组装、网络请求发送、鉴权处理以及响应结构的反序列化。 """ + 这是一个模型/服务提供商适配器类,屏蔽了外部不同供应商(如 OpenAI、Anthropic 等)的底层 API 差异。它负责标准化参数组装、网络请求发送、鉴权处理以及响应结构的反序列化。""" + pass + class ProviderNotExistError(ProviderError): """ProviderNotExistError 核心组件类。 - 这是一个模型/服务提供商适配器类,屏蔽了外部不同供应商(如 OpenAI、Anthropic 等)的底层 API 差异。它负责标准化参数组装、网络请求发送、鉴权处理以及响应结构的反序列化。 """ + 这是一个模型/服务提供商适配器类,屏蔽了外部不同供应商(如 OpenAI、Anthropic 等)的底层 API 差异。它负责标准化参数组装、网络请求发送、鉴权处理以及响应结构的反序列化。""" + pass + class WorkflowError(Exception): - """WorkflowError 核心组件类。 - 这是一个自定义异常类,专门用于在 Workflow 相关业务流程中触发中断。它携带了精确的错误上下文与追溯代码,帮助最外层网关能够统一捕获并返回友好的前端错误提示。 """ + 这是一个自定义异常类,专门用于在 Workflow 相关业务流程中触发中断。它携带了精确的错误上下文与追溯代码,帮助最外层网关能够统一捕获并返回友好的前端错误提示。""" + pass + class WorkflowExit(WorkflowError): - """WorkflowExit 核心组件类。 - 这是一个领域数据模型或功能封装类,承载了 WorkflowExit 相关的内聚属性定义与状态维护。它的存在隔离了局部的业务复杂性,并对外提供了类型安全的访问接口。 """ + 这是一个领域数据模型或功能封装类,承载了 WorkflowExit 相关的内聚属性定义与状态维护。它的存在隔离了局部的业务复杂性,并对外提供了类型安全的访问接口。""" + pass diff --git a/pretor/utils/get_tool.py b/pretor/utils/get_tool.py index 13fa59b..2f5d634 100644 --- a/pretor/utils/get_tool.py +++ b/pretor/utils/get_tool.py @@ -18,7 +18,8 @@ import sys from typing import Callable, Dict, List from pretor.utils.logger import get_logger -logger = get_logger('get_tool') + +logger = get_logger("get_tool") _tool_cache: Dict[str, Callable] = {} @@ -26,13 +27,15 @@ def _get_tool_func(tool_name: str) -> Callable | None: """检索并获取特定的 tool func 数据集合或实例对象。 根据提供的查询条件或上下文凭证,从数据库、缓存或第三方服务中读取对应的资源状态。 Args: tool_name (str): 赋予该实体的人类可读名称或标题字符串,主要用于前端 UI 展示、日志记录或模糊检索。 - Returns: (Callable | None): 经由当前业务模型加工处理后所输出的具体数据实例或领域模型对象。 """ + Returns: (Callable | None): 经由当前业务模型加工处理后所输出的具体数据实例或领域模型对象。""" func = _tool_cache.get(tool_name, None) if func: return func app_root = "/app" - tool_plugin_dir = os.path.join(app_root, "pretor", "plugin", "tool_plugin", tool_name) + tool_plugin_dir = os.path.join( + app_root, "pretor", "plugin", "tool_plugin", tool_name + ) if not os.path.exists(tool_plugin_dir) or not os.path.isdir(tool_plugin_dir): logger.error(f"Tool directory not found: {tool_plugin_dir}") @@ -57,7 +60,9 @@ def _get_tool_func(tool_name: str) -> Callable | None: func = getattr(module, tool_name, None) if not callable(func): - logger.error(f"Tool function '{tool_name}' not found or not callable in {module_name}") + logger.error( + f"Tool function '{tool_name}' not found or not callable in {module_name}" + ) return None _tool_cache[tool_name] = func return func @@ -65,19 +70,21 @@ def _get_tool_func(tool_name: str) -> Callable | None: logger.error(f"Failed to load module {module_name}: {e}") return None + def del_tool_cache(tool_name: str) -> None: """执行与 del tool cache 相关的核心业务流转操作。 该方法封装了具体的算法策略或状态控制逻辑,确保操作能够在事务上下文中被原子且一致地执行。 Args: tool_name (str): 赋予该实体的人类可读名称或标题字符串,主要用于前端 UI 展示、日志记录或模糊检索。 - Returns: (None): 经由当前业务模型加工处理后所输出的具体数据实例或领域模型对象。 """ + Returns: (None): 经由当前业务模型加工处理后所输出的具体数据实例或领域模型对象。""" if tool_name in _tool_cache: del _tool_cache[tool_name] + def load_tools_from_list(tool_names: List[str] | None) -> List[Callable]: """执行与 load tools from list 相关的核心业务流转操作。 该方法封装了具体的算法策略或状态控制逻辑,确保操作能够在事务上下文中被原子且一致地执行。 Args: tool_names (List[str] | None): 赋予该实体的人类可读名称或标题字符串,主要用于前端 UI 展示、日志记录或模糊检索。 - Returns: (List[Callable]): 经过筛选、排序或分页处理后的实体对象列表集合。 """ + Returns: (List[Callable]): 经过筛选、排序或分页处理后的实体对象列表集合。""" if not tool_names: return [] @@ -87,4 +94,4 @@ def load_tools_from_list(tool_names: List[str] | None) -> List[Callable]: if tool_func: tool_list.append(tool_func) - return tool_list \ No newline at end of file + return tool_list diff --git a/pretor/utils/logger.py b/pretor/utils/logger.py index 0c6ad98..b4626a8 100644 --- a/pretor/utils/logger.py +++ b/pretor/utils/logger.py @@ -16,10 +16,11 @@ from loguru import logger from rich.logging import RichHandler from loguru._logger import Logger + def setup_logger() -> Logger: """对现有的 setup logger 进行状态更新或属性覆盖。 基于增量变更原则,合并最新的配置或数据,并触发相关依赖组件的缓存刷新或事件通知。 - Returns: (Logger): 经由当前业务模型加工处理后所输出的具体数据实例或领域模型对象。 """ + Returns: (Logger): 经由当前业务模型加工处理后所输出的具体数据实例或领域模型对象。""" logger.remove() def format_record(record): @@ -27,7 +28,7 @@ def setup_logger() -> Logger: """执行与 format record 相关的核心业务流转操作。 该方法封装了具体的算法策略或状态控制逻辑,确保操作能够在事务上下文中被原子且一致地执行。 Args: record: 参与 format record 逻辑运算或数据构建的上下文依赖对象。 - Returns: : 经由当前业务模型加工处理后所输出的具体数据实例或领域模型对象。 """ + Returns: : 经由当前业务模型加工处理后所输出的具体数据实例或领域模型对象。""" actor = record["extra"].get("actor_name", "System") trace_id = record["extra"].get("trace_id", "") @@ -37,19 +38,27 @@ def setup_logger() -> Logger: logger.configure(extra={"actor_name": "System", "trace_id": ""}) logger.add( - RichHandler(rich_tracebacks=True, markup=True, show_time=False, show_level=False, show_path=False), + RichHandler( + rich_tracebacks=True, + markup=True, + show_time=False, + show_level=False, + show_path=False, + ), format=format_record, level="DEBUG", - enqueue=True, # 异步记录 + enqueue=True, # 异步记录 ) return logger + global_logger = setup_logger() + def get_logger(actor_name: str, trace_id: str = "") -> Logger: """检索并获取特定的 logger 数据集合或实例对象。 根据提供的查询条件或上下文凭证,从数据库、缓存或第三方服务中读取对应的资源状态。 Args: actor_name (str): 赋予该实体的人类可读名称或标题字符串,主要用于前端 UI 展示、日志记录或模糊检索。 trace_id (str): 目标对象的唯一全局标识符 (UUID/ULID),用于在数据库表或缓存结构中精准匹配该 trace 实例。 - Returns: (Logger): 经由当前业务模型加工处理后所输出的具体数据实例或领域模型对象。 """ + Returns: (Logger): 经由当前业务模型加工处理后所输出的具体数据实例或领域模型对象。""" return global_logger.bind(actor_name=actor_name, trace_id=trace_id) diff --git a/pretor/utils/pickle.py b/pretor/utils/pickle.py index d96c83f..8c58393 100644 --- a/pretor/utils/pickle.py +++ b/pretor/utils/pickle.py @@ -17,6 +17,7 @@ from pydantic import BaseModel T = TypeVar("T", bound=Type[BaseModel]) + def pickle(cls: T) -> T: """ 类装饰器pickle @@ -27,14 +28,15 @@ def pickle(cls: T) -> T: Returns: 返回被重写了__reduce__魔术方法的cls类 """ + def __reduce__(self): # 1. 序列化:触发 Pydantic-core (Rust) 的极速序列化 """执行与 reduce 相关的核心业务流转操作。 该方法封装了具体的算法策略或状态控制逻辑,确保操作能够在事务上下文中被原子且一致地执行。 - Returns: : 经由当前业务模型加工处理后所输出的具体数据实例或领域模型对象。 """ + Returns: : 经由当前业务模型加工处理后所输出的具体数据实例或领域模型对象。""" data = self.model_dump_json() # 2. 反序列化:告诉 Pickle 重建时调用 cls.model_validate_json return cls.model_validate_json, (data,) + cls.__reduce__ = __reduce__ return cls - diff --git a/pretor/utils/ray_hook.py b/pretor/utils/ray_hook.py index 606e396..97f573c 100644 --- a/pretor/utils/ray_hook.py +++ b/pretor/utils/ray_hook.py @@ -14,23 +14,25 @@ import ray from functools import lru_cache + class ActorList: """ActorList 核心组件类。 - 这是一个领域数据模型或功能封装类,承载了 ActorList 相关的内聚属性定义与状态维护。它的存在隔离了局部的业务复杂性,并对外提供了类型安全的访问接口。 """ + 这是一个领域数据模型或功能封装类,承载了 ActorList 相关的内聚属性定义与状态维护。它的存在隔离了局部的业务复杂性,并对外提供了类型安全的访问接口。""" + def __init__(self): - super().__setattr__('dict', {}) + super().__setattr__("dict", {}) def __setattr__(self, key, value): """对现有的 setattr 进行状态更新或属性覆盖。 基于增量变更原则,合并最新的配置或数据,并触发相关依赖组件的缓存刷新或事件通知。 - Args: key: 参与 setattr 逻辑运算或数据构建的上下文依赖对象。 value: 参与 setattr 逻辑运算或数据构建的上下文依赖对象。 """ + Args: key: 参与 setattr 逻辑运算或数据构建的上下文依赖对象。 value: 参与 setattr 逻辑运算或数据构建的上下文依赖对象。""" self.dict[key] = value def __getattr__(self, key): """检索并获取特定的 getattr 数据集合或实例对象。 根据提供的查询条件或上下文凭证,从数据库、缓存或第三方服务中读取对应的资源状态。 Args: key: 参与 getattr 逻辑运算或数据构建的上下文依赖对象。 - Returns: : 经由当前业务模型加工处理后所输出的具体数据实例或领域模型对象。 """ + Returns: : 经由当前业务模型加工处理后所输出的具体数据实例或领域模型对象。""" if key in self.dict: return self.dict[key] raise AttributeError(f"ActorList 对象没有属性 '{key}'") @@ -38,28 +40,30 @@ class ActorList: def __delattr__(self, key): """执行与 delattr 相关的核心业务流转操作。 该方法封装了具体的算法策略或状态控制逻辑,确保操作能够在事务上下文中被原子且一致地执行。 - Args: key: 参与 delattr 逻辑运算或数据构建的上下文依赖对象。 """ + Args: key: 参与 delattr 逻辑运算或数据构建的上下文依赖对象。""" if key in self.dict: del self.dict[key] else: raise AttributeError(f"ActorList对象没有属性 '{key}'") + @lru_cache(maxsize=128) def _get_cached_actor_handle(actor_name: str): """缓存接口""" return ray.get_actor(actor_name, namespace="pretor") + def clear_actor_cache(): """清理接口""" _get_cached_actor_handle.cache_clear() + def ray_actor_hook(*actor_names: str): """执行与 ray actor hook 相关的核心业务流转操作。 该方法封装了具体的算法策略或状态控制逻辑,确保操作能够在事务上下文中被原子且一致地执行。 - Returns: : 经由当前业务模型加工处理后所输出的具体数据实例或领域模型对象。 """ + Returns: : 经由当前业务模型加工处理后所输出的具体数据实例或领域模型对象。""" actor_list = ActorList() for actor_name in actor_names: handle = _get_cached_actor_handle(actor_name) setattr(actor_list, actor_name, handle) return actor_list - diff --git a/pretor/utils/retry.py b/pretor/utils/retry.py index ed28ed4..610f7e3 100644 --- a/pretor/utils/retry.py +++ b/pretor/utils/retry.py @@ -17,43 +17,51 @@ import asyncio from functools import wraps from pretor.utils.error import RetryableError + def retry_on_retryable_error(max_retries=3, base_delay=1): """执行与 retry on retryable error 相关的核心业务流转操作。 该方法封装了具体的算法策略或状态控制逻辑,确保操作能够在事务上下文中被原子且一致地执行。 Args: max_retries: 参与 retry on retryable error 逻辑运算或数据构建的上下文依赖对象。 base_delay: 参与 retry on retryable error 逻辑运算或数据构建的上下文依赖对象。 - Returns: : 经由当前业务模型加工处理后所输出的具体数据实例或领域模型对象。 """ + Returns: : 经由当前业务模型加工处理后所输出的具体数据实例或领域模型对象。""" + def decorator(func): """执行与 decorator 相关的核心业务流转操作。 该方法封装了具体的算法策略或状态控制逻辑,确保操作能够在事务上下文中被原子且一致地执行。 Args: func: 参与 decorator 逻辑运算或数据构建的上下文依赖对象。 - Returns: : 经由当前业务模型加工处理后所输出的具体数据实例或领域模型对象。 """ + Returns: : 经由当前业务模型加工处理后所输出的具体数据实例或领域模型对象。""" if asyncio.iscoroutinefunction(func): + @wraps(func) async def async_wrapper(*args, **kwargs): """执行与 async wrapper 相关的核心业务流转操作。 该方法封装了具体的算法策略或状态控制逻辑,确保操作能够在事务上下文中被原子且一致地执行。 - Returns: : 经由当前业务模型加工处理后所输出的具体数据实例或领域模型对象。 """ + Returns: : 经由当前业务模型加工处理后所输出的具体数据实例或领域模型对象。""" for attempt in range(max_retries): try: return await func(*args, **kwargs) except RetryableError: if attempt == max_retries - 1: raise - await asyncio.sleep(base_delay * (2 ** attempt)) + await asyncio.sleep(base_delay * (2**attempt)) + return async_wrapper else: + @wraps(func) def sync_wrapper(*args, **kwargs): """执行与 sync wrapper 相关的核心业务流转操作。 该方法封装了具体的算法策略或状态控制逻辑,确保操作能够在事务上下文中被原子且一致地执行。 - Returns: : 经由当前业务模型加工处理后所输出的具体数据实例或领域模型对象。 """ + Returns: : 经由当前业务模型加工处理后所输出的具体数据实例或领域模型对象。""" import time + for attempt in range(max_retries): try: return func(*args, **kwargs) except RetryableError: if attempt == max_retries - 1: raise - time.sleep(base_delay * (2 ** attempt)) + time.sleep(base_delay * (2**attempt)) + return sync_wrapper + return decorator diff --git a/pretor/worker_cluster/__init__.py b/pretor/worker_cluster/__init__.py new file mode 100644 index 0000000..35f451f --- /dev/null +++ b/pretor/worker_cluster/__init__.py @@ -0,0 +1,3 @@ +from pretor.worker_cluster.worker_cluster import WorkerCluster + +__all__ = ["WorkerCluster"] diff --git a/pretor/worker_individual/worker_cluster.py b/pretor/worker_cluster/worker_cluster.py similarity index 84% rename from pretor/worker_individual/worker_cluster.py rename to pretor/worker_cluster/worker_cluster.py index fc4af71..707f2e6 100644 --- a/pretor/worker_individual/worker_cluster.py +++ b/pretor/worker_cluster/worker_cluster.py @@ -42,14 +42,16 @@ class WorkerCluster: self.results_futures = {} self.runners = [] self.num_runners = num_runners - self.logger = get_logger('worker_cluster') + self.logger = get_logger("worker_cluster") async def start(self): """执行与 start 相关的核心业务流转操作。 - 该方法封装了具体的算法策略或状态控制逻辑,确保操作能够在事务上下文中被原子且一致地执行。 """ + 该方法封装了具体的算法策略或状态控制逻辑,确保操作能够在事务上下文中被原子且一致地执行。""" if self.task_queue is None: self.task_queue = Queue() - self.runners = [asyncio.create_task(self._runner(i)) for i in range(self.num_runners)] + self.runners = [ + asyncio.create_task(self._runner(i)) for i in range(self.num_runners) + ] self.logger.info(f"WorkerCluster 已启动 {self.num_runners} 个 runner 协程。") async def _recruit_worker(self, agent_id: str) -> BaseIndividual: @@ -58,8 +60,10 @@ class WorkerCluster: self._active_workers.move_to_end(agent_id) return self._active_workers[agent_id] - global_state_machine = ray_actor_hook("global_state_machine").global_state_machine - agent_config = await global_state_machine.get_individual.remote( agent_id) + global_state_machine = ray_actor_hook( + "global_state_machine" + ).global_state_machine + agent_config = await global_state_machine.get_individual.remote(agent_id) if not agent_config: raise ValueError(f"无法唤醒 Agent {agent_id}:数据库中不存在该档案") @@ -82,7 +86,7 @@ class WorkerCluster: async def _runner(self, runner_id: int): """执行与 runner 相关的核心业务流转操作。 该方法封装了具体的算法策略或状态控制逻辑,确保操作能够在事务上下文中被原子且一致地执行。 - Args: runner_id (int): 目标对象的唯一全局标识符 (UUID/ULID),用于在数据库表或缓存结构中精准匹配该 runner 实例。 """ + Args: runner_id (int): 目标对象的唯一全局标识符 (UUID/ULID),用于在数据库表或缓存结构中精准匹配该 runner 实例。""" while True: try: if self.task_queue is None: @@ -93,7 +97,9 @@ class WorkerCluster: agent_id = task.get("agent_id") task_event = task.get("task_event") - self.logger.debug(f"[WorkerCluster Runner {runner_id}] 开始处理任务 {task_id} 给 Agent {agent_id}") + self.logger.debug( + f"[WorkerCluster Runner {runner_id}] 开始处理任务 {task_id} 给 Agent {agent_id}" + ) start_time = time.time() try: @@ -105,40 +111,36 @@ class WorkerCluster: "success": True, "agent_id": agent_id, "data": result, - "metrics": {"cost_time_sec": round(cost_time, 2)} + "metrics": {"cost_time_sec": round(cost_time, 2)}, } except Exception as e: - self.logger.exception(f"[WorkerCluster Runner {runner_id}] 执行任务 {task_id} 时发生错误: {e}") - response = { - "success": False, - "agent_id": agent_id, - "error": str(e) - } + self.logger.exception( + f"[WorkerCluster Runner {runner_id}] 执行任务 {task_id} 时发生错误: {e}" + ) + response = {"success": False, "agent_id": agent_id, "error": str(e)} if task_id in self.results_futures: future = self.results_futures[task_id] if not future.done(): future.set_result(response) except Exception as e: - self.logger.error(f"[WorkerCluster Runner {runner_id}] 循环发生异常: {e}") + self.logger.error( + f"[WorkerCluster Runner {runner_id}] 循环发生异常: {e}" + ) await asyncio.sleep(1) async def submit_task(self, task_id: str, agent_id: str, task_event: dict): """执行与 submit task 相关的核心业务流转操作。 该方法封装了具体的算法策略或状态控制逻辑,确保操作能够在事务上下文中被原子且一致地执行。 Args: task_id (str): 目标对象的唯一全局标识符 (UUID/ULID),用于在数据库表或缓存结构中精准匹配该 task 实例。 agent_id (str): 目标对象的唯一全局标识符 (UUID/ULID),用于在数据库表或缓存结构中精准匹配该 agent 实例。 task_event (dict): 由事件总线或工作流引擎分发过来的事件载荷,封装了触发此次调用的上下文快照与任务目标指令。 - Returns: : 经由当前业务模型加工处理后所输出的具体数据实例或领域模型对象。 """ + Returns: : 经由当前业务模型加工处理后所输出的具体数据实例或领域模型对象。""" if not self.runners: await self.start() future = asyncio.Future() self.results_futures[task_id] = future - task = { - "task_id": task_id, - "agent_id": agent_id, - "task_event": task_event - } + task = {"task_id": task_id, "agent_id": agent_id, "task_event": task_event} await self.task_queue.put_async(task) self.logger.debug(f"[WorkerCluster] 任务 {task_id} 已加入队列。") @@ -151,10 +153,10 @@ class WorkerCluster: def get_cluster_metrics(self): """检索并获取特定的 cluster metrics 数据集合或实例对象。 根据提供的查询条件或上下文凭证,从数据库、缓存或第三方服务中读取对应的资源状态。 - Returns: : 经由当前业务模型加工处理后所输出的具体数据实例或领域模型对象。 """ + Returns: : 经由当前业务模型加工处理后所输出的具体数据实例或领域模型对象。""" return { "active_worker_count": len(self._active_workers), "max_capacity": self.max_capacity, "cached_agent_ids": list(self._active_workers.keys()), - "queue_size": self.task_queue.size() + "queue_size": self.task_queue.size(), } diff --git a/pretor/worker_individual/base_individual.py b/pretor/worker_individual/base_individual.py index 6283c9e..b054d8d 100644 --- a/pretor/worker_individual/base_individual.py +++ b/pretor/worker_individual/base_individual.py @@ -20,23 +20,31 @@ from pretor.utils.agent_model import ResponseModel, InputModel, DepsModel from pretor.utils.ray_hook import ray_actor_hook from pretor.utils.logger import get_logger -logger = get_logger('worker_individual') + +logger = get_logger("worker_individual") + class WorkerIndividualResponse(ResponseModel): """WorkerIndividualResponse 核心组件类。 - 这是一个具体的 Worker 智能体实体类,代表着具备特定人设、领域技能或长文本处理能力的数字员工。它可以被控制器动态拉起,并在安全沙箱内执行复杂的工作流指令与多步骤推理任务。 """ + 这是一个具体的 Worker 智能体实体类,代表着具备特定人设、领域技能或长文本处理能力的数字员工。它可以被控制器动态拉起,并在安全沙箱内执行复杂的工作流指令与多步骤推理任务。""" + output: str = Field(..., description="Worker执行任务的输出结果") + class WorkerIndividualDeps(DepsModel): """WorkerIndividualDeps 核心组件类。 - 这是一个具体的 Worker 智能体实体类,代表着具备特定人设、领域技能或长文本处理能力的数字员工。它可以被控制器动态拉起,并在安全沙箱内执行复杂的工作流指令与多步骤推理任务。 """ + 这是一个具体的 Worker 智能体实体类,代表着具备特定人设、领域技能或长文本处理能力的数字员工。它可以被控制器动态拉起,并在安全沙箱内执行复杂的工作流指令与多步骤推理任务。""" + task_event: dict + class WorkerIndividualInput(InputModel): """WorkerIndividualInput 核心组件类。 - 这是一个具体的 Worker 智能体实体类,代表着具备特定人设、领域技能或长文本处理能力的数字员工。它可以被控制器动态拉起,并在安全沙箱内执行复杂的工作流指令与多步骤推理任务。 """ + 这是一个具体的 Worker 智能体实体类,代表着具备特定人设、领域技能或长文本处理能力的数字员工。它可以被控制器动态拉起,并在安全沙箱内执行复杂的工作流指令与多步骤推理任务。""" + task_event: dict + class BaseIndividual: """ Worker Individual 的基类 @@ -51,14 +59,21 @@ class BaseIndividual: """完成 agent 模块的启动与依赖初始化。 在系统引导或服务拉起阶段被调用,负责建立网络连接、分配基础内存资源及注册核心服务组件。 Args: agent_name (str): 赋予该实体的人类可读名称或标题字符串,主要用于前端 UI 展示、日志记录或模糊检索。 system_prompt (str): 控制逻辑流向的具体字符串参数,指定了期望的 system prompt 内容。 - Returns: : 经由当前业务模型加工处理后所输出的具体数据实例或领域模型对象。 """ + Returns: : 经由当前业务模型加工处理后所输出的具体数据实例或领域模型对象。""" from pretor.utils.get_tool import load_tools_from_list - global_state_machine = ray_actor_hook("global_state_machine").global_state_machine - provider_title = self.agent_config.get("provider_title", "openai") # default fallback + + global_state_machine = ray_actor_hook( + "global_state_machine" + ).global_state_machine + provider_title = self.agent_config.get( + "provider_title", "openai" + ) # default fallback model_id = self.agent_config.get("model_id", "gpt-4o") # default fallback tools_list = self.agent_config.get("tools", None) - provider: Provider = await global_state_machine.get_provider.remote( provider_title) + provider: Provider = await global_state_machine.get_provider.remote( + provider_title + ) agent_factory = AgentFactory() callables = load_tools_from_list(tools_list) @@ -70,7 +85,7 @@ class BaseIndividual: system_prompt=system_prompt, deps_type=WorkerIndividualDeps, agent_name=agent_name, - tools=callables + tools=callables, ) @self.agent.system_prompt @@ -78,17 +93,14 @@ class BaseIndividual: """执行与 dynamic prompt 相关的核心业务流转操作。 该方法封装了具体的算法策略或状态控制逻辑,确保操作能够在事务上下文中被原子且一致地执行。 Args: ctx (RunContext[WorkerIndividualDeps]): 参与 dynamic prompt 逻辑运算或数据构建的上下文依赖对象。 - Returns: : 经由当前业务模型加工处理后所输出的具体数据实例或领域模型对象。 """ + Returns: : 经由当前业务模型加工处理后所输出的具体数据实例或领域模型对象。""" prompt = system_prompt + "\n\n" - prompt += ( - f"=== 当前任务上下文 ===\n" - f"{ctx.deps.task_event}\n" - ) + prompt += f"=== 当前任务上下文 ===\n{ctx.deps.task_event}\n" return prompt async def run(self, task_event: dict) -> dict: """执行与 run 相关的核心业务流转操作。 该方法封装了具体的算法策略或状态控制逻辑,确保操作能够在事务上下文中被原子且一致地执行。 Args: task_event (dict): 由事件总线或工作流引擎分发过来的事件载荷,封装了触发此次调用的上下文快照与任务目标指令。 - Returns: (dict): 高度聚合的字典结构数据,将多维度的属性特征或统计指标组合后一并返回。 """ + Returns: (dict): 高度聚合的字典结构数据,将多维度的属性特征或统计指标组合后一并返回。""" raise NotImplementedError("子类必须实现 run 方法") diff --git a/pretor/worker_individual/ordinary_individual.py b/pretor/worker_individual/ordinary_individual.py index 3055fba..b1e2c8e 100644 --- a/pretor/worker_individual/ordinary_individual.py +++ b/pretor/worker_individual/ordinary_individual.py @@ -12,10 +12,14 @@ # See the License for the specific language governing permissions and # limitations under the License. -from pretor.worker_individual.base_individual import BaseIndividual, WorkerIndividualDeps +from pretor.worker_individual.base_individual import ( + BaseIndividual, + WorkerIndividualDeps, +) from pretor.utils.logger import get_logger -logger = get_logger('ordinary_individual') +logger = get_logger("ordinary_individual") + class OrdinaryIndividual(BaseIndividual): """ @@ -29,18 +33,17 @@ class OrdinaryIndividual(BaseIndividual): """执行与 run 相关的核心业务流转操作。 该方法封装了具体的算法策略或状态控制逻辑,确保操作能够在事务上下文中被原子且一致地执行。 Args: task_event (dict): 由事件总线或工作流引擎分发过来的事件载荷,封装了触发此次调用的上下文快照与任务目标指令。 - Returns: (dict): 高度聚合的字典结构数据,将多维度的属性特征或统计指标组合后一并返回。 """ + Returns: (dict): 高度聚合的字典结构数据,将多维度的属性特征或统计指标组合后一并返回。""" if self.agent is None: - system_prompt = self.agent_config.get("prompt", "你是一个普通的AI助手,请尽力完成给定的任务。") + system_prompt = self.agent_config.get( + "prompt", "你是一个普通的AI助手,请尽力完成给定的任务。" + ) await self._init_agent("ordinary_individual", system_prompt) deps = WorkerIndividualDeps(task_event=task_event) self.agent.retries = 3 try: - result = await self.agent.run( - f"请执行以下任务:\n{task_event}", - deps=deps - ) + result = await self.agent.run(f"请执行以下任务:\n{task_event}", deps=deps) return {"output": result.data.output} except Exception as e: logger.exception(f"OrdinaryIndividual {self.agent_id} 执行失败: {e}") diff --git a/pretor/worker_individual/skill_individual.py b/pretor/worker_individual/skill_individual.py index b5751b9..ce982dc 100644 --- a/pretor/worker_individual/skill_individual.py +++ b/pretor/worker_individual/skill_individual.py @@ -12,14 +12,18 @@ # See the License for the specific language governing permissions and # limitations under the License. -from pretor.worker_individual.base_individual import BaseIndividual, WorkerIndividualDeps +from pretor.worker_individual.base_individual import ( + BaseIndividual, + WorkerIndividualDeps, +) from pretor.utils.logger import get_logger import os import json from pydantic_ai import Tool import importlib.util -logger = get_logger('skill_individual') +logger = get_logger("skill_individual") + class SkillIndividual(BaseIndividual): """ @@ -43,7 +47,9 @@ class SkillIndividual(BaseIndividual): elif isinstance(bound_skill, dict): skill_mapper = bound_skill - skill_base_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "plugin", "skill")) + skill_base_dir = os.path.abspath( + os.path.join(os.path.dirname(__file__), "..", "plugin", "skill") + ) for skill_name, _ in skill_mapper.items(): skill_path = os.path.join(skill_base_dir, skill_name) @@ -52,7 +58,7 @@ class SkillIndividual(BaseIndividual): continue try: - with open(metadata_path, 'r', encoding='utf-8') as f: + with open(metadata_path, "r", encoding="utf-8") as f: metadata = json.load(f) except Exception as e: logger.error(f"Failed to load metadata for skill {skill_name}: {e}") @@ -72,18 +78,28 @@ class SkillIndividual(BaseIndividual): func_name = func_info.get("name") try: # Dynamically load the python module - spec = importlib.util.spec_from_file_location(func_name, script_path) + spec = importlib.util.spec_from_file_location( + func_name, script_path + ) module = importlib.util.module_from_spec(spec) spec.loader.exec_module(module) func = getattr(module, func_name) if callable(func): # Convert to PydanticAI Tool - tool = Tool(func, name=func_name, description=func_info.get("docstring", "")) + tool = Tool( + func, + name=func_name, + description=func_info.get("docstring", ""), + ) tools.append(tool) - logger.info(f"Loaded skill tool: {func_name} from {skill_name}") + logger.info( + f"Loaded skill tool: {func_name} from {skill_name}" + ) except Exception as e: - logger.error(f"Failed to load function {func_name} from {script_path}: {e}") + logger.error( + f"Failed to load function {func_name} from {script_path}: {e}" + ) return tools @@ -91,10 +107,12 @@ class SkillIndividual(BaseIndividual): """执行与 run 相关的核心业务流转操作。 该方法封装了具体的算法策略或状态控制逻辑,确保操作能够在事务上下文中被原子且一致地执行。 Args: task_event (dict): 由事件总线或工作流引擎分发过来的事件载荷,封装了触发此次调用的上下文快照与任务目标指令。 - Returns: (dict): 高度聚合的字典结构数据,将多维度的属性特征或统计指标组合后一并返回。 """ + Returns: (dict): 高度聚合的字典结构数据,将多维度的属性特征或统计指标组合后一并返回。""" if self.agent is None: - system_prompt = self.agent_config.get("prompt", - "你是一个拥有专业技能的专家级AI助手,请利用你的专业知识完成给定的任务。") + system_prompt = self.agent_config.get( + "prompt", + "你是一个拥有专业技能的专家级AI助手,请利用你的专业知识完成给定的任务。", + ) await self._init_agent("skill_individual", system_prompt) deps = WorkerIndividualDeps(task_event=task_event) @@ -106,7 +124,7 @@ class SkillIndividual(BaseIndividual): result = await self.agent.run( f"请执行以下任务:\n{task_event}", deps=deps, - tools=tools if tools else None + tools=tools if tools else None, ) return {"output": result.data.output} except Exception as e: diff --git a/pretor/worker_individual/special_individual.py b/pretor/worker_individual/special_individual.py index b92bdc7..5017f97 100644 --- a/pretor/worker_individual/special_individual.py +++ b/pretor/worker_individual/special_individual.py @@ -12,10 +12,14 @@ # See the License for the specific language governing permissions and # limitations under the License. -from pretor.worker_individual.base_individual import BaseIndividual, WorkerIndividualDeps +from pretor.worker_individual.base_individual import ( + BaseIndividual, + WorkerIndividualDeps, +) from pretor.utils.logger import get_logger -logger = get_logger('special_individual') +logger = get_logger("special_individual") + class SpecialIndividual(BaseIndividual): """ @@ -29,18 +33,17 @@ class SpecialIndividual(BaseIndividual): """执行与 run 相关的核心业务流转操作。 该方法封装了具体的算法策略或状态控制逻辑,确保操作能够在事务上下文中被原子且一致地执行。 Args: task_event (dict): 由事件总线或工作流引擎分发过来的事件载荷,封装了触发此次调用的上下文快照与任务目标指令。 - Returns: (dict): 高度聚合的字典结构数据,将多维度的属性特征或统计指标组合后一并返回。 """ + Returns: (dict): 高度聚合的字典结构数据,将多维度的属性特征或统计指标组合后一并返回。""" if self.agent is None: - system_prompt = self.agent_config.get("prompt", "你是一个特殊的AI助手,负责处理特殊类型的任务。") + system_prompt = self.agent_config.get( + "prompt", "你是一个特殊的AI助手,负责处理特殊类型的任务。" + ) await self._init_agent("special_individual", system_prompt) deps = WorkerIndividualDeps(task_event=task_event) self.agent.retries = 3 try: - result = await self.agent.run( - f"请执行以下任务:\n{task_event}", - deps=deps - ) + result = await self.agent.run(f"请执行以下任务:\n{task_event}", deps=deps) return {"output": result.data.output} except Exception as e: logger.exception(f"SpecialIndividual {self.agent_id} 执行失败: {e}") diff --git a/tests/adapter/model_adapter/agent_factory_test.py b/tests/adapter/model_adapter/agent_factory_test.py index a39380e..a7bc196 100644 --- a/tests/adapter/model_adapter/agent_factory_test.py +++ b/tests/adapter/model_adapter/agent_factory_test.py @@ -12,8 +12,12 @@ def test_create_agent_success_real(): mock_provider.provider_url = "url" with patch("pretor.adapter.model_adapter.agent_factory.Agent") as mock_agent_cls: - with patch("pretor.adapter.model_adapter.agent_factory.OpenAIChatModel") as mock_model_cls: - with patch("pretor.adapter.model_adapter.agent_factory.OpenAIProvider") as mock_provider_cls: + with patch( + "pretor.adapter.model_adapter.agent_factory.OpenAIChatModel" + ) as mock_model_cls: + with patch( + "pretor.adapter.model_adapter.agent_factory.OpenAIProvider" + ) as mock_provider_cls: factory = AgentFactory() agent = factory.create_agent( provider=mock_provider, @@ -21,17 +25,19 @@ def test_create_agent_success_real(): output_type=str, system_prompt="You are an AI", deps_type=dict, - agent_name="myagent" + agent_name="myagent", ) mock_provider_cls.assert_called_once_with(api_key="key", base_url="url") - mock_model_cls.assert_called_once_with("gpt-4", provider=mock_provider_cls.return_value) + mock_model_cls.assert_called_once_with( + "gpt-4", provider=mock_provider_cls.return_value + ) mock_agent_cls.assert_called_once_with( model=mock_model_cls.return_value, name="myagent", system_prompt="You are an AI", output_type=str, - deps_type=dict, - tools=None + deps_type=dict, + tools=None, ) assert agent == mock_agent_cls.return_value diff --git a/tests/core/database/database_exception_test.py b/tests/core/database/database_exception_test.py index 8d4de0c..8dd1ce3 100644 --- a/tests/core/database/database_exception_test.py +++ b/tests/core/database/database_exception_test.py @@ -5,34 +5,42 @@ from pydantic import ValidationError from pretor.utils.error import UserNotExistError from pretor.core.database.database_exception import database_exception + @database_exception async def success_func(): return "success" + @database_exception async def validation_error_func(): raise ValidationError.from_exception_data(title="Mock", line_errors=[]) + @database_exception async def integrity_error_func(): raise IntegrityError("mock_statement", "mock_params", "mock_orig") + @database_exception async def operational_error_func(): raise OperationalError("mock_statement", "mock_params", "mock_orig") + @database_exception async def user_not_exist_error_func(): raise UserNotExistError("mock user") + @database_exception async def exception_func(): raise Exception("mock generic exception") + @pytest.mark.asyncio async def test_success_func(): assert await success_func() == "success" + @pytest.mark.asyncio @patch("pretor.core.database.database_exception.logger") async def test_validation_error(mock_logger): @@ -41,6 +49,7 @@ async def test_validation_error(mock_logger): mock_logger.error.assert_called_once() assert "对象校验失败" in mock_logger.error.call_args[0][0] + @pytest.mark.asyncio @patch("pretor.core.database.database_exception.logger") async def test_integrity_error(mock_logger): @@ -49,6 +58,7 @@ async def test_integrity_error(mock_logger): mock_logger.error.assert_called_once() assert "数据库完整性错误" in mock_logger.error.call_args[0][0] + @pytest.mark.asyncio @patch("pretor.core.database.database_exception.logger") async def test_operational_error(mock_logger): @@ -57,6 +67,7 @@ async def test_operational_error(mock_logger): mock_logger.error.assert_called_once() assert "数据库连接异常" in mock_logger.error.call_args[0][0] + @pytest.mark.asyncio @patch("pretor.core.database.database_exception.logger") async def test_user_not_exist_error(mock_logger): @@ -65,6 +76,7 @@ async def test_user_not_exist_error(mock_logger): mock_logger.error.assert_called_once() assert "更改密码失败,用户不存在" in mock_logger.error.call_args[0][0] + @pytest.mark.asyncio @patch("pretor.core.database.database_exception.logger") async def test_generic_exception(mock_logger): diff --git a/tests/core/database/module/user_test.py b/tests/core/database/module/user_test.py index 0556270..1e9fca1 100644 --- a/tests/core/database/module/user_test.py +++ b/tests/core/database/module/user_test.py @@ -26,6 +26,7 @@ def mock_session_maker(): async def test_add_user(mock_session_maker, mock_dependencies): mock_user_cls, _ = mock_dependencies from pretor.core.database.module.user import AuthDatabase + maker, session = mock_session_maker db = AuthDatabase(maker) @@ -51,6 +52,7 @@ async def test_add_user(mock_session_maker, mock_dependencies): async def test_change_password_success(mock_session_maker, mock_dependencies): mock_user_cls, mock_select = mock_dependencies from pretor.core.database.module.user import AuthDatabase + maker, session = mock_session_maker db = AuthDatabase(maker) @@ -79,6 +81,7 @@ async def test_change_password_success(mock_session_maker, mock_dependencies): async def test_change_password_user_not_exist(mock_session_maker, mock_dependencies): mock_user_cls, mock_select = mock_dependencies from pretor.core.database.module.user import AuthDatabase + maker, session = mock_session_maker db = AuthDatabase(maker) @@ -94,10 +97,12 @@ async def test_change_password_user_not_exist(mock_session_maker, mock_dependenc async def test_change_password_wrong_password(mock_session_maker, mock_dependencies): mock_user_cls, mock_select = mock_dependencies from pretor.core.database.module.user import AuthDatabase + maker, session = mock_session_maker db = AuthDatabase(maker) from pretor.utils.access import Accessor + mock_user = MagicMock() mock_user.hashed_password = Accessor.hash_password("actual_password") mock_exec_result = MagicMock() @@ -105,6 +110,7 @@ async def test_change_password_wrong_password(mock_session_maker, mock_dependenc session.execute = AsyncMock(return_value=mock_exec_result) from pretor.utils.error import UserPasswordError + with pytest.raises(UserPasswordError): await db.change_password("testuser", "old_password", "new_password") @@ -113,6 +119,7 @@ async def test_change_password_wrong_password(mock_session_maker, mock_dependenc async def test_delete_user_success(mock_session_maker, mock_dependencies): mock_user_cls, mock_select = mock_dependencies from pretor.core.database.module.user import AuthDatabase + maker, session = mock_session_maker db = AuthDatabase(maker) @@ -134,6 +141,7 @@ async def test_delete_user_success(mock_session_maker, mock_dependencies): async def test_delete_user_not_exist(mock_session_maker, mock_dependencies): mock_user_cls, mock_select = mock_dependencies from pretor.core.database.module.user import AuthDatabase + maker, session = mock_session_maker db = AuthDatabase(maker) @@ -149,6 +157,7 @@ async def test_delete_user_not_exist(mock_session_maker, mock_dependencies): async def test_login_user_success(mock_session_maker, mock_dependencies): mock_user_cls, mock_select = mock_dependencies from pretor.core.database.module.user import AuthDatabase + maker, session = mock_session_maker db = AuthDatabase(maker) @@ -169,6 +178,7 @@ async def test_login_user_success(mock_session_maker, mock_dependencies): async def test_login_user_not_exist(mock_session_maker, mock_dependencies): mock_user_cls, mock_select = mock_dependencies from pretor.core.database.module.user import AuthDatabase + maker, session = mock_session_maker db = AuthDatabase(maker) diff --git a/tests/core/database/table/table_provider_test.py b/tests/core/database/table/table_provider_test.py index a31147a..060e7e6 100644 --- a/tests/core/database/table/table_provider_test.py +++ b/tests/core/database/table/table_provider_test.py @@ -1,5 +1,6 @@ from pretor.core.database.table.provider import Provider + def test_provider_table(): # Provide required fields provider = Provider( @@ -8,7 +9,7 @@ def test_provider_table(): provider_apikey="key", provider_models=["model_1"], provider_type="type", - provider_owner=1 + provider_owner=1, ) - assert Provider.__tablename__ == 'provider' + assert Provider.__tablename__ == "provider" assert provider.provider_title == "title" diff --git a/tests/core/database/table/table_user_test.py b/tests/core/database/table/table_user_test.py index 8c6ab7f..ef700c0 100644 --- a/tests/core/database/table/table_user_test.py +++ b/tests/core/database/table/table_user_test.py @@ -1,6 +1,7 @@ from pretor.core.database.table.user import User + def test_user_table(): user = User(user_id="id", user_name="name", hashed_password="pw") - assert User.__tablename__ == 'user' + assert User.__tablename__ == "user" assert user.user_name == "name" diff --git a/tests/core/global_state_machine/global_state_machine_test.py b/tests/core/global_state_machine/global_state_machine_test.py index a06ebce..216d8bc 100644 --- a/tests/core/global_state_machine/global_state_machine_test.py +++ b/tests/core/global_state_machine/global_state_machine_test.py @@ -7,14 +7,16 @@ real_import = builtins.__import__ def mock_import(name, globals=None, locals=None, fromlist=(), level=0): - if name == 'ray': + if name == "ray": mock_ray = MagicMock() def mock_remote(*args, **kwargs): if len(args) == 1 and callable(args[0]): return args[0] + def decorator(cls): return cls + return decorator mock_ray.remote = mock_remote @@ -25,10 +27,10 @@ def mock_import(name, globals=None, locals=None, fromlist=(), level=0): builtins.__import__ = mock_import for mod in list(sys.modules.keys()): - if 'pretor.core.global_state_machine.global_state_machine' in mod or 'ray' in mod: + if "pretor.core.global_state_machine.global_state_machine" in mod or "ray" in mod: del sys.modules[mod] -from pretor.core.global_state_machine.global_state_machine import GlobalStateMachine +from pretor.core.global_state_machine.global_state_machine import GlobalStateMachine # noqa: E402 builtins.__import__ = real_import @@ -82,13 +84,17 @@ async def test_add_provider_unsupported(gsm): @pytest.mark.asyncio async def test_add_provider_request_error(gsm): from httpx import RequestError + mock_provider_class = AsyncMock() - mock_provider_class.create_provider.side_effect = RequestError("Network Error", request=MagicMock()) + mock_provider_class.create_provider.side_effect = RequestError( + "Network Error", request=MagicMock() + ) gsm._global_provider_manager.provider_mapper = {"openai": mock_provider_class} with patch("pretor.utils.logger.global_logger.bind") as mock_bind: from pretor.utils.error import RetryableError import pytest + mock_logger = MagicMock() mock_bind.return_value = mock_logger with pytest.raises(RetryableError): @@ -117,3 +123,4 @@ def test_get_provider_list_and_get_provider(gsm): assert gsm._global_provider_manager.get_provider_list() == {"p1": mock_provider} assert gsm._global_provider_manager.get_provider("p1") == mock_provider assert gsm._global_provider_manager.get_provider("missing") is None +# noqa: E402 diff --git a/tests/core/global_state_machine/model_provider/base_provider_test.py b/tests/core/global_state_machine/model_provider/base_provider_test.py index 46b7bcd..4641a55 100644 --- a/tests/core/global_state_machine/model_provider/base_provider_test.py +++ b/tests/core/global_state_machine/model_provider/base_provider_test.py @@ -1,25 +1,32 @@ -from pretor.core.global_state_machine.model_provider.base_provider import Provider, ProviderArgs, ProviderStatus +from pretor.core.global_state_machine.model_provider.base_provider import ( + Provider, + ProviderArgs, + ProviderStatus, +) + def test_provider_status(): assert ProviderStatus.UP == "up" assert ProviderStatus.DOWN == "down" + def test_provider_args(): args = ProviderArgs( provider_title="title", provider_url="url", provider_apikey="key", - provider_owner="1" + provider_owner="1", ) assert args.provider_title == "title" + def test_provider_model(): p = Provider( provider_title="title", provider_url="url", provider_apikey="key", provider_models=["model"], - provider_type="openai" + provider_type="openai", ) assert p.provider_status == ProviderStatus.UP assert p.provider_owner is None diff --git a/tests/core/global_state_machine/model_provider/claude_provider_test.py b/tests/core/global_state_machine/model_provider/claude_provider_test.py index 6f3a348..fe9e0a4 100644 --- a/tests/core/global_state_machine/model_provider/claude_provider_test.py +++ b/tests/core/global_state_machine/model_provider/claude_provider_test.py @@ -1,6 +1,9 @@ import pytest from unittest.mock import patch, MagicMock, AsyncMock -from pretor.core.global_state_machine.model_provider.claude_provider import ClaudeProvider, ProviderArgs +from pretor.core.global_state_machine.model_provider.claude_provider import ( + ClaudeProvider, + ProviderArgs, +) @pytest.fixture @@ -9,12 +12,14 @@ def provider_args(): provider_title="TestClaude", provider_url="https://api.anthropic.com", provider_apikey="testkey", - provider_owner="1" + provider_owner="1", ) @pytest.mark.asyncio -@patch("pretor.core.global_state_machine.model_provider.claude_provider.httpx.AsyncClient") +@patch( + "pretor.core.global_state_machine.model_provider.claude_provider.httpx.AsyncClient" +) async def test_load_models_success(mock_client, provider_args): mock_response = MagicMock() mock_response.status_code = 200 @@ -31,7 +36,9 @@ async def test_load_models_success(mock_client, provider_args): @pytest.mark.asyncio -@patch("pretor.core.global_state_machine.model_provider.claude_provider.httpx.AsyncClient") +@patch( + "pretor.core.global_state_machine.model_provider.claude_provider.httpx.AsyncClient" +) async def test_load_models_error(mock_client, provider_args): mock_client_instance = AsyncMock() mock_client_instance.get.side_effect = Exception("network error") @@ -42,8 +49,10 @@ async def test_load_models_error(mock_client, provider_args): @pytest.mark.asyncio -@patch("pretor.core.global_state_machine.model_provider.claude_provider.ClaudeProvider._load_models", - return_value=["claude-3"]) +@patch( + "pretor.core.global_state_machine.model_provider.claude_provider.ClaudeProvider._load_models", + return_value=["claude-3"], +) async def test_create_provider(mock_load, provider_args): provider = await ClaudeProvider.create_provider(provider_args) assert provider.provider_title == "TestClaude" diff --git a/tests/core/global_state_machine/model_provider/openai_provider_test.py b/tests/core/global_state_machine/model_provider/openai_provider_test.py index 99246e3..0b5d0ff 100644 --- a/tests/core/global_state_machine/model_provider/openai_provider_test.py +++ b/tests/core/global_state_machine/model_provider/openai_provider_test.py @@ -1,6 +1,9 @@ import pytest from unittest.mock import patch, MagicMock, AsyncMock -from pretor.core.global_state_machine.model_provider.openai_provider import OpenAIProvider, ProviderArgs +from pretor.core.global_state_machine.model_provider.openai_provider import ( + OpenAIProvider, + ProviderArgs, +) @pytest.fixture @@ -9,7 +12,7 @@ def provider_args(): provider_title="TestOpenAI", provider_url="https://api.openai.com/v1", provider_apikey="testkey", - provider_owner="1" + provider_owner="1", ) @@ -19,12 +22,14 @@ def provider_args_no_v1(): provider_title="TestOpenAI", provider_url="https://api.openai.com", provider_apikey="testkey", - provider_owner="1" + provider_owner="1", ) @pytest.mark.asyncio -@patch("pretor.core.global_state_machine.model_provider.openai_provider.httpx.AsyncClient") +@patch( + "pretor.core.global_state_machine.model_provider.openai_provider.httpx.AsyncClient" +) async def test_load_models_success(mock_client, provider_args): mock_response = MagicMock() mock_response.status_code = 200 @@ -40,12 +45,14 @@ async def test_load_models_success(mock_client, provider_args): assert models == ["gpt-3.5-turbo", "gpt-4"] mock_client_instance.get.assert_called_once_with( "https://api.openai.com/v1/models", - headers={"Authorization": "Bearer testkey", "Content-Type": "application/json"} + headers={"Authorization": "Bearer testkey", "Content-Type": "application/json"}, ) @pytest.mark.asyncio -@patch("pretor.core.global_state_machine.model_provider.openai_provider.httpx.AsyncClient") +@patch( + "pretor.core.global_state_machine.model_provider.openai_provider.httpx.AsyncClient" +) async def test_load_models_no_v1(mock_client, provider_args_no_v1): mock_response = MagicMock() mock_response.status_code = 200 @@ -59,12 +66,14 @@ async def test_load_models_no_v1(mock_client, provider_args_no_v1): assert models == [] mock_client_instance.get.assert_called_once_with( "https://api.openai.com/v1/models", - headers={"Authorization": "Bearer testkey", "Content-Type": "application/json"} + headers={"Authorization": "Bearer testkey", "Content-Type": "application/json"}, ) @pytest.mark.asyncio -@patch("pretor.core.global_state_machine.model_provider.openai_provider.httpx.AsyncClient") +@patch( + "pretor.core.global_state_machine.model_provider.openai_provider.httpx.AsyncClient" +) async def test_load_models_status_error(mock_client, provider_args): mock_response = MagicMock() mock_response.status_code = 401 @@ -78,21 +87,29 @@ async def test_load_models_status_error(mock_client, provider_args): @pytest.mark.asyncio -@patch("pretor.core.global_state_machine.model_provider.openai_provider.httpx.AsyncClient") +@patch( + "pretor.core.global_state_machine.model_provider.openai_provider.httpx.AsyncClient" +) async def test_load_models_request_error(mock_client, provider_args): import httpx + mock_client_instance = AsyncMock() - mock_client_instance.get.side_effect = httpx.RequestError("network error", request=MagicMock()) + mock_client_instance.get.side_effect = httpx.RequestError( + "network error", request=MagicMock() + ) mock_client.return_value.__aenter__.return_value = mock_client_instance import pytest from pretor.utils.error import RetryableError + with pytest.raises(RetryableError): await OpenAIProvider._load_models(provider_args) @pytest.mark.asyncio -@patch("pretor.core.global_state_machine.model_provider.openai_provider.httpx.AsyncClient") +@patch( + "pretor.core.global_state_machine.model_provider.openai_provider.httpx.AsyncClient" +) async def test_load_models_generic_error(mock_client, provider_args): mock_client_instance = AsyncMock() mock_client_instance.get.side_effect = Exception("generic error") @@ -103,8 +120,10 @@ async def test_load_models_generic_error(mock_client, provider_args): @pytest.mark.asyncio -@patch("pretor.core.global_state_machine.model_provider.openai_provider.OpenAIProvider._load_models", - return_value=["gpt-4"]) +@patch( + "pretor.core.global_state_machine.model_provider.openai_provider.OpenAIProvider._load_models", + return_value=["gpt-4"], +) async def test_create_provider(mock_load, provider_args): provider = await OpenAIProvider.create_provider(provider_args) assert provider.provider_title == "TestOpenAI" diff --git a/tests/core/global_state_machine/provider_manager_test.py b/tests/core/global_state_machine/provider_manager_test.py index 1111dd1..d318c12 100644 --- a/tests/core/global_state_machine/provider_manager_test.py +++ b/tests/core/global_state_machine/provider_manager_test.py @@ -13,11 +13,15 @@ async def test_provider_manager_init(): mock_provider2.provider_title = "title2" mock_postgres.get_provider = MagicMock() - mock_postgres.get_provider.remote = AsyncMock(return_value=[mock_provider1, mock_provider2]) + mock_postgres.get_provider.remote = AsyncMock( + return_value=[mock_provider1, mock_provider2] + ) manager = ProviderManager(mock_postgres) mock_postgres.provider_database = MagicMock() - mock_postgres.provider_database.remote = AsyncMock(return_value=[mock_provider1, mock_provider2]) + mock_postgres.provider_database.remote = AsyncMock( + return_value=[mock_provider1, mock_provider2] + ) await manager.init_provider_register(mock_postgres) assert "openai" in manager.provider_mapper diff --git a/tests/core/global_state_machine/tool_manager_test.py b/tests/core/global_state_machine/tool_manager_test.py index 53427fa..8f8c912 100644 --- a/tests/core/global_state_machine/tool_manager_test.py +++ b/tests/core/global_state_machine/tool_manager_test.py @@ -1,5 +1,6 @@ from pretor.core.global_state_machine.tool_manager import GlobalToolManager + def test_global_tool_manager_init(): manager = GlobalToolManager() assert isinstance(manager, GlobalToolManager) diff --git a/tests/core/database/postgres_test.py b/tests/core/postgres_database/postgres_test.py similarity index 69% rename from tests/core/database/postgres_test.py rename to tests/core/postgres_database/postgres_test.py index 4463fc9..be674dc 100644 --- a/tests/core/database/postgres_test.py +++ b/tests/core/postgres_database/postgres_test.py @@ -7,14 +7,16 @@ real_import = builtins.__import__ def mock_import(name, globals=None, locals=None, fromlist=(), level=0): - if name == 'ray': + if name == "ray": mock_ray = MagicMock() def mock_remote(*args, **kwargs): if len(args) == 1 and callable(args[0]): return args[0] + def decorator(cls): return cls + return decorator mock_ray.remote = mock_remote @@ -24,28 +26,30 @@ def mock_import(name, globals=None, locals=None, fromlist=(), level=0): builtins.__import__ = mock_import for mod in list(sys.modules.keys()): - if 'pretor.core.database.postgres' in mod or 'ray' in mod: + if "pretor.core.postgres_database.postgres" in mod or "ray" in mod: del sys.modules[mod] -from pretor.core.database.postgres import PostgresDatabase +from pretor.core.postgres_database.postgres import PostgresDatabase # noqa: E402 builtins.__import__ = real_import -@patch("pretor.core.database.postgres.create_async_engine") -@patch("pretor.core.database.postgres.sessionmaker") -@patch("pretor.core.database.postgres.AuthDatabase") -@patch("pretor.core.database.postgres.ProviderDatabase") -@patch("pretor.core.database.postgres.os.environ.get") +@patch("pretor.core.postgres_database.postgres.create_async_engine") +@patch("pretor.core.postgres_database.postgres.sessionmaker") +@patch("pretor.core.postgres_database.postgres.AuthDatabase") +@patch("pretor.core.postgres_database.postgres.ProviderDatabase") +@patch("pretor.core.postgres_database.postgres.os.environ.get") @pytest.mark.asyncio -async def test_postgres_database(mock_env_get, mock_provider_db, mock_auth_db, mock_sessionmaker, mock_create_engine): +async def test_postgres_database( + mock_env_get, mock_provider_db, mock_auth_db, mock_sessionmaker, mock_create_engine +): def env_side_effect(key): return { "POSTGRES_USER": "testuser", "POSTGRES_PASSWORD": "testpassword", "POSTGRES_HOST": "localhost", "POSTGRES_PORT": "5432", - "POSTGRES_DB": "testdb" + "POSTGRES_DB": "testdb", }.get(key) mock_env_get.side_effect = env_side_effect @@ -53,6 +57,7 @@ async def test_postgres_database(mock_env_get, mock_provider_db, mock_auth_db, m mock_engine = MagicMock() mock_conn = MagicMock() from unittest.mock import AsyncMock + mock_conn.run_sync = AsyncMock() mock_begin_ctx = MagicMock() @@ -64,15 +69,17 @@ async def test_postgres_database(mock_env_get, mock_provider_db, mock_auth_db, m db = PostgresDatabase() mock_create_engine.assert_called_once_with( - "postgresql+asyncpg://testuser:testpassword@localhost:5432/testdb", - echo=True + "postgresql+asyncpg://testuser:testpassword@localhost:5432/testdb", echo=True ) mock_auth_db.assert_called_once() mock_provider_db.assert_called_once() mock_auth_db.return_value.get_user_authority = AsyncMock(return_value="test_auth") - with patch("pretor.core.database.postgres.SQLModel.metadata.create_all") as mock_create_all: + with patch( + "pretor.core.postgres_database.postgres.SQLModel.metadata.create_all" + ) as mock_create_all: await db.init_db() mock_conn.run_sync.assert_called_once_with(mock_create_all) assert await db.get_user_authority(user_id="123") == "test_auth" +# noqa: E402 diff --git a/tests/core/workflow/workflow_template_generator/workflow_template_generator_test.py b/tests/core/workflow/workflow_template_generator/workflow_template_generator_test.py deleted file mode 100644 index 3ad77ae..0000000 --- a/tests/core/workflow/workflow_template_generator/workflow_template_generator_test.py +++ /dev/null @@ -1,43 +0,0 @@ -from unittest.mock import patch, MagicMock -from pretor.core.workflow.workflow_template_generator.workflow_template_generator import WorkflowTemplateGenerator - - -@patch("pretor.core.workflow.workflow_template_generator.workflow_template_generator.Path") -def test_generate_workflow_template(mock_path): - mock_dir = MagicMock() - mock_dir.exists.return_value = False - - mock_file = MagicMock() - mock_dir.__truediv__.return_value = mock_file - - mock_open_ctx = MagicMock() - mock_file.open.return_value.__enter__.return_value = mock_open_ctx - - mock_path_root = MagicMock() - mock_path_root.__truediv__.return_value = mock_dir - mock_path.return_value = mock_path_root - - from pretor.core.workflow.workflow_template_generator.workflow_template import WorkflowTemplate - generator = WorkflowTemplateGenerator() - mock_template = MagicMock(spec=WorkflowTemplate) - mock_template.name = "test_wf" - mock_template.desc = "test_desc" - import json - mock_template.model_dump_json.return_value = json.dumps({ - "name": "test_wf", - "desc": "test_desc", - "work_link": [{"step": 1, "node": "n", "action": "a", "desc": "d", "input": [], "output": [], "logic_gate": {}}] - }) - generator.generate_workflow_template( - workflow_template=mock_template - ) - - mock_dir.mkdir.assert_called_once_with(parents=True) - mock_file.open.assert_called_once_with("w", encoding="utf-8") - - mock_open_ctx.write.assert_called_once() - write_arg = mock_open_ctx.write.call_args[0][0] - written_data = json.loads(write_arg) - assert written_data["name"] == "test_wf" - assert written_data["desc"] == "test_desc" - assert written_data["work_link"][0]["step"] == 1 diff --git a/tests/core/workflow/workflow_template_generator/workflow_template_test.py b/tests/core/workflow/workflow_template_generator/workflow_template_test.py deleted file mode 100644 index 327a5f3..0000000 --- a/tests/core/workflow/workflow_template_generator/workflow_template_test.py +++ /dev/null @@ -1,36 +0,0 @@ -import pytest -from pydantic import ValidationError -from pretor.core.workflow.workflow_template_generator.workflow_template import WorkflowTemplateStep, WorkflowTemplate - -def test_workflow_template_step(): - step = WorkflowTemplateStep( - step=1, - node="node_type", - action="act", - desc="desc", - input=["in1"], - output=["out1"], - logic_gate={"if_pass": "next"} - ) - assert step.step == 1 - assert step.node == "node_type" - -def test_workflow_template_success(): - step1 = WorkflowTemplateStep( - step=1, node="node1", action="a1", desc="d1", input=[], output=[], logic_gate={} - ) - step2 = WorkflowTemplateStep( - step=2, node="node2", action="a2", desc="d2", input=[], output=[], logic_gate={} - ) - wt = WorkflowTemplate(name="temp", desc="desc", work_link=[step1, step2]) - assert wt.name == "temp" - -def test_workflow_template_error_duplicate_steps(): - step1 = WorkflowTemplateStep( - step=1, node="node1", action="a1", desc="d1", input=[], output=[], logic_gate={} - ) - step2 = WorkflowTemplateStep( - step=1, node="node2", action="a2", desc="d2", input=[], output=[], logic_gate={} - ) - with pytest.raises(ValidationError, match="Step numbers in work_link must be unique"): - WorkflowTemplate(name="temp", desc="desc", work_link=[step1, step2]) diff --git a/tests/core/workflow/workflow_template_manager_test.py b/tests/core/workflow/workflow_template_manager_test.py deleted file mode 100644 index 213ec81..0000000 --- a/tests/core/workflow/workflow_template_manager_test.py +++ /dev/null @@ -1,57 +0,0 @@ -import json -from unittest.mock import MagicMock, patch, mock_open -from pathlib import Path -from pretor.core.workflow.workflow_template_manager import WorkflowManager - - -def test_workflow_manager_init_success(): - mock_file1 = MagicMock(spec=Path) - mock_file1.open = mock_open(read_data=json.dumps({"name": "test1", "desc": "desc1"})) - - mock_file2 = MagicMock(spec=Path) - mock_file2.open = mock_open(read_data=json.dumps({"name": "test2", "desc": "desc2"})) - - with patch("pretor.core.workflow.workflow_template_manager.Path.glob", return_value=[mock_file1, mock_file2]): - with patch("pretor.core.workflow.workflow_template_manager.WorkflowTemplateGenerator"): - manager = WorkflowManager() - assert manager.workflow_templates_registry == {"test1": "desc1", "test2": "desc2"} - - -def test_workflow_manager_init_json_error(): - mock_file1 = MagicMock(spec=Path) - mock_file1.open = mock_open(read_data="{invalid_json}") - - with patch("pretor.core.workflow.workflow_template_manager.Path.glob", return_value=[mock_file1]): - with patch("pretor.core.workflow.workflow_template_manager.logger") as mock_logger: - with patch("pretor.core.workflow.workflow_template_manager.WorkflowTemplateGenerator"): - manager = WorkflowManager() - assert manager.workflow_templates_registry == {} - mock_logger.warning.assert_called_once() - assert "不是json文件或格式错误" in mock_logger.warning.call_args[0][0] - - -from pretor.core.workflow.workflow_template_generator.workflow_template import WorkflowTemplate - -@patch("pretor.core.workflow.workflow_template_manager.WorkflowTemplateGenerator") -@patch("pretor.core.workflow.workflow_template_manager.Path.glob", return_value=[]) -def test_generate_workflow_template_success(mock_glob, mock_generator_cls): - manager = WorkflowManager() - mock_template = MagicMock(spec=WorkflowTemplate) - mock_template.name = "name" - mock_template.desc = "desc" - mock_generator_cls.return_value.generate_workflow_template.return_value = mock_template - - manager.generate_workflow_template(mock_template) - mock_generator_cls.return_value.generate_workflow_template.assert_called_once_with(workflow_template=mock_template) - assert manager.workflow_templates_registry["name"] == "desc" - - -@patch("pretor.core.workflow.workflow_template_manager.WorkflowTemplateGenerator") -@patch("pretor.core.workflow.workflow_template_manager.Path.glob", return_value=[]) -@patch("pretor.core.workflow.workflow_template_manager.logger") -def test_generate_workflow_template_exception(mock_logger, mock_glob, mock_generator_cls): - mock_generator_cls.return_value.generate_workflow_template.side_effect = Exception("error") - manager = WorkflowManager() - mock_template = MagicMock(spec=WorkflowTemplate) - manager.generate_workflow_template(mock_template) - mock_logger.exception.assert_called_once_with("Failed to generate workflow template") diff --git a/tests/core/workflow/workflow_test.py b/tests/core/workflow/workflow_test.py index 19a44ed..3105dd6 100644 --- a/tests/core/workflow/workflow_test.py +++ b/tests/core/workflow/workflow_test.py @@ -1,5 +1,11 @@ import pytest -from pretor.core.workflow.workflow import WorkStep, PretorWorkflow, WorkflowStatus, LogicGate +from pretor.core.workflow.workflow import ( + WorkStep, + PretorWorkflow, + WorkflowStatus, + LogicGate, +) + def test_work_step(): ws = WorkStep( @@ -7,7 +13,7 @@ def test_work_step(): name="step1", node="control_node", action="coding", - desc="Write some code" + desc="Write some code", ) assert ws.step == 1 assert ws.name == "step1" @@ -16,30 +22,59 @@ def test_work_step(): assert ws.desc == "Write some code" assert ws.status == "waiting" + def test_pretor_workflow_validation_success(): ws1 = WorkStep(step=1, name="s1", node="control_node", action="a1", desc="d1") ws2 = WorkStep(step=2, name="s2", node="supervisory_node", action="a2", desc="d2") - wf = PretorWorkflow(title="wf1", work_link=[ws1, ws2], trace_id="t", event_info={"platform":"a", "user_name":"b"}) + wf = PretorWorkflow( + title="wf1", + work_link=[ws1, ws2], + trace_id="t", + event_info={"platform": "a", "user_name": "b"}, + ) assert wf.title == "wf1" + def test_pretor_workflow_validation_error_step_discontinuous(): ws1 = WorkStep(step=1, name="s1", node="control_node", action="a1", desc="d1") ws2 = WorkStep(step=3, name="s3", node="supervisory_node", action="a2", desc="d2") with pytest.raises(ValueError, match="工作链步数不连续"): - PretorWorkflow(title="wf1", work_link=[ws1, ws2], trace_id="t", event_info={"platform":"a", "user_name":"b"}) + PretorWorkflow( + title="wf1", + work_link=[ws1, ws2], + trace_id="t", + event_info={"platform": "a", "user_name": "b"}, + ) + def test_pretor_workflow_validation_error_jump_out_of_bounds(): lg = LogicGate(if_fail="jump_to_step_3", if_pass="continue") - ws1 = WorkStep(step=1, name="s1", node="control_node", action="a1", desc="d1", logic_gate=lg) + ws1 = WorkStep( + step=1, name="s1", node="control_node", action="a1", desc="d1", logic_gate=lg + ) ws2 = WorkStep(step=2, name="s2", node="supervisory_node", action="a2", desc="d2") with pytest.raises(ValueError, match="跳转目标 Step 3 越界了"): - PretorWorkflow(title="wf1", work_link=[ws1, ws2], trace_id="t", event_info={"platform":"a", "user_name":"b"}) + PretorWorkflow( + title="wf1", + work_link=[ws1, ws2], + trace_id="t", + event_info={"platform": "a", "user_name": "b"}, + ) + def test_pretor_workflow_validation_error_jump_format_error(): lg = LogicGate(if_fail="jump_to_step_invalid", if_pass="continue") - ws1 = WorkStep(step=1, name="s1", node="control_node", action="a1", desc="d1", logic_gate=lg) + ws1 = WorkStep( + step=1, name="s1", node="control_node", action="a1", desc="d1", logic_gate=lg + ) with pytest.raises(ValueError, match="LogicGate 格式错误"): - PretorWorkflow(title="wf1", work_link=[ws1], trace_id="t", event_info={"platform":"a", "user_name":"b"}) + PretorWorkflow( + title="wf1", + work_link=[ws1], + trace_id="t", + event_info={"platform": "a", "user_name": "b"}, + ) + def test_workflow_status(): status = WorkflowStatus() diff --git a/tests/core/workflow/workflow_runner_test.py b/tests/core/workflow_running_engine/workflow_runner_test.py similarity index 80% rename from tests/core/workflow/workflow_runner_test.py rename to tests/core/workflow_running_engine/workflow_runner_test.py index 927344c..0081c8d 100644 --- a/tests/core/workflow/workflow_runner_test.py +++ b/tests/core/workflow_running_engine/workflow_runner_test.py @@ -9,14 +9,16 @@ real_import = builtins.__import__ def mock_import(name, globals=None, locals=None, fromlist=(), level=0): - if name == 'ray': + if name == "ray": mock_ray = MagicMock() def mock_remote(*args, **kwargs): if len(args) == 1 and callable(args[0]): return args[0] + def decorator(cls): return cls + return decorator mock_ray.remote = mock_remote @@ -26,16 +28,19 @@ def mock_import(name, globals=None, locals=None, fromlist=(), level=0): builtins.__import__ = mock_import for mod in list(sys.modules.keys()): - if 'pretor.core.workflow.workflow_runner' in mod or 'ray' in mod: + if "pretor.core.workflow_running_engine.workflow_runner" in mod or "ray" in mod: del sys.modules[mod] -from pretor.core.workflow.workflow_runner import WorkflowEngine, WorkflowRunningEngine +from pretor.core.workflow_running_engine.workflow_runner import ( # noqa: E402 + WorkflowEngine, + WorkflowRunningEngine, +) builtins.__import__ = real_import @pytest.fixture def mock_ray(): - with patch("pretor.core.workflow.workflow_runner.ray") as mock_ray: + with patch("pretor.core.workflow_running_engine.workflow_runner.ray") as mock_ray: mock_ray.get = lambda x: x yield mock_ray @@ -91,7 +96,9 @@ async def test_workflow_engine_run(): engine = WorkflowEngine(mock_wf, mock_conscious, mock_control, mock_supervisor) - with patch("pretor.core.workflow.workflow_runner.ray") as mock_ray_patch: + with patch( + "pretor.core.workflow_running_engine.workflow_runner.ray" + ) as mock_ray_patch: mock_gsm = MagicMock() mock_ray_patch.get_actor.return_value = mock_gsm await engine.run() @@ -141,22 +148,36 @@ async def test_workflow_running_engine_runner(): user_id="test_user", user_name="test_user", message="test_message", - context={"workflow_template": "test_template"} + context={}, ) await engine.workflow_queue.put(mock_event) # Mock the global_state_machine get_skill_list.remote method properly mock_gsm = MagicMock() - mock_gsm.list_individuals.remote = AsyncMock(return_value={"test_skill": {"agent_type": "skill_individual", "agent_name": "TestSkill", "description": "desc"}}) + mock_gsm.list_individuals.remote = AsyncMock( + return_value={ + "test_skill": { + "agent_type": "skill_individual", + "agent_name": "TestSkill", + "description": "desc", + } + } + ) engine.global_state_machine = mock_gsm - with patch("pretor.core.workflow.workflow_runner.WorkflowEngine") as mock_wf_engine_cls, patch("builtins.open", new_callable=MagicMock) as mock_open, \ - patch("pretor.core.workflow.workflow_runner.ray_actor_hook") as mock_hook: - + with ( + patch( + "pretor.core.workflow_running_engine.workflow_runner.WorkflowEngine" + ) as mock_wf_engine_cls, + patch("builtins.open", new_callable=MagicMock) as mock_open, + patch( + "pretor.core.workflow_running_engine.workflow_runner.ray_actor_hook" + ) as mock_hook, + ): # Instead of patching hook, we inject it directly # engine.global_state_machine = AsyncMock() - mock_open.return_value.__enter__.return_value.read.return_value = '{}' + mock_open.return_value.__enter__.return_value.read.return_value = "{}" mock_gwm = MagicMock() mock_gwm.update_workflow.remote = AsyncMock() @@ -170,4 +191,7 @@ async def test_workflow_running_engine_runner(): await asyncio.sleep(0.05) task.cancel() - mock_wf_engine_cls.assert_called_with(mock_wf, mock_consciousness, "control", "supervisor") + mock_wf_engine_cls.assert_called_with( + mock_wf, mock_consciousness, "control", "supervisor" + ) +# noqa: E402 diff --git a/tests/utils/access_test.py b/tests/utils/access_test.py index a7260c8..f65fdf2 100644 --- a/tests/utils/access_test.py +++ b/tests/utils/access_test.py @@ -28,9 +28,9 @@ sys.modules["passlib"] = MagicMock() sys.modules["passlib.context"] = MagicMock() sys.modules["pretor.core.database.table.user"] = MagicMock() -import pytest -import jwt -from pretor.utils.access import Accessor +import pytest # noqa: E402 +import jwt # noqa: E402 +from pretor.utils.access import Accessor # noqa: E402 def test_decode_token_success(): @@ -55,6 +55,7 @@ def test_decode_token_expired(): token = "expired.token.here" from fastapi import HTTPException + with patch("jwt.decode", side_effect=jwt.ExpiredSignatureError): with patch("pretor.utils.access.HTTPException", HTTPException): with pytest.raises(HTTPException) as excinfo: @@ -69,6 +70,7 @@ def test_decode_token_invalid(): token = "invalid.token.here" from fastapi import HTTPException + with patch("jwt.decode", side_effect=jwt.InvalidTokenError): with patch("pretor.utils.access.HTTPException", HTTPException): with pytest.raises(HTTPException) as excinfo: @@ -93,4 +95,5 @@ def test_decode_token_validation_error(): Accessor._decode_token(token) assert excinfo.value.status_code == 401 - assert excinfo.value.detail == "无效的认证凭证" \ No newline at end of file + assert excinfo.value.detail == "无效的认证凭证" +# noqa: E402
Manage and create reusable workflow templates.
Provide the JSON definition for a new workflow template.