From 6055606e2c7e3c6c2f492f36e69cb2058fc81f81 Mon Sep 17 00:00:00 2001 From: zhaoxi Date: Wed, 22 Apr 2026 21:16:43 +0800 Subject: [PATCH] =?UTF-8?q?wip:=20=E5=AE=8C=E5=96=84=E4=BA=86worker=5Findi?= =?UTF-8?q?vidual=E5=92=8Cworkflow=5Frunning=5Fengine=E7=9A=84=E9=80=BB?= =?UTF-8?q?=E8=BE=91?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- docs/problem.md | 40 +---- pretor/core/database/module/individual.py | 7 + .../global_state_machine.py | 10 ++ .../individual_manager.py | 61 ++++++++ pretor/core/workflow/workflow_runner.py | 29 +++- pretor/worker_individual/worker_cluster.py | 135 +++++++++------- pretor/worker_individual/worker_individual.py | 147 +++++++++++++++++- 7 files changed, 338 insertions(+), 91 deletions(-) create mode 100644 pretor/core/global_state_machine/individual_manager.py diff --git a/docs/problem.md b/docs/problem.md index 81f8042..d817906 100644 --- a/docs/problem.md +++ b/docs/problem.md @@ -3,7 +3,7 @@ ## 问题栏 #### 🔴 核心缺陷与修复 (Bug Fixes & Stability) - [x] /pretor/core/individual每个template进行优化 -- [ ] /pretor/worker_individual待完善复合子个体和基础子个体 +- [x] /pretor/worker_individual待完善复合子个体和基础子个体 #### 🛡️ 安全与合规 (Security & Auth) - [ ] 优化安全架构防止模型注入 @@ -15,36 +15,10 @@ - [ ] 优化import #### 🏗️ 架构演进 (Architecture & Refactoring) -- 】~~使用fastapi-users完善用户系统~~(2026/4/19 fastapi-users会严重摧毁代码的优雅性) -- [ ] 升级auth功能 -- [x] /pretor/api的接口函数进行重构 -- [x] /dockerfile待完善 -- [ ] 完善沙箱功能 -- [ ] 完善爬虫功能 -- [ ] 对接更多的provider - ---- -## 日志 -#### 2026/4/12 -- [x] /pretor/api的接口函数进行重构 -- [x] /pretor/core/individual每个template进行优化 -- [ ] /pretor/worker_individual待完善复合子个体和基础子个体 -- [ ] /pretor/api待完善 -- [x] /dockerfile待完善 - -#### 2026/4/16 -- [ ] 发布v0.1.0正式版 -- [ ] 增加对应全workflow的情况追踪,使得在任务运行中人机交互更加自然方便 -- ~~[ ] 使用fastapi-users完善用户系统~~ - -#### 2026/4/19 -- [ ] 完善沙箱功能 -- [ ] 完善爬虫功能 -- [ ] 对接更多的provider -- [ ] 优化import +- [x] ~~使用fastapi-users完善用户系统~~(2026/4/19 fastapi-users会严重摧毁代码的优雅性) - [x] 升级auth功能 - -#### 2026/4/20 -- [ ] 优化安全架构防止模型注入 -- [ ] 设计workflowEngine的自动扩缩容设计 -- [ ] 完善错误捕获和日志系统 \ No newline at end of file +- [x] /pretor/api的接口函数进行重构 +- [x] /dockerfile待完善 +- [ ] 完善沙箱功能 +- [ ] 完善爬虫功能 +- [ ] 对接更多的provider diff --git a/pretor/core/database/module/individual.py b/pretor/core/database/module/individual.py index 97cb80c..5e0968f 100644 --- a/pretor/core/database/module/individual.py +++ b/pretor/core/database/module/individual.py @@ -74,3 +74,10 @@ class IndividualDatabase: session.delete(individual) await session.commit() return True + + @database_exception + async def get_all_worker_individual(self) -> List[WorkerIndividual]: + async with self.async_session_maker() as session: + statement = select(WorkerIndividual) + results = await session.execute(statement) + return list(results.scalars().all()) \ No newline at end of file diff --git a/pretor/core/global_state_machine/global_state_machine.py b/pretor/core/global_state_machine/global_state_machine.py index 4faefbf..f1c8ffb 100644 --- a/pretor/core/global_state_machine/global_state_machine.py +++ b/pretor/core/global_state_machine/global_state_machine.py @@ -22,6 +22,8 @@ import asyncio from pretor.core.workflow.workflow import PretorWorkflow from pretor.core.workflow.workflow_template_manager import WorkflowManager from pretor.core.global_state_machine.skill_manager import GlobalSkillManager +from pretor.core.global_state_machine.individual_manager import GlobalIndividualManager + @ray.remote class GlobalStateMachine: @@ -33,11 +35,13 @@ class GlobalStateMachine: self._global_tool_manager = GlobalToolManager() self._global_workflow_template_manager = WorkflowManager() self._global_skill_manager = GlobalSkillManager() + self._global_individual_manager = GlobalIndividualManager() self.postgres_database = postgres_database async def init_state_machine(self): await self._global_provider_manager.init_provider_register(self.postgres_database) + await self._global_individual_manager.init_individual_register(self.postgres_database) async def add_provider_wrap(self, provider_type, provider_title, provider_url, provider_apikey, provider_owner): return await self._global_provider_manager.add_provider( @@ -73,6 +77,12 @@ class GlobalStateMachine: return await method(*args, **kwargs) return method(*args, **kwargs) + async def individual_manager(self, method_name: str, *args, **kwargs): + method = getattr(self._global_individual_manager, method_name) + if asyncio.iscoroutinefunction(method): + return await method(*args, **kwargs) + return method(*args, **kwargs) + ###以下方法为event_dict方法 def add_event(self, event: PretorEvent) -> None: event.pending_queue = asyncio.Queue() diff --git a/pretor/core/global_state_machine/individual_manager.py b/pretor/core/global_state_machine/individual_manager.py new file mode 100644 index 0000000..2c7e3ba --- /dev/null +++ b/pretor/core/global_state_machine/individual_manager.py @@ -0,0 +1,61 @@ +# Copyright 2026 zhaoxi826 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Dict, Any +from loguru import logger + +class GlobalIndividualManager: + def __init__(self): + self._individuals: Dict[str, Dict[str, Any]] = {} + + async def init_individual_register(self, postgres) -> None: + try: + try: + individuals = await postgres.individual_database.remote("get_all_worker_individual") + for ind in individuals: + agent_id = getattr(ind, 'agent_id', None) + if agent_id: + self._individuals[agent_id] = ind.model_dump() if hasattr(ind, 'model_dump') else dict(ind) + logger.info(f"成功从数据库拉取了 {len(self._individuals)} 个 Worker Individual 配置。") + except AttributeError: + logger.warning("数据库中 get_all_worker_individual 方法未实现,跳过全量加载。可以在将来完善该接口。") + except Exception as e: + # 捕获因 Ray 调用目标方法不存在引发的异常 + if "has no attribute 'get_all_worker_individual'" in str(e): + logger.warning("数据库 individual_database 中缺少 get_all_worker_individual 方法,无法全量拉取。") + else: + raise e + except Exception as e: + logger.error(f"从数据库拉取 Worker Individual 配置失败: {e}") + + def add_individual(self, agent_id: str, config: Dict[str, Any]) -> None: + """ + 注册一个 worker individual + config 可以包含 type, prompt, provider_title, model_id 等 + """ + config["agent_id"] = agent_id + self._individuals[agent_id] = config + + def get_individual(self, agent_id: str) -> Dict[str, Any]: + """ + 获取一个 worker individual 的配置 + """ + return self._individuals.get(agent_id, None) + + def remove_individual(self, agent_id: str) -> None: + if agent_id in self._individuals: + del self._individuals[agent_id] + + def list_individuals(self) -> Dict[str, Dict[str, Any]]: + return self._individuals diff --git a/pretor/core/workflow/workflow_runner.py b/pretor/core/workflow/workflow_runner.py index 761708b..8f88238 100644 --- a/pretor/core/workflow/workflow_runner.py +++ b/pretor/core/workflow/workflow_runner.py @@ -187,7 +187,6 @@ class WorkflowEngine: elif step.node == "consciousness_node": if not self.consciousness_node: raise WorkflowError("未提供 consciousness_node 句柄!") - # 这里将 command 作为 original_command,可根据业务调整 original_cmd = self.workflow.command or "" payload = ConsciousnessForWorkflowInput( workflow_step=step, @@ -199,11 +198,31 @@ class WorkflowEngine: return result_obj, True elif step.node in ["primary_individual", "composite_individual"]: - logger.warning(f"当前节点 {step.node} 暂未实现完整调度支持,这里将模拟执行。") - await asyncio.sleep(1) - simulated_result = f"这是 {step.node} 执行 {step.action} 产生的模拟结果 (输入: {input_data})" - return simulated_result, True + logger.info(f"正在通过 WorkerCluster 调度 {step.node} 的 {step.action} 动作。") + try: + from pretor.utils.ray_hook import ray_actor_hook + worker_cluster = ray_actor_hook("worker_cluster").worker_cluster + task_id = f"{self.workflow.trace_id}_step_{step.step}" + agent_id = getattr(step, 'agent_id', f"default_{step.node}") + if isinstance(input_data, dict) and "agent_id" in input_data: + agent_id = input_data.get("agent_id") + task_event = { + "action": step.action, + "description": step.description, + "input_data": input_data, + "context_memory": self.workflow.context_memory + } + result_response = await worker_cluster.submit_task.remote(task_id, agent_id, task_event) + if result_response.get("success"): + return result_response.get("data"), True + else: + logger.error(f"WorkerCluster 执行 {step.node} 失败: {result_response.get('error')}") + return result_response.get("error"), False + + except Exception as e: + logger.exception(f"调度 WorkerCluster 执行 {step.node} 时发生异常: {e}") + raise WorkflowError(f"WorkerCluster 调度异常: {e}") else: raise WorkflowError(f"未知的节点类型:{step.node}") diff --git a/pretor/worker_individual/worker_cluster.py b/pretor/worker_individual/worker_cluster.py index e99ccb0..3ee6a2e 100644 --- a/pretor/worker_individual/worker_cluster.py +++ b/pretor/worker_individual/worker_cluster.py @@ -13,6 +13,14 @@ # limitations under the License. import ray +import time +import asyncio +from collections import OrderedDict +from loguru import logger +from ray.util.queue import Queue +from pretor.utils.ray_hook import ray_actor_hook +from pretor.worker_individual.worker_individual import BaseIndividual, SkillIndividual, OrdinaryIndividual, \ + SpecialIndividual @ray.remote @@ -22,35 +30,31 @@ class WorkerCluster: 设计理念:按需加载,内存 LRU 淘汰,避免 Actor 爆炸 """ - def __init__(self, db_actor, max_capacity: int = 200): - self.db = db_actor + def __init__(self, max_capacity: int = 200, num_runners: int = 10): self.max_capacity = max_capacity - # 核心:LRU 活跃 Agent 缓存池 - self._active_workers: OrderedDict[str, BaseWorkerIndividual] = OrderedDict() + self._active_workers: OrderedDict[str, BaseIndividual] = OrderedDict() self.status = "running" + self.task_queue = Queue() + self.results_futures = {} + self.runners = [] + self.num_runners = num_runners - async def _recruit_worker(self, agent_id: str) -> BaseWorkerIndividual: + async def start(self): + self.runners = [asyncio.create_task(self._runner(i)) for i in range(self.num_runners)] + logger.info(f"WorkerCluster 已启动 {self.num_runners} 个 runner 协程。") + + async def _recruit_worker(self, agent_id: str) -> BaseIndividual: """内部方法:招聘/唤醒一个具体的 Agent 对象""" - - # 1. 尝试从缓存直接命中 if agent_id in self._active_workers: - self._active_workers.move_to_end(agent_id) # 标记为最近使用 + self._active_workers.move_to_end(agent_id) return self._active_workers[agent_id] - # 2. 缓存未命中,去数据库拉取 Agent 档案配置 - # agent_config = await self.db.get_agent_config.remote(agent_id) - - # 模拟从数据库取出的配置数据 - agent_config = { - "agent_id": agent_id, - "type": "skill", # 取决于数据库里的设定:ordinary, skill, special - "prompt": "你是一个资深架构师..." - } + global_state_machine = ray_actor_hook("global_state_machine").global_state_machine + agent_config = await global_state_machine.individual_manager.remote("get_individual", agent_id) if not agent_config: raise ValueError(f"无法唤醒 Agent {agent_id}:数据库中不存在该档案") - # 3. 工厂模式:根据类型动态装配不同量级的 Individual worker_type = agent_config.get("type", "ordinary") if worker_type == "skill": worker = SkillIndividual(agent_config) @@ -59,49 +63,76 @@ class WorkerCluster: else: worker = OrdinaryIndividual(agent_config) - # 4. 放入内存池,如果爆满则淘汰最老的那个 self._active_workers[agent_id] = worker if len(self._active_workers) > self.max_capacity: evicted_id, _ = self._active_workers.popitem(last=False) - print(f"[WorkerCluster] 内存池满,休眠老化 Agent: {evicted_id}") + logger.info(f"[WorkerCluster] 内存池满,休眠老化 Agent: {evicted_id}") return worker - async def execute_task(self, agent_id: str, task_event: dict) -> dict: - """ - 对外暴露的唯一干活接口。 - task_event 应该包含所有的上下文(Context、历史记忆、本次指令) - """ + async def _runner(self, runner_id: int): + while True: + try: + task = await self.task_queue.get_async() + task_id = task.get("task_id") + agent_id = task.get("agent_id") + task_event = task.get("task_event") + + logger.debug(f"[WorkerCluster Runner {runner_id}] 开始处理任务 {task_id} 给 Agent {agent_id}") + start_time = time.time() + + try: + worker = await self._recruit_worker(agent_id) + result = await worker.run(task_event) + cost_time = time.time() - start_time + + response = { + "success": True, + "agent_id": agent_id, + "data": result, + "metrics": {"cost_time_sec": round(cost_time, 2)} + } + except Exception as e: + logger.exception(f"[WorkerCluster Runner {runner_id}] 执行任务 {task_id} 时发生错误: {e}") + response = { + "success": False, + "agent_id": agent_id, + "error": str(e) + } + if task_id in self.results_futures: + future = self.results_futures[task_id] + if not future.done(): + future.set_result(response) + + except Exception as e: + logger.error(f"[WorkerCluster Runner {runner_id}] 循环发生异常: {e}") + await asyncio.sleep(1) + + async def submit_task(self, task_id: str, agent_id: str, task_event: dict): + if not self.runners: + await self.start() + + future = asyncio.Future() + self.results_futures[task_id] = future + + task = { + "task_id": task_id, + "agent_id": agent_id, + "task_event": task_event + } + await self.task_queue.put_async(task) + logger.debug(f"[WorkerCluster] 任务 {task_id} 已加入队列。") + try: - # 1. 获取工作实体(秒级热启动或毫秒级缓存命中) - worker = await self._recruit_worker(agent_id) - - # 2. 注入上下文并执行 - # 这里的 run 方法内部不保存状态,所有记忆都从 task_event 传入 - start_time = time.time() - result = await worker.run(task_event) - cost_time = time.time() - start_time - - # 3. 封装标准回包 - return { - "success": True, - "agent_id": agent_id, - "data": result, - "metrics": {"cost_time_sec": round(cost_time, 2)} - } - - except Exception as e: - # 异常隔离:一个 Agent 报错,绝对不能把整个 Cluster 搞崩 - return { - "success": False, - "agent_id": agent_id, - "error": str(e) - } + result = await future + return result + finally: + self.results_futures.pop(task_id, None) def get_cluster_metrics(self): - """监控探针:用于查看当前集群负载""" return { "active_worker_count": len(self._active_workers), "max_capacity": self.max_capacity, - "cached_agent_ids": list(self._active_workers.keys()) - } \ No newline at end of file + "cached_agent_ids": list(self._active_workers.keys()), + "queue_size": self.task_queue.size() + } diff --git a/pretor/worker_individual/worker_individual.py b/pretor/worker_individual/worker_individual.py index a997743..18bd01d 100644 --- a/pretor/worker_individual/worker_individual.py +++ b/pretor/worker_individual/worker_individual.py @@ -10,4 +10,149 @@ # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and -# limitations under the License. \ No newline at end of file +# limitations under the License. + + +from loguru import logger +from pydantic_ai import Agent, RunContext +from pydantic import Field +from pretor.adapter.model_adapter.agent_factory import AgentFactory +from pretor.core.global_state_machine.global_state_machine import GlobalStateMachine +from pretor.core.global_state_machine.model_provider.base_provider import Provider +from pretor.utils.agent_model import ResponseModel, InputModel, DepsModel +from pretor.utils.get_tool import get_tool +import ray +from pretor.utils.ray_hook import ray_actor_hook + + +class WorkerIndividualResponse(ResponseModel): + output: str = Field(..., description="Worker执行任务的输出结果") + + +class WorkerIndividualDeps(DepsModel): + task_event: dict + + +class WorkerIndividualInput(InputModel): + task_event: dict + + +class BaseIndividual: + """ + Worker Individual 的基类 + """ + + def __init__(self, agent_config: dict): + self.agent_config = agent_config + self.agent_id = agent_config.get("agent_id") + self.agent: Agent | None = None + + async def _init_agent(self, agent_name: str, system_prompt: str): + global_state_machine = ray_actor_hook("global_state_machine").global_state_machine + provider_title = self.agent_config.get("provider_title", "openai") # default fallback + model_id = self.agent_config.get("model_id", "gpt-4o") # default fallback + + provider: Provider = await global_state_machine.provider_manager.remote("get_provider", provider_title) + agent_factory = AgentFactory() + self.agent = agent_factory.create_agent( + provider=provider, + model_id=model_id, + output_type=WorkerIndividualResponse, + system_prompt=system_prompt, + deps_type=WorkerIndividualDeps, + agent_name=agent_name + ) + + @self.agent.system_prompt + async def dynamic_prompt(ctx: RunContext[WorkerIndividualDeps]): + prompt = system_prompt + "\n\n" + prompt += ( + f"=== 当前任务上下文 ===\n" + f"{ctx.deps.task_event}\n" + ) + return prompt + + async def run(self, task_event: dict) -> dict: + raise NotImplementedError("子类必须实现 run 方法") + + +class SkillIndividual(BaseIndividual): + """ + 专家子个体:拥有专业 skill 的 agent。 + """ + + def __init__(self, agent_config: dict): + super().__init__(agent_config) + + async def run(self, task_event: dict) -> dict: + if self.agent is None: + system_prompt = self.agent_config.get("prompt", + "你是一个拥有专业技能的专家级AI助手,请利用你的专业知识完成给定的任务。") + await self._init_agent("skill_individual", system_prompt) + + deps = WorkerIndividualDeps(task_event=task_event) + self.agent.retries = 3 + # In actual usage, tools could be dynamically loaded here based on agent_config + # tool = get_tool("skill_individual") + try: + result = await self.agent.run( + f"请执行以下任务:\n{task_event}", + deps=deps + # tools=tool + ) + return {"output": result.data.output} + except Exception as e: + logger.exception(f"SkillIndividual {self.agent_id} 执行失败: {e}") + raise + + +class OrdinaryIndividual(BaseIndividual): + """ + 普通子个体:普通的 agent。 + """ + + def __init__(self, agent_config: dict): + super().__init__(agent_config) + + async def run(self, task_event: dict) -> dict: + if self.agent is None: + system_prompt = self.agent_config.get("prompt", "你是一个普通的AI助手,请尽力完成给定的任务。") + await self._init_agent("ordinary_individual", system_prompt) + + deps = WorkerIndividualDeps(task_event=task_event) + self.agent.retries = 3 + try: + result = await self.agent.run( + f"请执行以下任务:\n{task_event}", + deps=deps + ) + return {"output": result.data.output} + except Exception as e: + logger.exception(f"OrdinaryIndividual {self.agent_id} 执行失败: {e}") + raise + + +class SpecialIndividual(BaseIndividual): + """ + 特殊子个体:执行特殊任务的 agent,如生成语音、视频等。 + """ + + def __init__(self, agent_config: dict): + super().__init__(agent_config) + + async def run(self, task_event: dict) -> dict: + if self.agent is None: + system_prompt = self.agent_config.get("prompt", "你是一个特殊的AI助手,负责处理特殊类型的任务。") + await self._init_agent("special_individual", system_prompt) + + deps = WorkerIndividualDeps(task_event=task_event) + self.agent.retries = 3 + try: + result = await self.agent.run( + f"请执行以下任务:\n{task_event}", + deps=deps + ) + return {"output": result.data.output} + except Exception as e: + logger.exception(f"SpecialIndividual {self.agent_id} 执行失败: {e}") + raise