wip: 完善了worker_individual和workflow_running_engine的逻辑
This commit is contained in:
parent
c6025732c6
commit
6055606e2c
|
|
@ -3,7 +3,7 @@
|
|||
## 问题栏
|
||||
#### 🔴 核心缺陷与修复 (Bug Fixes & Stability)
|
||||
- [x] /pretor/core/individual每个template进行优化
|
||||
- [ ] /pretor/worker_individual待完善复合子个体和基础子个体
|
||||
- [x] /pretor/worker_individual待完善复合子个体和基础子个体
|
||||
|
||||
#### 🛡️ 安全与合规 (Security & Auth)
|
||||
- [ ] 优化安全架构防止模型注入
|
||||
|
|
@ -15,36 +15,10 @@
|
|||
- [ ] 优化import
|
||||
|
||||
#### 🏗️ 架构演进 (Architecture & Refactoring)
|
||||
- 】~~使用fastapi-users完善用户系统~~(2026/4/19 fastapi-users会严重摧毁代码的优雅性)
|
||||
- [ ] 升级auth功能
|
||||
- [x] /pretor/api的接口函数进行重构
|
||||
- [x] /dockerfile待完善
|
||||
- [ ] 完善沙箱功能
|
||||
- [ ] 完善爬虫功能
|
||||
- [ ] 对接更多的provider
|
||||
|
||||
---
|
||||
## 日志
|
||||
#### 2026/4/12
|
||||
- [x] /pretor/api的接口函数进行重构
|
||||
- [x] /pretor/core/individual每个template进行优化
|
||||
- [ ] /pretor/worker_individual待完善复合子个体和基础子个体
|
||||
- [ ] /pretor/api待完善
|
||||
- [x] /dockerfile待完善
|
||||
|
||||
#### 2026/4/16
|
||||
- [ ] 发布v0.1.0正式版
|
||||
- [ ] 增加对应全workflow的情况追踪,使得在任务运行中人机交互更加自然方便
|
||||
- ~~[ ] 使用fastapi-users完善用户系统~~
|
||||
|
||||
#### 2026/4/19
|
||||
- [ ] 完善沙箱功能
|
||||
- [ ] 完善爬虫功能
|
||||
- [ ] 对接更多的provider
|
||||
- [ ] 优化import
|
||||
- [x] ~~使用fastapi-users完善用户系统~~(2026/4/19 fastapi-users会严重摧毁代码的优雅性)
|
||||
- [x] 升级auth功能
|
||||
|
||||
#### 2026/4/20
|
||||
- [ ] 优化安全架构防止模型注入
|
||||
- [ ] 设计workflowEngine的自动扩缩容设计
|
||||
- [ ] 完善错误捕获和日志系统
|
||||
- [x] /pretor/api的接口函数进行重构
|
||||
- [x] /dockerfile待完善
|
||||
- [ ] 完善沙箱功能
|
||||
- [ ] 完善爬虫功能
|
||||
- [ ] 对接更多的provider
|
||||
|
|
|
|||
|
|
@ -74,3 +74,10 @@ class IndividualDatabase:
|
|||
session.delete(individual)
|
||||
await session.commit()
|
||||
return True
|
||||
|
||||
@database_exception
|
||||
async def get_all_worker_individual(self) -> List[WorkerIndividual]:
|
||||
async with self.async_session_maker() as session:
|
||||
statement = select(WorkerIndividual)
|
||||
results = await session.execute(statement)
|
||||
return list(results.scalars().all())
|
||||
|
|
@ -22,6 +22,8 @@ import asyncio
|
|||
from pretor.core.workflow.workflow import PretorWorkflow
|
||||
from pretor.core.workflow.workflow_template_manager import WorkflowManager
|
||||
from pretor.core.global_state_machine.skill_manager import GlobalSkillManager
|
||||
from pretor.core.global_state_machine.individual_manager import GlobalIndividualManager
|
||||
|
||||
|
||||
@ray.remote
|
||||
class GlobalStateMachine:
|
||||
|
|
@ -33,11 +35,13 @@ class GlobalStateMachine:
|
|||
self._global_tool_manager = GlobalToolManager()
|
||||
self._global_workflow_template_manager = WorkflowManager()
|
||||
self._global_skill_manager = GlobalSkillManager()
|
||||
self._global_individual_manager = GlobalIndividualManager()
|
||||
|
||||
self.postgres_database = postgres_database
|
||||
|
||||
async def init_state_machine(self):
|
||||
await self._global_provider_manager.init_provider_register(self.postgres_database)
|
||||
await self._global_individual_manager.init_individual_register(self.postgres_database)
|
||||
|
||||
async def add_provider_wrap(self, provider_type, provider_title, provider_url, provider_apikey, provider_owner):
|
||||
return await self._global_provider_manager.add_provider(
|
||||
|
|
@ -73,6 +77,12 @@ class GlobalStateMachine:
|
|||
return await method(*args, **kwargs)
|
||||
return method(*args, **kwargs)
|
||||
|
||||
async def individual_manager(self, method_name: str, *args, **kwargs):
|
||||
method = getattr(self._global_individual_manager, method_name)
|
||||
if asyncio.iscoroutinefunction(method):
|
||||
return await method(*args, **kwargs)
|
||||
return method(*args, **kwargs)
|
||||
|
||||
###以下方法为event_dict方法
|
||||
def add_event(self, event: PretorEvent) -> None:
|
||||
event.pending_queue = asyncio.Queue()
|
||||
|
|
|
|||
|
|
@ -0,0 +1,61 @@
|
|||
# Copyright 2026 zhaoxi826
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import Dict, Any
|
||||
from loguru import logger
|
||||
|
||||
class GlobalIndividualManager:
|
||||
def __init__(self):
|
||||
self._individuals: Dict[str, Dict[str, Any]] = {}
|
||||
|
||||
async def init_individual_register(self, postgres) -> None:
|
||||
try:
|
||||
try:
|
||||
individuals = await postgres.individual_database.remote("get_all_worker_individual")
|
||||
for ind in individuals:
|
||||
agent_id = getattr(ind, 'agent_id', None)
|
||||
if agent_id:
|
||||
self._individuals[agent_id] = ind.model_dump() if hasattr(ind, 'model_dump') else dict(ind)
|
||||
logger.info(f"成功从数据库拉取了 {len(self._individuals)} 个 Worker Individual 配置。")
|
||||
except AttributeError:
|
||||
logger.warning("数据库中 get_all_worker_individual 方法未实现,跳过全量加载。可以在将来完善该接口。")
|
||||
except Exception as e:
|
||||
# 捕获因 Ray 调用目标方法不存在引发的异常
|
||||
if "has no attribute 'get_all_worker_individual'" in str(e):
|
||||
logger.warning("数据库 individual_database 中缺少 get_all_worker_individual 方法,无法全量拉取。")
|
||||
else:
|
||||
raise e
|
||||
except Exception as e:
|
||||
logger.error(f"从数据库拉取 Worker Individual 配置失败: {e}")
|
||||
|
||||
def add_individual(self, agent_id: str, config: Dict[str, Any]) -> None:
|
||||
"""
|
||||
注册一个 worker individual
|
||||
config 可以包含 type, prompt, provider_title, model_id 等
|
||||
"""
|
||||
config["agent_id"] = agent_id
|
||||
self._individuals[agent_id] = config
|
||||
|
||||
def get_individual(self, agent_id: str) -> Dict[str, Any]:
|
||||
"""
|
||||
获取一个 worker individual 的配置
|
||||
"""
|
||||
return self._individuals.get(agent_id, None)
|
||||
|
||||
def remove_individual(self, agent_id: str) -> None:
|
||||
if agent_id in self._individuals:
|
||||
del self._individuals[agent_id]
|
||||
|
||||
def list_individuals(self) -> Dict[str, Dict[str, Any]]:
|
||||
return self._individuals
|
||||
|
|
@ -187,7 +187,6 @@ class WorkflowEngine:
|
|||
elif step.node == "consciousness_node":
|
||||
if not self.consciousness_node:
|
||||
raise WorkflowError("未提供 consciousness_node 句柄!")
|
||||
# 这里将 command 作为 original_command,可根据业务调整
|
||||
original_cmd = self.workflow.command or ""
|
||||
payload = ConsciousnessForWorkflowInput(
|
||||
workflow_step=step,
|
||||
|
|
@ -199,11 +198,31 @@ class WorkflowEngine:
|
|||
return result_obj, True
|
||||
|
||||
elif step.node in ["primary_individual", "composite_individual"]:
|
||||
logger.warning(f"当前节点 {step.node} 暂未实现完整调度支持,这里将模拟执行。")
|
||||
await asyncio.sleep(1)
|
||||
simulated_result = f"这是 {step.node} 执行 {step.action} 产生的模拟结果 (输入: {input_data})"
|
||||
return simulated_result, True
|
||||
logger.info(f"正在通过 WorkerCluster 调度 {step.node} 的 {step.action} 动作。")
|
||||
try:
|
||||
from pretor.utils.ray_hook import ray_actor_hook
|
||||
worker_cluster = ray_actor_hook("worker_cluster").worker_cluster
|
||||
task_id = f"{self.workflow.trace_id}_step_{step.step}"
|
||||
agent_id = getattr(step, 'agent_id', f"default_{step.node}")
|
||||
if isinstance(input_data, dict) and "agent_id" in input_data:
|
||||
agent_id = input_data.get("agent_id")
|
||||
task_event = {
|
||||
"action": step.action,
|
||||
"description": step.description,
|
||||
"input_data": input_data,
|
||||
"context_memory": self.workflow.context_memory
|
||||
}
|
||||
result_response = await worker_cluster.submit_task.remote(task_id, agent_id, task_event)
|
||||
|
||||
if result_response.get("success"):
|
||||
return result_response.get("data"), True
|
||||
else:
|
||||
logger.error(f"WorkerCluster 执行 {step.node} 失败: {result_response.get('error')}")
|
||||
return result_response.get("error"), False
|
||||
|
||||
except Exception as e:
|
||||
logger.exception(f"调度 WorkerCluster 执行 {step.node} 时发生异常: {e}")
|
||||
raise WorkflowError(f"WorkerCluster 调度异常: {e}")
|
||||
else:
|
||||
raise WorkflowError(f"未知的节点类型:{step.node}")
|
||||
|
||||
|
|
|
|||
|
|
@ -13,6 +13,14 @@
|
|||
# limitations under the License.
|
||||
|
||||
import ray
|
||||
import time
|
||||
import asyncio
|
||||
from collections import OrderedDict
|
||||
from loguru import logger
|
||||
from ray.util.queue import Queue
|
||||
from pretor.utils.ray_hook import ray_actor_hook
|
||||
from pretor.worker_individual.worker_individual import BaseIndividual, SkillIndividual, OrdinaryIndividual, \
|
||||
SpecialIndividual
|
||||
|
||||
|
||||
@ray.remote
|
||||
|
|
@ -22,35 +30,31 @@ class WorkerCluster:
|
|||
设计理念:按需加载,内存 LRU 淘汰,避免 Actor 爆炸
|
||||
"""
|
||||
|
||||
def __init__(self, db_actor, max_capacity: int = 200):
|
||||
self.db = db_actor
|
||||
def __init__(self, max_capacity: int = 200, num_runners: int = 10):
|
||||
self.max_capacity = max_capacity
|
||||
# 核心:LRU 活跃 Agent 缓存池
|
||||
self._active_workers: OrderedDict[str, BaseWorkerIndividual] = OrderedDict()
|
||||
self._active_workers: OrderedDict[str, BaseIndividual] = OrderedDict()
|
||||
self.status = "running"
|
||||
self.task_queue = Queue()
|
||||
self.results_futures = {}
|
||||
self.runners = []
|
||||
self.num_runners = num_runners
|
||||
|
||||
async def _recruit_worker(self, agent_id: str) -> BaseWorkerIndividual:
|
||||
async def start(self):
|
||||
self.runners = [asyncio.create_task(self._runner(i)) for i in range(self.num_runners)]
|
||||
logger.info(f"WorkerCluster 已启动 {self.num_runners} 个 runner 协程。")
|
||||
|
||||
async def _recruit_worker(self, agent_id: str) -> BaseIndividual:
|
||||
"""内部方法:招聘/唤醒一个具体的 Agent 对象"""
|
||||
|
||||
# 1. 尝试从缓存直接命中
|
||||
if agent_id in self._active_workers:
|
||||
self._active_workers.move_to_end(agent_id) # 标记为最近使用
|
||||
self._active_workers.move_to_end(agent_id)
|
||||
return self._active_workers[agent_id]
|
||||
|
||||
# 2. 缓存未命中,去数据库拉取 Agent 档案配置
|
||||
# agent_config = await self.db.get_agent_config.remote(agent_id)
|
||||
|
||||
# 模拟从数据库取出的配置数据
|
||||
agent_config = {
|
||||
"agent_id": agent_id,
|
||||
"type": "skill", # 取决于数据库里的设定:ordinary, skill, special
|
||||
"prompt": "你是一个资深架构师..."
|
||||
}
|
||||
global_state_machine = ray_actor_hook("global_state_machine").global_state_machine
|
||||
agent_config = await global_state_machine.individual_manager.remote("get_individual", agent_id)
|
||||
|
||||
if not agent_config:
|
||||
raise ValueError(f"无法唤醒 Agent {agent_id}:数据库中不存在该档案")
|
||||
|
||||
# 3. 工厂模式:根据类型动态装配不同量级的 Individual
|
||||
worker_type = agent_config.get("type", "ordinary")
|
||||
if worker_type == "skill":
|
||||
worker = SkillIndividual(agent_config)
|
||||
|
|
@ -59,49 +63,76 @@ class WorkerCluster:
|
|||
else:
|
||||
worker = OrdinaryIndividual(agent_config)
|
||||
|
||||
# 4. 放入内存池,如果爆满则淘汰最老的那个
|
||||
self._active_workers[agent_id] = worker
|
||||
if len(self._active_workers) > self.max_capacity:
|
||||
evicted_id, _ = self._active_workers.popitem(last=False)
|
||||
print(f"[WorkerCluster] 内存池满,休眠老化 Agent: {evicted_id}")
|
||||
logger.info(f"[WorkerCluster] 内存池满,休眠老化 Agent: {evicted_id}")
|
||||
|
||||
return worker
|
||||
|
||||
async def execute_task(self, agent_id: str, task_event: dict) -> dict:
|
||||
"""
|
||||
对外暴露的唯一干活接口。
|
||||
task_event 应该包含所有的上下文(Context、历史记忆、本次指令)
|
||||
"""
|
||||
async def _runner(self, runner_id: int):
|
||||
while True:
|
||||
try:
|
||||
# 1. 获取工作实体(秒级热启动或毫秒级缓存命中)
|
||||
worker = await self._recruit_worker(agent_id)
|
||||
task = await self.task_queue.get_async()
|
||||
task_id = task.get("task_id")
|
||||
agent_id = task.get("agent_id")
|
||||
task_event = task.get("task_event")
|
||||
|
||||
# 2. 注入上下文并执行
|
||||
# 这里的 run 方法内部不保存状态,所有记忆都从 task_event 传入
|
||||
logger.debug(f"[WorkerCluster Runner {runner_id}] 开始处理任务 {task_id} 给 Agent {agent_id}")
|
||||
start_time = time.time()
|
||||
|
||||
try:
|
||||
worker = await self._recruit_worker(agent_id)
|
||||
result = await worker.run(task_event)
|
||||
cost_time = time.time() - start_time
|
||||
|
||||
# 3. 封装标准回包
|
||||
return {
|
||||
response = {
|
||||
"success": True,
|
||||
"agent_id": agent_id,
|
||||
"data": result,
|
||||
"metrics": {"cost_time_sec": round(cost_time, 2)}
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
# 异常隔离:一个 Agent 报错,绝对不能把整个 Cluster 搞崩
|
||||
return {
|
||||
logger.exception(f"[WorkerCluster Runner {runner_id}] 执行任务 {task_id} 时发生错误: {e}")
|
||||
response = {
|
||||
"success": False,
|
||||
"agent_id": agent_id,
|
||||
"error": str(e)
|
||||
}
|
||||
if task_id in self.results_futures:
|
||||
future = self.results_futures[task_id]
|
||||
if not future.done():
|
||||
future.set_result(response)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"[WorkerCluster Runner {runner_id}] 循环发生异常: {e}")
|
||||
await asyncio.sleep(1)
|
||||
|
||||
async def submit_task(self, task_id: str, agent_id: str, task_event: dict):
|
||||
if not self.runners:
|
||||
await self.start()
|
||||
|
||||
future = asyncio.Future()
|
||||
self.results_futures[task_id] = future
|
||||
|
||||
task = {
|
||||
"task_id": task_id,
|
||||
"agent_id": agent_id,
|
||||
"task_event": task_event
|
||||
}
|
||||
await self.task_queue.put_async(task)
|
||||
logger.debug(f"[WorkerCluster] 任务 {task_id} 已加入队列。")
|
||||
|
||||
try:
|
||||
result = await future
|
||||
return result
|
||||
finally:
|
||||
self.results_futures.pop(task_id, None)
|
||||
|
||||
def get_cluster_metrics(self):
|
||||
"""监控探针:用于查看当前集群负载"""
|
||||
return {
|
||||
"active_worker_count": len(self._active_workers),
|
||||
"max_capacity": self.max_capacity,
|
||||
"cached_agent_ids": list(self._active_workers.keys())
|
||||
"cached_agent_ids": list(self._active_workers.keys()),
|
||||
"queue_size": self.task_queue.size()
|
||||
}
|
||||
|
|
@ -11,3 +11,148 @@
|
|||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
from loguru import logger
|
||||
from pydantic_ai import Agent, RunContext
|
||||
from pydantic import Field
|
||||
from pretor.adapter.model_adapter.agent_factory import AgentFactory
|
||||
from pretor.core.global_state_machine.global_state_machine import GlobalStateMachine
|
||||
from pretor.core.global_state_machine.model_provider.base_provider import Provider
|
||||
from pretor.utils.agent_model import ResponseModel, InputModel, DepsModel
|
||||
from pretor.utils.get_tool import get_tool
|
||||
import ray
|
||||
from pretor.utils.ray_hook import ray_actor_hook
|
||||
|
||||
|
||||
class WorkerIndividualResponse(ResponseModel):
|
||||
output: str = Field(..., description="Worker执行任务的输出结果")
|
||||
|
||||
|
||||
class WorkerIndividualDeps(DepsModel):
|
||||
task_event: dict
|
||||
|
||||
|
||||
class WorkerIndividualInput(InputModel):
|
||||
task_event: dict
|
||||
|
||||
|
||||
class BaseIndividual:
|
||||
"""
|
||||
Worker Individual 的基类
|
||||
"""
|
||||
|
||||
def __init__(self, agent_config: dict):
|
||||
self.agent_config = agent_config
|
||||
self.agent_id = agent_config.get("agent_id")
|
||||
self.agent: Agent | None = None
|
||||
|
||||
async def _init_agent(self, agent_name: str, system_prompt: str):
|
||||
global_state_machine = ray_actor_hook("global_state_machine").global_state_machine
|
||||
provider_title = self.agent_config.get("provider_title", "openai") # default fallback
|
||||
model_id = self.agent_config.get("model_id", "gpt-4o") # default fallback
|
||||
|
||||
provider: Provider = await global_state_machine.provider_manager.remote("get_provider", provider_title)
|
||||
agent_factory = AgentFactory()
|
||||
self.agent = agent_factory.create_agent(
|
||||
provider=provider,
|
||||
model_id=model_id,
|
||||
output_type=WorkerIndividualResponse,
|
||||
system_prompt=system_prompt,
|
||||
deps_type=WorkerIndividualDeps,
|
||||
agent_name=agent_name
|
||||
)
|
||||
|
||||
@self.agent.system_prompt
|
||||
async def dynamic_prompt(ctx: RunContext[WorkerIndividualDeps]):
|
||||
prompt = system_prompt + "\n\n"
|
||||
prompt += (
|
||||
f"=== 当前任务上下文 ===\n"
|
||||
f"{ctx.deps.task_event}\n"
|
||||
)
|
||||
return prompt
|
||||
|
||||
async def run(self, task_event: dict) -> dict:
|
||||
raise NotImplementedError("子类必须实现 run 方法")
|
||||
|
||||
|
||||
class SkillIndividual(BaseIndividual):
|
||||
"""
|
||||
专家子个体:拥有专业 skill 的 agent。
|
||||
"""
|
||||
|
||||
def __init__(self, agent_config: dict):
|
||||
super().__init__(agent_config)
|
||||
|
||||
async def run(self, task_event: dict) -> dict:
|
||||
if self.agent is None:
|
||||
system_prompt = self.agent_config.get("prompt",
|
||||
"你是一个拥有专业技能的专家级AI助手,请利用你的专业知识完成给定的任务。")
|
||||
await self._init_agent("skill_individual", system_prompt)
|
||||
|
||||
deps = WorkerIndividualDeps(task_event=task_event)
|
||||
self.agent.retries = 3
|
||||
# In actual usage, tools could be dynamically loaded here based on agent_config
|
||||
# tool = get_tool("skill_individual")
|
||||
try:
|
||||
result = await self.agent.run(
|
||||
f"请执行以下任务:\n{task_event}",
|
||||
deps=deps
|
||||
# tools=tool
|
||||
)
|
||||
return {"output": result.data.output}
|
||||
except Exception as e:
|
||||
logger.exception(f"SkillIndividual {self.agent_id} 执行失败: {e}")
|
||||
raise
|
||||
|
||||
|
||||
class OrdinaryIndividual(BaseIndividual):
|
||||
"""
|
||||
普通子个体:普通的 agent。
|
||||
"""
|
||||
|
||||
def __init__(self, agent_config: dict):
|
||||
super().__init__(agent_config)
|
||||
|
||||
async def run(self, task_event: dict) -> dict:
|
||||
if self.agent is None:
|
||||
system_prompt = self.agent_config.get("prompt", "你是一个普通的AI助手,请尽力完成给定的任务。")
|
||||
await self._init_agent("ordinary_individual", system_prompt)
|
||||
|
||||
deps = WorkerIndividualDeps(task_event=task_event)
|
||||
self.agent.retries = 3
|
||||
try:
|
||||
result = await self.agent.run(
|
||||
f"请执行以下任务:\n{task_event}",
|
||||
deps=deps
|
||||
)
|
||||
return {"output": result.data.output}
|
||||
except Exception as e:
|
||||
logger.exception(f"OrdinaryIndividual {self.agent_id} 执行失败: {e}")
|
||||
raise
|
||||
|
||||
|
||||
class SpecialIndividual(BaseIndividual):
|
||||
"""
|
||||
特殊子个体:执行特殊任务的 agent,如生成语音、视频等。
|
||||
"""
|
||||
|
||||
def __init__(self, agent_config: dict):
|
||||
super().__init__(agent_config)
|
||||
|
||||
async def run(self, task_event: dict) -> dict:
|
||||
if self.agent is None:
|
||||
system_prompt = self.agent_config.get("prompt", "你是一个特殊的AI助手,负责处理特殊类型的任务。")
|
||||
await self._init_agent("special_individual", system_prompt)
|
||||
|
||||
deps = WorkerIndividualDeps(task_event=task_event)
|
||||
self.agent.retries = 3
|
||||
try:
|
||||
result = await self.agent.run(
|
||||
f"请执行以下任务:\n{task_event}",
|
||||
deps=deps
|
||||
)
|
||||
return {"output": result.data.output}
|
||||
except Exception as e:
|
||||
logger.exception(f"SpecialIndividual {self.agent_id} 执行失败: {e}")
|
||||
raise
|
||||
|
|
|
|||
Loading…
Reference in New Issue