Pretor/pretor/worker_individual/worker_cluster.py

142 lines
5.4 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

from pretor.utils.error import RetryableError
# Copyright 2026 zhaoxi826
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import ray
import time
import asyncio
from collections import OrderedDict
from ray.util.queue import Queue
from pretor.utils.ray_hook import ray_actor_hook
from pretor.worker_individual.worker_individual import BaseIndividual, SkillIndividual, OrdinaryIndividual, \
SpecialIndividual
from pretor.utils.logger import get_logger
logger = get_logger('worker_cluster')
@ray.remote
class WorkerCluster:
"""
工作集群 Actor管理和调度所有的 worker_individual
设计理念:按需加载,内存 LRU 淘汰,避免 Actor 爆炸
"""
def __init__(self, max_capacity: int = 200, num_runners: int = 10):
self.max_capacity = max_capacity
self._active_workers: OrderedDict[str, BaseIndividual] = OrderedDict()
self.status = "running"
self.task_queue = Queue()
self.results_futures = {}
self.runners = []
self.num_runners = num_runners
async def start(self):
self.runners = [asyncio.create_task(self._runner(i)) for i in range(self.num_runners)]
logger.info(f"WorkerCluster 已启动 {self.num_runners} 个 runner 协程。")
async def _recruit_worker(self, agent_id: str) -> BaseIndividual:
"""内部方法:招聘/唤醒一个具体的 Agent 对象"""
if agent_id in self._active_workers:
self._active_workers.move_to_end(agent_id)
return self._active_workers[agent_id]
global_state_machine = ray_actor_hook("global_state_machine").global_state_machine
agent_config = await global_state_machine.get_individual.remote( agent_id)
if not agent_config:
raise ValueError(f"无法唤醒 Agent {agent_id}:数据库中不存在该档案")
worker_type = agent_config.get("type", "ordinary")
if worker_type == "skill":
worker = SkillIndividual(agent_config)
elif worker_type == "special":
worker = SpecialIndividual(agent_config)
else:
worker = OrdinaryIndividual(agent_config)
self._active_workers[agent_id] = worker
if len(self._active_workers) > self.max_capacity:
evicted_id, _ = self._active_workers.popitem(last=False)
logger.info(f"[WorkerCluster] 内存池满,休眠老化 Agent: {evicted_id}")
return worker
async def _runner(self, runner_id: int):
while True:
try:
task = await self.task_queue.get_async()
task_id = task.get("task_id")
agent_id = task.get("agent_id")
task_event = task.get("task_event")
logger.debug(f"[WorkerCluster Runner {runner_id}] 开始处理任务 {task_id} 给 Agent {agent_id}")
start_time = time.time()
try:
worker = await self._recruit_worker(agent_id)
result = await worker.run(task_event)
cost_time = time.time() - start_time
response = {
"success": True,
"agent_id": agent_id,
"data": result,
"metrics": {"cost_time_sec": round(cost_time, 2)}
}
except Exception as e:
logger.exception(f"[WorkerCluster Runner {runner_id}] 执行任务 {task_id} 时发生错误: {e}")
response = {
"success": False,
"agent_id": agent_id,
"error": str(e)
}
if task_id in self.results_futures:
future = self.results_futures[task_id]
if not future.done():
future.set_result(response)
except Exception as e:
logger.error(f"[WorkerCluster Runner {runner_id}] 循环发生异常: {e}")
await asyncio.sleep(1)
async def submit_task(self, task_id: str, agent_id: str, task_event: dict):
if not self.runners:
await self.start()
future = asyncio.Future()
self.results_futures[task_id] = future
task = {
"task_id": task_id,
"agent_id": agent_id,
"task_event": task_event
}
await self.task_queue.put_async(task)
logger.debug(f"[WorkerCluster] 任务 {task_id} 已加入队列。")
try:
result = await future
return result
finally:
self.results_futures.pop(task_id, None)
def get_cluster_metrics(self):
return {
"active_worker_count": len(self._active_workers),
"max_capacity": self.max_capacity,
"cached_agent_ids": list(self._active_workers.keys()),
"queue_size": self.task_queue.size()
}