refactor(core): decouple actors and remove workflow templates (#67)
Removes the deprecated `workflow_template` concept entirely across both backend API routers, internal logic handling within the `supervisory_node` and `consciousness_node`, and front-end components. Enables `consciousness_node` to work autonomously. Also refactors core package structure to enforce the "one python package, one Ray Actor" architectural rule. `GlobalWorkflowManager`, `WorkflowRunningEngine`, `PostgresDatabase`, and `WorkerCluster` have been moved to their own top-level decoupled package directories with properly exported `__init__.py` modules. Test suites have been relocated and import paths updated across the system. Co-authored-by: google-labs-jules[bot] <161369871+google-labs-jules[bot]@users.noreply.github.com> Co-authored-by: zhaoxi826 <198742034+zhaoxi826@users.noreply.github.com>
This commit is contained in:
@@ -1,14 +1,3 @@
|
||||
# Copyright 2026 zhaoxi826
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from pretor.core.global_state_machine.global_state_machine import GlobalStateMachine
|
||||
|
||||
__all__ = ["GlobalStateMachine"]
|
||||
|
||||
@@ -15,8 +15,7 @@
|
||||
import ray
|
||||
from pretor.core.global_state_machine.provider_manager import ProviderManager
|
||||
from pretor.core.global_state_machine.tool_manager import GlobalToolManager
|
||||
from pretor.core.database.postgres import PostgresDatabase
|
||||
from pretor.core.workflow.workflow_template_manager import WorkflowManager
|
||||
from pretor.core.postgres_database import PostgresDatabase
|
||||
from pretor.core.global_state_machine.skill_manager import GlobalSkillManager
|
||||
from pretor.core.global_state_machine.individual_manager import GlobalIndividualManager
|
||||
|
||||
@@ -24,17 +23,17 @@ from pretor.core.global_state_machine.individual_manager import GlobalIndividual
|
||||
@ray.remote
|
||||
class GlobalStateMachine:
|
||||
"""GlobalStateMachine 核心组件类。
|
||||
这是一个领域数据模型或功能封装类,承载了 GlobalStateMachine 相关的内聚属性定义与状态维护。它的存在隔离了局部的业务复杂性,并对外提供了类型安全的访问接口。 """
|
||||
这是一个领域数据模型或功能封装类,承载了 GlobalStateMachine 相关的内聚属性定义与状态维护。它的存在隔离了局部的业务复杂性,并对外提供了类型安全的访问接口。"""
|
||||
|
||||
def __init__(self, postgres_database: PostgresDatabase):
|
||||
import sys
|
||||
|
||||
print("GSM __init__ START", file=sys.stderr, flush=True)
|
||||
print(" event_dict done", file=sys.stderr, flush=True)
|
||||
self._global_provider_manager = ProviderManager(postgres_database)
|
||||
print(" provider_manager done", file=sys.stderr, flush=True)
|
||||
self._global_tool_manager = GlobalToolManager()
|
||||
print(" tool_manager done", file=sys.stderr, flush=True)
|
||||
self._global_workflow_template_manager = WorkflowManager()
|
||||
print(" workflow_template_manager done", file=sys.stderr, flush=True)
|
||||
self._global_skill_manager = GlobalSkillManager()
|
||||
print(" skill_manager done", file=sys.stderr, flush=True)
|
||||
self._global_individual_manager = GlobalIndividualManager()
|
||||
@@ -44,50 +43,63 @@ class GlobalStateMachine:
|
||||
|
||||
async def init_state_machine(self):
|
||||
"""完成 state machine 模块的启动与依赖初始化。
|
||||
在系统引导或服务拉起阶段被调用,负责建立网络连接、分配基础内存资源及注册核心服务组件。 """
|
||||
await self._global_provider_manager.init_provider_register(self.postgres_database)
|
||||
await self._global_individual_manager.init_individual_register(self.postgres_database)
|
||||
在系统引导或服务拉起阶段被调用,负责建立网络连接、分配基础内存资源及注册核心服务组件。"""
|
||||
await self._global_provider_manager.init_provider_register(
|
||||
self.postgres_database
|
||||
)
|
||||
await self._global_individual_manager.init_individual_register(
|
||||
self.postgres_database
|
||||
)
|
||||
|
||||
async def add_provider_wrap(self, provider_type, provider_title, provider_url, provider_apikey, provider_owner):
|
||||
async def add_provider_wrap(
|
||||
self,
|
||||
provider_type,
|
||||
provider_title,
|
||||
provider_url,
|
||||
provider_apikey,
|
||||
provider_owner,
|
||||
):
|
||||
"""创建并持久化新的 provider wrap 实体。
|
||||
接收构建参数,执行必要的数据校验与默认值填充后,将新记录安全地写入底层存储或系统注册表中。
|
||||
Args: provider_type: 目标对象的唯一全局标识符 (UUID/ULID),用于在数据库表或缓存结构中精准匹配该 provider_type 实例。 provider_title: 目标对象的唯一全局标识符 (UUID/ULID),用于在数据库表或缓存结构中精准匹配该 provider_title 实例。 provider_url: 目标对象的唯一全局标识符 (UUID/ULID),用于在数据库表或缓存结构中精准匹配该 provider_url 实例。 provider_apikey: 目标对象的唯一全局标识符 (UUID/ULID),用于在数据库表或缓存结构中精准匹配该 provider_apikey 实例。 provider_owner: 目标对象的唯一全局标识符 (UUID/ULID),用于在数据库表或缓存结构中精准匹配该 provider_owner 实例。
|
||||
Returns: : 经由当前业务模型加工处理后所输出的具体数据实例或领域模型对象。 """
|
||||
Returns: : 经由当前业务模型加工处理后所输出的具体数据实例或领域模型对象。"""
|
||||
return await self._global_provider_manager.add_provider(
|
||||
provider_type=provider_type,
|
||||
provider_title=provider_title,
|
||||
provider_url=provider_url,
|
||||
provider_apikey=provider_apikey,
|
||||
provider_owner=provider_owner,
|
||||
postgres_database=self.postgres_database
|
||||
postgres_database=self.postgres_database,
|
||||
)
|
||||
|
||||
# Provider Manager Methods
|
||||
def get_provider_list(self):
|
||||
"""检索并获取特定的 provider list 数据集合或实例对象。
|
||||
根据提供的查询条件或上下文凭证,从数据库、缓存或第三方服务中读取对应的资源状态。
|
||||
Returns: : 经由当前业务模型加工处理后所输出的具体数据实例或领域模型对象。 """
|
||||
Returns: : 经由当前业务模型加工处理后所输出的具体数据实例或领域模型对象。"""
|
||||
return self._global_provider_manager.get_provider_list()
|
||||
|
||||
def get_provider(self, provider_title):
|
||||
"""检索并获取特定的 provider 数据集合或实例对象。
|
||||
根据提供的查询条件或上下文凭证,从数据库、缓存或第三方服务中读取对应的资源状态。
|
||||
Args: provider_title: 目标对象的唯一全局标识符 (UUID/ULID),用于在数据库表或缓存结构中精准匹配该 provider_title 实例。
|
||||
Returns: : 经由当前业务模型加工处理后所输出的具体数据实例或领域模型对象。 """
|
||||
Returns: : 经由当前业务模型加工处理后所输出的具体数据实例或领域模型对象。"""
|
||||
return self._global_provider_manager.get_provider(provider_title)
|
||||
|
||||
async def delete_provider(self, provider_title: str):
|
||||
"""安全地移除或注销 provider。
|
||||
执行物理删除或逻辑删除操作,并妥善清理相关的关联数据及占用资源。
|
||||
Args: provider_title (str): 目标对象的唯一全局标识符 (UUID/ULID),用于在数据库表或缓存结构中精准匹配该 provider_title 实例。
|
||||
Returns: : 经由当前业务模型加工处理后所输出的具体数据实例或领域模型对象。 """
|
||||
return await self._global_provider_manager.delete_provider(provider_title, self.postgres_database)
|
||||
Returns: : 经由当前业务模型加工处理后所输出的具体数据实例或领域模型对象。"""
|
||||
return await self._global_provider_manager.delete_provider(
|
||||
provider_title, self.postgres_database
|
||||
)
|
||||
|
||||
# Tool Manager Methods
|
||||
def get_tool_mapper(self):
|
||||
"""检索并获取特定的 tool mapper 数据集合或实例对象。
|
||||
根据提供的查询条件或上下文凭证,从数据库、缓存或第三方服务中读取对应的资源状态。
|
||||
Returns: : 经由当前业务模型加工处理后所输出的具体数据实例或领域模型对象。 """
|
||||
Returns: : 经由当前业务模型加工处理后所输出的具体数据实例或领域模型对象。"""
|
||||
return self._global_tool_manager.tool_mapper
|
||||
|
||||
def get_tool_list(self, agent_name: str):
|
||||
@@ -96,60 +108,32 @@ class GlobalStateMachine:
|
||||
"""检索并获取特定的 tool list 数据集合或实例对象。
|
||||
根据提供的查询条件或上下文凭证,从数据库、缓存或第三方服务中读取对应的资源状态。
|
||||
Args: agent_name (str): 赋予该实体的人类可读名称或标题字符串,主要用于前端 UI 展示、日志记录或模糊检索。
|
||||
Returns: : 经由当前业务模型加工处理后所输出的具体数据实例或领域模型对象。 """
|
||||
Returns: : 经由当前业务模型加工处理后所输出的具体数据实例或领域模型对象。"""
|
||||
tools = self._global_tool_manager.tool_mapper.get(agent_name, {})
|
||||
# also include default tools
|
||||
default_tools = self._global_tool_manager.tool_mapper.get("default", {})
|
||||
merged_tools = {**default_tools, **tools}
|
||||
return merged_tools
|
||||
|
||||
# Workflow Template Manager Methods
|
||||
def get_all_workflow_templates(self):
|
||||
"""检索并获取特定的 all workflow templates 数据集合或实例对象。
|
||||
根据提供的查询条件或上下文凭证,从数据库、缓存或第三方服务中读取对应的资源状态。
|
||||
Returns: : 经由当前业务模型加工处理后所输出的具体数据实例或领域模型对象。 """
|
||||
return self._global_workflow_template_manager.get_all_workflow_templates()
|
||||
|
||||
def add_workflow_template(self, template_name: str, workflow_template):
|
||||
"""创建并持久化新的 workflow template 实体。
|
||||
接收构建参数,执行必要的数据校验与默认值填充后,将新记录安全地写入底层存储或系统注册表中。
|
||||
Args: template_name (str): 赋予该实体的人类可读名称或标题字符串,主要用于前端 UI 展示、日志记录或模糊检索。 workflow_template: 参与 add workflow template 逻辑运算或数据构建的上下文依赖对象。
|
||||
Returns: : 经由当前业务模型加工处理后所输出的具体数据实例或领域模型对象。 """
|
||||
return self._global_workflow_template_manager.add_workflow_template(template_name, workflow_template)
|
||||
|
||||
def delete_workflow_template(self, template_name: str):
|
||||
"""安全地移除或注销 workflow template。
|
||||
执行物理删除或逻辑删除操作,并妥善清理相关的关联数据及占用资源。
|
||||
Args: template_name (str): 赋予该实体的人类可读名称或标题字符串,主要用于前端 UI 展示、日志记录或模糊检索。
|
||||
Returns: : 经由当前业务模型加工处理后所输出的具体数据实例或领域模型对象。 """
|
||||
return self._global_workflow_template_manager.delete_workflow_template(template_name)
|
||||
|
||||
def generate_workflow_template(self, workflow_template):
|
||||
"""执行与 generate workflow template 相关的核心业务流转操作。
|
||||
该方法封装了具体的算法策略或状态控制逻辑,确保操作能够在事务上下文中被原子且一致地执行。
|
||||
Args: workflow_template: 参与 generate workflow template 逻辑运算或数据构建的上下文依赖对象。
|
||||
Returns: : 经由当前业务模型加工处理后所输出的具体数据实例或领域模型对象。 """
|
||||
return self._global_workflow_template_manager.generate_workflow_template(workflow_template)
|
||||
|
||||
# Skill Manager Methods
|
||||
def add_skill(self, skill_name: str):
|
||||
"""创建并持久化新的 skill 实体。
|
||||
接收构建参数,执行必要的数据校验与默认值填充后,将新记录安全地写入底层存储或系统注册表中。
|
||||
Args: skill_name (str): 赋予该实体的人类可读名称或标题字符串,主要用于前端 UI 展示、日志记录或模糊检索。
|
||||
Returns: : 经由当前业务模型加工处理后所输出的具体数据实例或领域模型对象。 """
|
||||
Returns: : 经由当前业务模型加工处理后所输出的具体数据实例或领域模型对象。"""
|
||||
return self._global_skill_manager.add_skill(skill_name)
|
||||
|
||||
def get_skill_list(self):
|
||||
"""检索并获取特定的 skill list 数据集合或实例对象。
|
||||
根据提供的查询条件或上下文凭证,从数据库、缓存或第三方服务中读取对应的资源状态。
|
||||
Returns: : 经由当前业务模型加工处理后所输出的具体数据实例或领域模型对象。 """
|
||||
Returns: : 经由当前业务模型加工处理后所输出的具体数据实例或领域模型对象。"""
|
||||
return self._global_skill_manager.get_skill_list()
|
||||
|
||||
def remove_skill(self, skill_name: str):
|
||||
"""安全地移除或注销 skill。
|
||||
执行物理删除或逻辑删除操作,并妥善清理相关的关联数据及占用资源。
|
||||
Args: skill_name (str): 赋予该实体的人类可读名称或标题字符串,主要用于前端 UI 展示、日志记录或模糊检索。
|
||||
Returns: : 经由当前业务模型加工处理后所输出的具体数据实例或领域模型对象。 """
|
||||
Returns: : 经由当前业务模型加工处理后所输出的具体数据实例或领域模型对象。"""
|
||||
return self._global_skill_manager.remove_skill(skill_name)
|
||||
|
||||
# Individual Manager Methods
|
||||
@@ -157,26 +141,25 @@ class GlobalStateMachine:
|
||||
"""创建并持久化新的 individual 实体。
|
||||
接收构建参数,执行必要的数据校验与默认值填充后,将新记录安全地写入底层存储或系统注册表中。
|
||||
Args: agent_id (str): 目标对象的唯一全局标识符 (UUID/ULID),用于在数据库表或缓存结构中精准匹配该 agent 实例。 config: 驱动该模块运行的核心配置字典或 Pydantic 数据模型,定义了重试策略、超时时间及模型参数等选项。
|
||||
Returns: : 经由当前业务模型加工处理后所输出的具体数据实例或领域模型对象。 """
|
||||
Returns: : 经由当前业务模型加工处理后所输出的具体数据实例或领域模型对象。"""
|
||||
return self._global_individual_manager.add_individual(agent_id, config)
|
||||
|
||||
def get_individual(self, agent_id: str):
|
||||
"""检索并获取特定的 individual 数据集合或实例对象。
|
||||
根据提供的查询条件或上下文凭证,从数据库、缓存或第三方服务中读取对应的资源状态。
|
||||
Args: agent_id (str): 目标对象的唯一全局标识符 (UUID/ULID),用于在数据库表或缓存结构中精准匹配该 agent 实例。
|
||||
Returns: : 经由当前业务模型加工处理后所输出的具体数据实例或领域模型对象。 """
|
||||
Returns: : 经由当前业务模型加工处理后所输出的具体数据实例或领域模型对象。"""
|
||||
return self._global_individual_manager.get_individual(agent_id)
|
||||
|
||||
def remove_individual(self, agent_id: str):
|
||||
"""安全地移除或注销 individual。
|
||||
执行物理删除或逻辑删除操作,并妥善清理相关的关联数据及占用资源。
|
||||
Args: agent_id (str): 目标对象的唯一全局标识符 (UUID/ULID),用于在数据库表或缓存结构中精准匹配该 agent 实例。
|
||||
Returns: : 经由当前业务模型加工处理后所输出的具体数据实例或领域模型对象。 """
|
||||
Returns: : 经由当前业务模型加工处理后所输出的具体数据实例或领域模型对象。"""
|
||||
return self._global_individual_manager.remove_individual(agent_id)
|
||||
|
||||
def list_individuals(self):
|
||||
"""执行与 list individuals 相关的核心业务流转操作。
|
||||
该方法封装了具体的算法策略或状态控制逻辑,确保操作能够在事务上下文中被原子且一致地执行。
|
||||
Returns: : 经由当前业务模型加工处理后所输出的具体数据实例或领域模型对象。 """
|
||||
Returns: : 经由当前业务模型加工处理后所输出的具体数据实例或领域模型对象。"""
|
||||
return self._global_individual_manager.list_individuals()
|
||||
|
||||
|
||||
@@ -1,187 +0,0 @@
|
||||
import ray
|
||||
import asyncio
|
||||
from typing import Dict
|
||||
from pretor.api.platform.event import PretorEvent
|
||||
from pretor.core.workflow.workflow import PretorWorkflow
|
||||
from pretor.utils.ray_hook import ray_actor_hook
|
||||
from pretor.utils.logger import get_logger
|
||||
|
||||
@ray.remote
|
||||
class GlobalWorkflowManager:
|
||||
def __init__(self):
|
||||
self.event_dict: Dict[str, PretorEvent] = {}
|
||||
self.event_object_refs: Dict[str, ray.ObjectRef] = {}
|
||||
self.postgres_database = None
|
||||
self.logger = get_logger("GlobalWorkflowManager")
|
||||
|
||||
async def init_manager(self):
|
||||
self.postgres_database = ray_actor_hook("postgres_database").postgres_database
|
||||
|
||||
# Load all events from database to memory
|
||||
try:
|
||||
records = await self.postgres_database.get_all_events.remote()
|
||||
for record in records:
|
||||
try:
|
||||
event = PretorEvent.model_validate_json(record.event_data_json)
|
||||
event.pending_queue = asyncio.Queue()
|
||||
event.receive_queue = asyncio.Queue()
|
||||
self.event_dict[event.trace_id] = event
|
||||
|
||||
# Store in ray object store for cache
|
||||
event_copy = event.model_copy()
|
||||
event_copy.pending_queue = None
|
||||
event_copy.receive_queue = None
|
||||
self.event_object_refs[event.trace_id] = ray.put(event_copy.model_dump_json())
|
||||
|
||||
except Exception as e:
|
||||
self.logger.error(f"Failed to load event {record.trace_id}: {e}")
|
||||
self.logger.info(f"Loaded {len(self.event_dict)} events from database")
|
||||
|
||||
# Trigger resumption of incomplete workflows
|
||||
workflow_running_engine = None
|
||||
for trace_id, event in self.event_dict.items():
|
||||
if event.workflow and event.workflow.status.status in ["waiting_llm_working", "waiting_tool_working", "llm_working", "tool_working"]:
|
||||
self.logger.info(f"Resuming incomplete workflow {trace_id}")
|
||||
if not workflow_running_engine:
|
||||
try:
|
||||
workflow_running_engine = ray_actor_hook("workflow_running_engine").workflow_running_engine
|
||||
except AttributeError:
|
||||
self.logger.warning("workflow_running_engine not found, cannot resume workflow")
|
||||
break
|
||||
await workflow_running_engine.resume_workflow.remote(event)
|
||||
|
||||
except Exception as e:
|
||||
self.logger.error(f"Failed to fetch events from database: {e}")
|
||||
|
||||
async def _upsert_event_to_db(self, event: PretorEvent):
|
||||
try:
|
||||
# Create a copy and remove non-serializable queues
|
||||
event_copy = event.model_copy()
|
||||
event_copy.pending_queue = None
|
||||
event_copy.receive_queue = None
|
||||
|
||||
event_json = event_copy.model_dump_json()
|
||||
# Update cache
|
||||
self.event_object_refs[event.trace_id] = ray.put(event_json)
|
||||
|
||||
await self.postgres_database.upsert_event.remote(
|
||||
event.trace_id,
|
||||
event_json
|
||||
)
|
||||
except Exception as e:
|
||||
self.logger.error(f"Failed to upsert event {event.trace_id} to database: {e}")
|
||||
|
||||
async def add_event(self, event: PretorEvent) -> None:
|
||||
event.pending_queue = asyncio.Queue()
|
||||
event.receive_queue = asyncio.Queue()
|
||||
self.event_dict[event.trace_id] = event
|
||||
await self._upsert_event_to_db(event)
|
||||
|
||||
async def delete_event(self, trace_id: str) -> None:
|
||||
if trace_id in self.event_dict:
|
||||
del self.event_dict[trace_id]
|
||||
if trace_id in self.event_object_refs:
|
||||
del self.event_object_refs[trace_id]
|
||||
try:
|
||||
await self.postgres_database.delete_event.remote(trace_id)
|
||||
except Exception as e:
|
||||
self.logger.error(f"Failed to delete event {trace_id} from database: {e}")
|
||||
|
||||
async def get_event(self, trace_id: str) -> PretorEvent | None:
|
||||
# First check memory dict
|
||||
if trace_id in self.event_dict:
|
||||
return self.event_dict[trace_id]
|
||||
|
||||
# Then check Ray object store cache
|
||||
if trace_id in self.event_object_refs:
|
||||
try:
|
||||
event_json = ray.get(self.event_object_refs[trace_id])
|
||||
return PretorEvent.model_validate_json(event_json)
|
||||
except Exception as e:
|
||||
self.logger.warning(f"Failed to fetch event from cache for trace {trace_id}: {e}")
|
||||
|
||||
# Fallback to database
|
||||
try:
|
||||
record = await self.postgres_database.get_event.remote(trace_id)
|
||||
if record:
|
||||
event = PretorEvent.model_validate_json(record.event_data_json)
|
||||
|
||||
# Restore to memory dict with missing transient queues
|
||||
event.pending_queue = asyncio.Queue()
|
||||
event.receive_queue = asyncio.Queue()
|
||||
self.event_dict[trace_id] = event
|
||||
|
||||
# Restore to cache
|
||||
event_copy = event.model_copy()
|
||||
event_copy.pending_queue = None
|
||||
event_copy.receive_queue = None
|
||||
self.event_object_refs[trace_id] = ray.put(event_copy.model_dump_json())
|
||||
|
||||
return event
|
||||
except Exception as e:
|
||||
self.logger.error(f"Failed to fetch event {trace_id} from database fallback: {e}")
|
||||
|
||||
return None
|
||||
|
||||
async def update_attachment(self, trace_id: str, attachment: Dict[str, str]) -> None:
|
||||
if trace_id in self.event_dict:
|
||||
self.event_dict[trace_id].attachment = attachment
|
||||
await self._upsert_event_to_db(self.event_dict[trace_id])
|
||||
|
||||
async def update_workflow(self, trace_id: str, workflow: PretorWorkflow) -> None:
|
||||
if trace_id in self.event_dict:
|
||||
self.event_dict[trace_id].workflow = workflow
|
||||
await self._upsert_event_to_db(self.event_dict[trace_id])
|
||||
|
||||
async def get_workflow(self, trace_id: str) -> PretorWorkflow | None:
|
||||
event = await self.get_event(trace_id)
|
||||
return event.workflow if event else None
|
||||
|
||||
async def list_events(self) -> list[dict]:
|
||||
result = []
|
||||
|
||||
# Read strictly from the database to ensure we get all events,
|
||||
# and ignore the cache to prevent frontend missing items.
|
||||
try:
|
||||
records = await self.postgres_database.get_all_events.remote()
|
||||
for record in records:
|
||||
try:
|
||||
event = PretorEvent.model_validate_json(record.event_data_json)
|
||||
workflow_title = event.workflow.title if event.workflow else None
|
||||
workflow_status = event.workflow.status.status if event.workflow and event.workflow.status else None
|
||||
result.append({
|
||||
"event_id": event.trace_id,
|
||||
"workflow_title": workflow_title,
|
||||
"status": workflow_status,
|
||||
"user_name": event.user_name,
|
||||
"message": event.message,
|
||||
"create_time": event.create_time,
|
||||
})
|
||||
# Best-effort cache population
|
||||
self.event_object_refs[event.trace_id] = ray.put(record.event_data_json)
|
||||
except Exception:
|
||||
continue
|
||||
except Exception as e:
|
||||
self.logger.error(f"Failed to list_events from DB: {e}")
|
||||
|
||||
return result
|
||||
|
||||
async def put_pending(self, trace_id, item) -> None:
|
||||
if trace_id in self.event_dict and self.event_dict[trace_id].pending_queue:
|
||||
await self.event_dict[trace_id].pending_queue.put(item)
|
||||
|
||||
async def get_pending(self, trace_id) -> str:
|
||||
if trace_id in self.event_dict and self.event_dict[trace_id].pending_queue:
|
||||
return await self.event_dict[trace_id].pending_queue.get()
|
||||
await asyncio.sleep(1) # Prevent CPU spinning if not found
|
||||
return ""
|
||||
|
||||
async def put_received(self, trace_id, item) -> None:
|
||||
if trace_id in self.event_dict and self.event_dict[trace_id].receive_queue:
|
||||
await self.event_dict[trace_id].receive_queue.put(item)
|
||||
|
||||
async def get_received(self, trace_id) -> str:
|
||||
if trace_id in self.event_dict and self.event_dict[trace_id].receive_queue:
|
||||
return await self.event_dict[trace_id].receive_queue.get()
|
||||
await asyncio.sleep(1) # Prevent CPU spinning if not found
|
||||
return ""
|
||||
@@ -14,11 +14,14 @@
|
||||
|
||||
from typing import Dict, Any
|
||||
from pretor.utils.logger import get_logger
|
||||
logger = get_logger('individual_manager')
|
||||
|
||||
logger = get_logger("individual_manager")
|
||||
|
||||
|
||||
class GlobalIndividualManager:
|
||||
"""GlobalIndividualManager 核心组件类。
|
||||
这是一个管理器类,职责集中在维护整个系统内有关 GlobalIndividual 资源的全局生命周期。它提供了注册机制、状态同步以及跨组件的统一查询入口,确保系统中该类型资源的实例一致性与可控性。 """
|
||||
这是一个管理器类,职责集中在维护整个系统内有关 GlobalIndividual 资源的全局生命周期。它提供了注册机制、状态同步以及跨组件的统一查询入口,确保系统中该类型资源的实例一致性与可控性。"""
|
||||
|
||||
def __init__(self):
|
||||
self._individuals: Dict[str, Dict[str, Any]] = {}
|
||||
|
||||
@@ -26,21 +29,31 @@ class GlobalIndividualManager:
|
||||
"""完成 individual register 模块的启动与依赖初始化。
|
||||
在系统引导或服务拉起阶段被调用,负责建立网络连接、分配基础内存资源及注册核心服务组件。
|
||||
Args: postgres: 参与 init individual register 逻辑运算或数据构建的上下文依赖对象。
|
||||
Returns: (None): 经由当前业务模型加工处理后所输出的具体数据实例或领域模型对象。 """
|
||||
Returns: (None): 经由当前业务模型加工处理后所输出的具体数据实例或领域模型对象。"""
|
||||
try:
|
||||
try:
|
||||
individuals = await postgres.get_all_worker_individual.remote()
|
||||
for ind in individuals:
|
||||
agent_id = getattr(ind, 'agent_id', None)
|
||||
agent_id = getattr(ind, "agent_id", None)
|
||||
if agent_id:
|
||||
self._individuals[agent_id] = ind.model_dump() if hasattr(ind, 'model_dump') else dict(ind)
|
||||
logger.info(f"成功从数据库拉取了 {len(self._individuals)} 个 Worker Individual 配置。")
|
||||
self._individuals[agent_id] = (
|
||||
ind.model_dump()
|
||||
if hasattr(ind, "model_dump")
|
||||
else dict(ind)
|
||||
)
|
||||
logger.info(
|
||||
f"成功从数据库拉取了 {len(self._individuals)} 个 Worker Individual 配置。"
|
||||
)
|
||||
except AttributeError:
|
||||
logger.warning("数据库中 get_all_worker_individual 方法未实现,跳过全量加载。可以在将来完善该接口。")
|
||||
logger.warning(
|
||||
"数据库中 get_all_worker_individual 方法未实现,跳过全量加载。可以在将来完善该接口。"
|
||||
)
|
||||
except Exception as e:
|
||||
# 捕获因 Ray 调用目标方法不存在引发的异常
|
||||
if "has no attribute 'get_all_worker_individual'" in str(e):
|
||||
logger.warning("数据库 individual_database 中缺少 get_all_worker_individual 方法,无法全量拉取。")
|
||||
logger.warning(
|
||||
"数据库 individual_database 中缺少 get_all_worker_individual 方法,无法全量拉取。"
|
||||
)
|
||||
else:
|
||||
raise e
|
||||
except Exception as e:
|
||||
@@ -64,12 +77,12 @@ class GlobalIndividualManager:
|
||||
"""安全地移除或注销 individual。
|
||||
执行物理删除或逻辑删除操作,并妥善清理相关的关联数据及占用资源。
|
||||
Args: agent_id (str): 目标对象的唯一全局标识符 (UUID/ULID),用于在数据库表或缓存结构中精准匹配该 agent 实例。
|
||||
Returns: (None): 经由当前业务模型加工处理后所输出的具体数据实例或领域模型对象。 """
|
||||
Returns: (None): 经由当前业务模型加工处理后所输出的具体数据实例或领域模型对象。"""
|
||||
if agent_id in self._individuals:
|
||||
del self._individuals[agent_id]
|
||||
|
||||
def list_individuals(self) -> Dict[str, Dict[str, Any]]:
|
||||
"""执行与 list individuals 相关的核心业务流转操作。
|
||||
该方法封装了具体的算法策略或状态控制逻辑,确保操作能够在事务上下文中被原子且一致地执行。
|
||||
Returns: (Dict[str, Dict[str, Any]]): 高度聚合的字典结构数据,将多维度的属性特征或统计指标组合后一并返回。 """
|
||||
Returns: (Dict[str, Dict[str, Any]]): 高度聚合的字典结构数据,将多维度的属性特征或统计指标组合后一并返回。"""
|
||||
return self._individuals
|
||||
|
||||
@@ -12,8 +12,24 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from pretor.core.global_state_machine.model_provider.base_provider import Provider, ProviderArgs
|
||||
from pretor.core.global_state_machine.model_provider.openai_provider import OpenAIProvider
|
||||
from pretor.core.global_state_machine.model_provider.claude_provider import ClaudeProvider
|
||||
from pretor.core.global_state_machine.model_provider.deepseek_provider import DeepseekProvider
|
||||
__all__ = ["Provider", "ProviderArgs", "OpenAIProvider", "ClaudeProvider", "DeepseekProvider"]
|
||||
from pretor.core.global_state_machine.model_provider.base_provider import (
|
||||
Provider,
|
||||
ProviderArgs,
|
||||
)
|
||||
from pretor.core.global_state_machine.model_provider.openai_provider import (
|
||||
OpenAIProvider,
|
||||
)
|
||||
from pretor.core.global_state_machine.model_provider.claude_provider import (
|
||||
ClaudeProvider,
|
||||
)
|
||||
from pretor.core.global_state_machine.model_provider.deepseek_provider import (
|
||||
DeepseekProvider,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"Provider",
|
||||
"ProviderArgs",
|
||||
"OpenAIProvider",
|
||||
"ClaudeProvider",
|
||||
"DeepseekProvider",
|
||||
]
|
||||
|
||||
@@ -17,15 +17,19 @@ from pydantic import BaseModel
|
||||
from typing import List
|
||||
from enum import Enum
|
||||
|
||||
|
||||
class ProviderStatus(str, Enum):
|
||||
"""ProviderStatus 核心组件类。
|
||||
这是一个模型/服务提供商适配器类,屏蔽了外部不同供应商(如 OpenAI、Anthropic 等)的底层 API 差异。它负责标准化参数组装、网络请求发送、鉴权处理以及响应结构的反序列化。 """
|
||||
这是一个模型/服务提供商适配器类,屏蔽了外部不同供应商(如 OpenAI、Anthropic 等)的底层 API 差异。它负责标准化参数组装、网络请求发送、鉴权处理以及响应结构的反序列化。"""
|
||||
|
||||
UP = "up"
|
||||
DOWN = "down"
|
||||
|
||||
|
||||
class Provider(BaseModel):
|
||||
"""Provider 核心组件类。
|
||||
这是一个模型/服务提供商适配器类,屏蔽了外部不同供应商(如 OpenAI、Anthropic 等)的底层 API 差异。它负责标准化参数组装、网络请求发送、鉴权处理以及响应结构的反序列化。 """
|
||||
这是一个模型/服务提供商适配器类,屏蔽了外部不同供应商(如 OpenAI、Anthropic 等)的底层 API 差异。它负责标准化参数组装、网络请求发送、鉴权处理以及响应结构的反序列化。"""
|
||||
|
||||
provider_title: str
|
||||
provider_url: str
|
||||
provider_apikey: str
|
||||
@@ -34,17 +38,21 @@ class Provider(BaseModel):
|
||||
provider_owner: str | None = None
|
||||
provider_status: ProviderStatus = ProviderStatus.UP
|
||||
|
||||
|
||||
class ProviderArgs(BaseModel):
|
||||
"""ProviderArgs 核心组件类。
|
||||
这是一个模型/服务提供商适配器类,屏蔽了外部不同供应商(如 OpenAI、Anthropic 等)的底层 API 差异。它负责标准化参数组装、网络请求发送、鉴权处理以及响应结构的反序列化。 """
|
||||
这是一个模型/服务提供商适配器类,屏蔽了外部不同供应商(如 OpenAI、Anthropic 等)的底层 API 差异。它负责标准化参数组装、网络请求发送、鉴权处理以及响应结构的反序列化。"""
|
||||
|
||||
provider_title: str
|
||||
provider_url: str
|
||||
provider_apikey: str
|
||||
provider_owner: str
|
||||
|
||||
|
||||
class BaseProvider(ABC):
|
||||
"""BaseProvider 核心组件类。
|
||||
这是一个模型/服务提供商适配器类,屏蔽了外部不同供应商(如 OpenAI、Anthropic 等)的底层 API 差异。它负责标准化参数组装、网络请求发送、鉴权处理以及响应结构的反序列化。 """
|
||||
这是一个模型/服务提供商适配器类,屏蔽了外部不同供应商(如 OpenAI、Anthropic 等)的底层 API 差异。它负责标准化参数组装、网络请求发送、鉴权处理以及响应结构的反序列化。"""
|
||||
|
||||
@staticmethod
|
||||
@abstractmethod
|
||||
async def create_provider(provider_args: ProviderArgs) -> Provider:
|
||||
@@ -83,7 +91,9 @@ class BaseProvider(ABC):
|
||||
|
||||
@staticmethod
|
||||
@abstractmethod
|
||||
def _return_provider(provider_args: ProviderArgs, provider_models: List[str]) -> Provider:
|
||||
def _return_provider(
|
||||
provider_args: ProviderArgs, provider_models: List[str]
|
||||
) -> Provider:
|
||||
"""
|
||||
包装Provider对象并返回
|
||||
将provider_args和_load_models获取的方法包装为provider对象
|
||||
@@ -100,5 +110,3 @@ class BaseProvider(ABC):
|
||||
返回一个Provider对象
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
|
||||
@@ -14,21 +14,29 @@
|
||||
|
||||
from pretor.utils.retry import retry_on_retryable_error
|
||||
|
||||
from pretor.core.global_state_machine.model_provider.base_provider import BaseProvider, Provider, ProviderArgs
|
||||
from pretor.core.global_state_machine.model_provider.base_provider import (
|
||||
BaseProvider,
|
||||
Provider,
|
||||
ProviderArgs,
|
||||
)
|
||||
import httpx
|
||||
from typing import List
|
||||
|
||||
|
||||
class ClaudeProvider(BaseProvider):
|
||||
"""ClaudeProvider 核心组件类。
|
||||
这是一个模型/服务提供商适配器类,屏蔽了外部不同供应商(如 OpenAI、Anthropic 等)的底层 API 差异。它负责标准化参数组装、网络请求发送、鉴权处理以及响应结构的反序列化。 """
|
||||
这是一个模型/服务提供商适配器类,屏蔽了外部不同供应商(如 OpenAI、Anthropic 等)的底层 API 差异。它负责标准化参数组装、网络请求发送、鉴权处理以及响应结构的反序列化。"""
|
||||
|
||||
@staticmethod
|
||||
async def create_provider(provider_args: ProviderArgs) -> Provider:
|
||||
"""创建并持久化新的 provider 实体。
|
||||
接收构建参数,执行必要的数据校验与默认值填充后,将新记录安全地写入底层存储或系统注册表中。
|
||||
Args: provider_args (ProviderArgs): 目标对象的唯一全局标识符 (UUID/ULID),用于在数据库表或缓存结构中精准匹配该 provider_args 实例。
|
||||
Returns: (Provider): 经由当前业务模型加工处理后所输出的具体数据实例或领域模型对象。 """
|
||||
Returns: (Provider): 经由当前业务模型加工处理后所输出的具体数据实例或领域模型对象。"""
|
||||
provider_models: List[str] = await ClaudeProvider._load_models(provider_args)
|
||||
provider: Provider = ClaudeProvider._return_provider(provider_args, provider_models)
|
||||
provider: Provider = ClaudeProvider._return_provider(
|
||||
provider_args, provider_models
|
||||
)
|
||||
return provider
|
||||
|
||||
@staticmethod
|
||||
@@ -38,11 +46,11 @@ class ClaudeProvider(BaseProvider):
|
||||
"""执行与 load models 相关的核心业务流转操作。
|
||||
该方法封装了具体的算法策略或状态控制逻辑,确保操作能够在事务上下文中被原子且一致地执行。
|
||||
Args: provider_args (ProviderArgs): 目标对象的唯一全局标识符 (UUID/ULID),用于在数据库表或缓存结构中精准匹配该 provider_args 实例。
|
||||
Returns: (List[str]): 经过筛选、排序或分页处理后的实体对象列表集合。 """
|
||||
Returns: (List[str]): 经过筛选、排序或分页处理后的实体对象列表集合。"""
|
||||
headers = {
|
||||
"x-api-key": provider_args.provider_apikey,
|
||||
"anthropic-version": "2023-06-01",
|
||||
"Content-Type": "application/json"
|
||||
"Content-Type": "application/json",
|
||||
}
|
||||
# 如果是官方 API,通常使用 /v1/models (如果支持)
|
||||
# 注意:很多时候 Anthropic 并不返回完整列表,如果请求失败,建议返回硬编码的常用模型
|
||||
@@ -57,19 +65,27 @@ class ClaudeProvider(BaseProvider):
|
||||
return sorted(model_ids)
|
||||
else:
|
||||
# 如果官方列表接口不可用,fallback 到已知常用模型
|
||||
return ["claude-3-5-sonnet-20240620", "claude-3-opus-20240229", "claude-3-haiku-20240307"]
|
||||
return [
|
||||
"claude-3-5-sonnet-20240620",
|
||||
"claude-3-opus-20240229",
|
||||
"claude-3-haiku-20240307",
|
||||
]
|
||||
except Exception as e:
|
||||
print(f"[{provider_args.provider_title}] 获取 Claude 模型列表错误: {e}")
|
||||
return []
|
||||
|
||||
@staticmethod
|
||||
def _return_provider(provider_args: ProviderArgs, provider_models: List[str]) -> Provider:
|
||||
def _return_provider(
|
||||
provider_args: ProviderArgs, provider_models: List[str]
|
||||
) -> Provider:
|
||||
"""执行与 return provider 相关的核心业务流转操作。
|
||||
该方法封装了具体的算法策略或状态控制逻辑,确保操作能够在事务上下文中被原子且一致地执行。
|
||||
Args: provider_args (ProviderArgs): 目标对象的唯一全局标识符 (UUID/ULID),用于在数据库表或缓存结构中精准匹配该 provider_args 实例。 provider_models (List[str]): 目标对象的唯一全局标识符 (UUID/ULID),用于在数据库表或缓存结构中精准匹配该 provider_models 实例。
|
||||
Returns: (Provider): 经由当前业务模型加工处理后所输出的具体数据实例或领域模型对象。 """
|
||||
return Provider(provider_title=provider_args.provider_title,
|
||||
provider_apikey=provider_args.provider_apikey,
|
||||
provider_url=provider_args.provider_url,
|
||||
provider_models=provider_models,
|
||||
provider_type="claude")
|
||||
Returns: (Provider): 经由当前业务模型加工处理后所输出的具体数据实例或领域模型对象。"""
|
||||
return Provider(
|
||||
provider_title=provider_args.provider_title,
|
||||
provider_apikey=provider_args.provider_apikey,
|
||||
provider_url=provider_args.provider_url,
|
||||
provider_models=provider_models,
|
||||
provider_type="claude",
|
||||
)
|
||||
|
||||
@@ -13,21 +13,29 @@
|
||||
# limitations under the License.
|
||||
|
||||
from pretor.utils.retry import retry_on_retryable_error
|
||||
from pretor.core.global_state_machine.model_provider.base_provider import BaseProvider, Provider, ProviderArgs
|
||||
from pretor.core.global_state_machine.model_provider.base_provider import (
|
||||
BaseProvider,
|
||||
Provider,
|
||||
ProviderArgs,
|
||||
)
|
||||
import httpx
|
||||
from typing import List
|
||||
|
||||
|
||||
class DeepseekProvider(BaseProvider):
|
||||
"""DeepseekProvider 核心组件类。
|
||||
这是一个模型/服务提供商适配器类,屏蔽了外部不同供应商(如 OpenAI、Anthropic 等)的底层 API 差异。它负责标准化参数组装、网络请求发送、鉴权处理以及响应结构的反序列化。 """
|
||||
这是一个模型/服务提供商适配器类,屏蔽了外部不同供应商(如 OpenAI、Anthropic 等)的底层 API 差异。它负责标准化参数组装、网络请求发送、鉴权处理以及响应结构的反序列化。"""
|
||||
|
||||
@staticmethod
|
||||
async def create_provider(provider_args: ProviderArgs) -> Provider:
|
||||
"""创建并持久化新的 provider 实体。
|
||||
接收构建参数,执行必要的数据校验与默认值填充后,将新记录安全地写入底层存储或系统注册表中。
|
||||
Args: provider_args (ProviderArgs): 目标对象的唯一全局标识符 (UUID/ULID),用于在数据库表或缓存结构中精准匹配该 provider_args 实例。
|
||||
Returns: (Provider): 经由当前业务模型加工处理后所输出的具体数据实例或领域模型对象。 """
|
||||
Returns: (Provider): 经由当前业务模型加工处理后所输出的具体数据实例或领域模型对象。"""
|
||||
provider_models: List[str] = await DeepseekProvider._load_models(provider_args)
|
||||
provider: Provider = DeepseekProvider._return_provider(provider_args, provider_models)
|
||||
provider: Provider = DeepseekProvider._return_provider(
|
||||
provider_args, provider_models
|
||||
)
|
||||
return provider
|
||||
|
||||
@staticmethod
|
||||
@@ -36,17 +44,23 @@ class DeepseekProvider(BaseProvider):
|
||||
"""执行与 load models 相关的核心业务流转操作。
|
||||
该方法封装了具体的算法策略或状态控制逻辑,确保操作能够在事务上下文中被原子且一致地执行。
|
||||
Args: provider_args (ProviderArgs): 目标对象的唯一全局标识符 (UUID/ULID),用于在数据库表或缓存结构中精准匹配该 provider_args 实例。
|
||||
Returns: (List[str]): 经过筛选、排序或分页处理后的实体对象列表集合。 """
|
||||
Returns: (List[str]): 经过筛选、排序或分页处理后的实体对象列表集合。"""
|
||||
headers = {
|
||||
"Authorization": f"Bearer {provider_args.provider_apikey}",
|
||||
"Content-Type": "application/json"
|
||||
"Content-Type": "application/json",
|
||||
}
|
||||
url = f"{provider_args.provider_url}/models" if "/v1" in provider_args.provider_url else f"{provider_args.provider_url}/v1/models"
|
||||
url = (
|
||||
f"{provider_args.provider_url}/models"
|
||||
if "/v1" in provider_args.provider_url
|
||||
else f"{provider_args.provider_url}/v1/models"
|
||||
)
|
||||
try:
|
||||
async with httpx.AsyncClient(timeout=10.0) as client:
|
||||
response = await client.get(url, headers=headers)
|
||||
if response.status_code != 200:
|
||||
print(f"[{provider_args.provider_title}] 获取模型失败: {response.status_code}")
|
||||
print(
|
||||
f"[{provider_args.provider_title}] 获取模型失败: {response.status_code}"
|
||||
)
|
||||
return []
|
||||
data = response.json()
|
||||
raw_models = data.get("data", [])
|
||||
@@ -54,20 +68,27 @@ class DeepseekProvider(BaseProvider):
|
||||
return sorted(model_ids)
|
||||
except httpx.RequestError as e:
|
||||
from pretor.utils.error import RetryableError
|
||||
|
||||
print(f"[{provider_args.provider_title}] 网络请求异常: {e}")
|
||||
raise RetryableError(f"[{provider_args.provider_title}] 网络请求异常: {e}") from e
|
||||
raise RetryableError(
|
||||
f"[{provider_args.provider_title}] 网络请求异常: {e}"
|
||||
) from e
|
||||
except Exception as e:
|
||||
print(f"[{provider_args.provider_title}] 解析模型列表时发生错误: {e}")
|
||||
return []
|
||||
|
||||
@staticmethod
|
||||
def _return_provider(provider_args: ProviderArgs, provider_models: List[str]) -> Provider:
|
||||
def _return_provider(
|
||||
provider_args: ProviderArgs, provider_models: List[str]
|
||||
) -> Provider:
|
||||
"""执行与 return provider 相关的核心业务流转操作。
|
||||
该方法封装了具体的算法策略或状态控制逻辑,确保操作能够在事务上下文中被原子且一致地执行。
|
||||
Args: provider_args (ProviderArgs): 目标对象的唯一全局标识符 (UUID/ULID),用于在数据库表或缓存结构中精准匹配该 provider_args 实例。 provider_models (List[str]): 目标对象的唯一全局标识符 (UUID/ULID),用于在数据库表或缓存结构中精准匹配该 provider_models 实例。
|
||||
Returns: (Provider): 经由当前业务模型加工处理后所输出的具体数据实例或领域模型对象。 """
|
||||
return Provider(provider_title=provider_args.provider_title,
|
||||
provider_apikey=provider_args.provider_apikey,
|
||||
provider_url=provider_args.provider_url,
|
||||
provider_models=provider_models,
|
||||
provider_type="deepseek")
|
||||
Returns: (Provider): 经由当前业务模型加工处理后所输出的具体数据实例或领域模型对象。"""
|
||||
return Provider(
|
||||
provider_title=provider_args.provider_title,
|
||||
provider_apikey=provider_args.provider_apikey,
|
||||
provider_url=provider_args.provider_url,
|
||||
provider_models=provider_models,
|
||||
provider_type="deepseek",
|
||||
)
|
||||
|
||||
@@ -13,21 +13,29 @@
|
||||
# limitations under the License.
|
||||
|
||||
from pretor.utils.retry import retry_on_retryable_error
|
||||
from pretor.core.global_state_machine.model_provider.base_provider import BaseProvider, Provider, ProviderArgs
|
||||
from pretor.core.global_state_machine.model_provider.base_provider import (
|
||||
BaseProvider,
|
||||
Provider,
|
||||
ProviderArgs,
|
||||
)
|
||||
import httpx
|
||||
from typing import List
|
||||
|
||||
|
||||
class OpenAIProvider(BaseProvider):
|
||||
"""OpenAIProvider 核心组件类。
|
||||
这是一个模型/服务提供商适配器类,屏蔽了外部不同供应商(如 OpenAI、Anthropic 等)的底层 API 差异。它负责标准化参数组装、网络请求发送、鉴权处理以及响应结构的反序列化。 """
|
||||
这是一个模型/服务提供商适配器类,屏蔽了外部不同供应商(如 OpenAI、Anthropic 等)的底层 API 差异。它负责标准化参数组装、网络请求发送、鉴权处理以及响应结构的反序列化。"""
|
||||
|
||||
@staticmethod
|
||||
async def create_provider(provider_args: ProviderArgs) -> Provider:
|
||||
"""创建并持久化新的 provider 实体。
|
||||
接收构建参数,执行必要的数据校验与默认值填充后,将新记录安全地写入底层存储或系统注册表中。
|
||||
Args: provider_args (ProviderArgs): 目标对象的唯一全局标识符 (UUID/ULID),用于在数据库表或缓存结构中精准匹配该 provider_args 实例。
|
||||
Returns: (Provider): 经由当前业务模型加工处理后所输出的具体数据实例或领域模型对象。 """
|
||||
Returns: (Provider): 经由当前业务模型加工处理后所输出的具体数据实例或领域模型对象。"""
|
||||
provider_models: List[str] = await OpenAIProvider._load_models(provider_args)
|
||||
provider: Provider = OpenAIProvider._return_provider(provider_args, provider_models)
|
||||
provider: Provider = OpenAIProvider._return_provider(
|
||||
provider_args, provider_models
|
||||
)
|
||||
return provider
|
||||
|
||||
@staticmethod
|
||||
@@ -36,17 +44,23 @@ class OpenAIProvider(BaseProvider):
|
||||
"""执行与 load models 相关的核心业务流转操作。
|
||||
该方法封装了具体的算法策略或状态控制逻辑,确保操作能够在事务上下文中被原子且一致地执行。
|
||||
Args: provider_args (ProviderArgs): 目标对象的唯一全局标识符 (UUID/ULID),用于在数据库表或缓存结构中精准匹配该 provider_args 实例。
|
||||
Returns: (List[str]): 经过筛选、排序或分页处理后的实体对象列表集合。 """
|
||||
Returns: (List[str]): 经过筛选、排序或分页处理后的实体对象列表集合。"""
|
||||
headers = {
|
||||
"Authorization": f"Bearer {provider_args.provider_apikey}",
|
||||
"Content-Type": "application/json"
|
||||
"Content-Type": "application/json",
|
||||
}
|
||||
url = f"{provider_args.provider_url}/models" if "/v1" in provider_args.provider_url else f"{provider_args.provider_url}/v1/models"
|
||||
url = (
|
||||
f"{provider_args.provider_url}/models"
|
||||
if "/v1" in provider_args.provider_url
|
||||
else f"{provider_args.provider_url}/v1/models"
|
||||
)
|
||||
try:
|
||||
async with httpx.AsyncClient(timeout=10.0) as client:
|
||||
response = await client.get(url, headers=headers)
|
||||
if response.status_code != 200:
|
||||
print(f"[{provider_args.provider_title}] 获取模型失败: {response.status_code}")
|
||||
print(
|
||||
f"[{provider_args.provider_title}] 获取模型失败: {response.status_code}"
|
||||
)
|
||||
return []
|
||||
data = response.json()
|
||||
raw_models = data.get("data", [])
|
||||
@@ -54,20 +68,27 @@ class OpenAIProvider(BaseProvider):
|
||||
return sorted(model_ids)
|
||||
except httpx.RequestError as e:
|
||||
from pretor.utils.error import RetryableError
|
||||
|
||||
print(f"[{provider_args.provider_title}] 网络请求异常: {e}")
|
||||
raise RetryableError(f"[{provider_args.provider_title}] 网络请求异常: {e}") from e
|
||||
raise RetryableError(
|
||||
f"[{provider_args.provider_title}] 网络请求异常: {e}"
|
||||
) from e
|
||||
except Exception as e:
|
||||
print(f"[{provider_args.provider_title}] 解析模型列表时发生错误: {e}")
|
||||
return []
|
||||
|
||||
@staticmethod
|
||||
def _return_provider(provider_args: ProviderArgs, provider_models: List[str]) -> Provider:
|
||||
def _return_provider(
|
||||
provider_args: ProviderArgs, provider_models: List[str]
|
||||
) -> Provider:
|
||||
"""执行与 return provider 相关的核心业务流转操作。
|
||||
该方法封装了具体的算法策略或状态控制逻辑,确保操作能够在事务上下文中被原子且一致地执行。
|
||||
Args: provider_args (ProviderArgs): 目标对象的唯一全局标识符 (UUID/ULID),用于在数据库表或缓存结构中精准匹配该 provider_args 实例。 provider_models (List[str]): 目标对象的唯一全局标识符 (UUID/ULID),用于在数据库表或缓存结构中精准匹配该 provider_models 实例。
|
||||
Returns: (Provider): 经由当前业务模型加工处理后所输出的具体数据实例或领域模型对象。 """
|
||||
return Provider(provider_title=provider_args.provider_title,
|
||||
provider_apikey=provider_args.provider_apikey,
|
||||
provider_url=provider_args.provider_url,
|
||||
provider_models=provider_models,
|
||||
provider_type="openai")
|
||||
Returns: (Provider): 经由当前业务模型加工处理后所输出的具体数据实例或领域模型对象。"""
|
||||
return Provider(
|
||||
provider_title=provider_args.provider_title,
|
||||
provider_apikey=provider_args.provider_apikey,
|
||||
provider_url=provider_args.provider_url,
|
||||
provider_models=provider_models,
|
||||
provider_type="openai",
|
||||
)
|
||||
|
||||
@@ -12,51 +12,73 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from pretor.core.global_state_machine.model_provider import Provider, OpenAIProvider, ClaudeProvider, DeepseekProvider
|
||||
from pretor.core.global_state_machine.model_provider import (
|
||||
Provider,
|
||||
OpenAIProvider,
|
||||
ClaudeProvider,
|
||||
DeepseekProvider,
|
||||
)
|
||||
from typing import Dict, Type
|
||||
|
||||
|
||||
class ProviderManager:
|
||||
"""
|
||||
模型供应商管理器 (ProviderManager)。
|
||||
负责维护不同的 LLM 协议适配器,提供从配置注册到 Agent 实例化的全生命周期管理。
|
||||
"""
|
||||
|
||||
# --- 类属性显式标注 (IDE 友好) ---
|
||||
provider_mapper: Dict[str, Type[Provider]]
|
||||
"""协议映射表:键为协议名(如 'openai'),值为对应的 Provider 类。"""
|
||||
|
||||
provider_register: Dict[str, Provider]
|
||||
"""供应商注册表:键为用户自定义别名,值为已实例化的 Provider 对象。"""
|
||||
|
||||
def __init__(self, postgres):
|
||||
self.provider_mapper = {"openai": OpenAIProvider,
|
||||
"claude": ClaudeProvider,
|
||||
"deepseek": DeepseekProvider}
|
||||
self.provider_mapper = {
|
||||
"openai": OpenAIProvider,
|
||||
"claude": ClaudeProvider,
|
||||
"deepseek": DeepseekProvider,
|
||||
}
|
||||
self.provider_register = {}
|
||||
|
||||
async def init_provider_register(self, postgres) -> None:
|
||||
"""完成 provider register 模块的启动与依赖初始化。
|
||||
在系统引导或服务拉起阶段被调用,负责建立网络连接、分配基础内存资源及注册核心服务组件。
|
||||
Args: postgres: 参与 init provider register 逻辑运算或数据构建的上下文依赖对象。
|
||||
Returns: (None): 经由当前业务模型加工处理后所输出的具体数据实例或领域模型对象。 """
|
||||
Returns: (None): 经由当前业务模型加工处理后所输出的具体数据实例或领域模型对象。"""
|
||||
providers = await postgres.get_provider.remote()
|
||||
for provider in providers:
|
||||
self.provider_register[provider.provider_title] = provider
|
||||
|
||||
async def add_provider(self, provider_type, provider_title, provider_url, provider_apikey, provider_owner, postgres_database) -> None:
|
||||
async def add_provider(
|
||||
self,
|
||||
provider_type,
|
||||
provider_title,
|
||||
provider_url,
|
||||
provider_apikey,
|
||||
provider_owner,
|
||||
postgres_database,
|
||||
) -> None:
|
||||
"""创建并持久化新的 provider 实体。
|
||||
接收构建参数,执行必要的数据校验与默认值填充后,将新记录安全地写入底层存储或系统注册表中。
|
||||
Args: provider_type: 目标对象的唯一全局标识符 (UUID/ULID),用于在数据库表或缓存结构中精准匹配该 provider_type 实例。 provider_title: 目标对象的唯一全局标识符 (UUID/ULID),用于在数据库表或缓存结构中精准匹配该 provider_title 实例。 provider_url: 目标对象的唯一全局标识符 (UUID/ULID),用于在数据库表或缓存结构中精准匹配该 provider_url 实例。 provider_apikey: 目标对象的唯一全局标识符 (UUID/ULID),用于在数据库表或缓存结构中精准匹配该 provider_apikey 实例。 provider_owner: 目标对象的唯一全局标识符 (UUID/ULID),用于在数据库表或缓存结构中精准匹配该 provider_owner 实例。 postgres_database: 从客户端传递过来或由上游组件生成的核心业务数据体,通常需要进一步的清洗和结构化解析。
|
||||
Returns: (None): 经由当前业务模型加工处理后所输出的具体数据实例或领域模型对象。 """
|
||||
Returns: (None): 经由当前业务模型加工处理后所输出的具体数据实例或领域模型对象。"""
|
||||
from pretor.core.global_state_machine.model_provider import ProviderArgs
|
||||
from pretor.utils.logger import get_logger
|
||||
logger = get_logger('provider_manager')
|
||||
|
||||
logger = get_logger("provider_manager")
|
||||
import httpx
|
||||
|
||||
provider_args: ProviderArgs = ProviderArgs(provider_title=provider_title,
|
||||
provider_url=provider_url,
|
||||
provider_apikey=provider_apikey,
|
||||
provider_owner=provider_owner)
|
||||
provider_args: ProviderArgs = ProviderArgs(
|
||||
provider_title=provider_title,
|
||||
provider_url=provider_url,
|
||||
provider_apikey=provider_apikey,
|
||||
provider_owner=provider_owner,
|
||||
)
|
||||
try:
|
||||
import ulid
|
||||
|
||||
provider_class = self.provider_mapper.get(provider_type, None)
|
||||
if provider_class is None:
|
||||
logger.warning(f"Provider type {provider_type} is not supported.")
|
||||
@@ -65,41 +87,49 @@ class ProviderManager:
|
||||
provider.provider_owner = provider_owner
|
||||
self.provider_register[provider_title] = provider
|
||||
await postgres_database.add_provider_db.remote(
|
||||
provider_id=str(ulid.ULID()),
|
||||
provider_title=provider.provider_title,
|
||||
provider_url=provider.provider_url,
|
||||
provider_apikey=provider.provider_apikey,
|
||||
provider_models=provider.provider_models,
|
||||
provider_type=provider.provider_type,
|
||||
provider_owner=provider.provider_owner)
|
||||
provider_id=str(ulid.ULID()),
|
||||
provider_title=provider.provider_title,
|
||||
provider_url=provider.provider_url,
|
||||
provider_apikey=provider.provider_apikey,
|
||||
provider_models=provider.provider_models,
|
||||
provider_type=provider.provider_type,
|
||||
provider_owner=provider.provider_owner,
|
||||
)
|
||||
|
||||
logger.info(f"已添加适配器{provider_title}")
|
||||
except httpx.RequestError as e:
|
||||
from pretor.utils.error import RetryableError
|
||||
|
||||
logger.warning(f"[{provider_args.provider_title}] 网络请求异常: {e}")
|
||||
raise RetryableError(f"[{provider_args.provider_title}] 网络请求异常: {e}") from e
|
||||
raise RetryableError(
|
||||
f"[{provider_args.provider_title}] 网络请求异常: {e}"
|
||||
) from e
|
||||
except Exception as e:
|
||||
logger.warning(f"[{provider_args.provider_title}] 解析模型列表时发生错误: {e}")
|
||||
logger.warning(
|
||||
f"[{provider_args.provider_title}] 解析模型列表时发生错误: {e}"
|
||||
)
|
||||
|
||||
def get_provider_list(self):
|
||||
"""检索并获取特定的 provider list 数据集合或实例对象。
|
||||
根据提供的查询条件或上下文凭证,从数据库、缓存或第三方服务中读取对应的资源状态。
|
||||
Returns: : 经由当前业务模型加工处理后所输出的具体数据实例或领域模型对象。 """
|
||||
Returns: : 经由当前业务模型加工处理后所输出的具体数据实例或领域模型对象。"""
|
||||
return self.provider_register
|
||||
|
||||
def get_provider(self, provider_title):
|
||||
"""检索并获取特定的 provider 数据集合或实例对象。
|
||||
根据提供的查询条件或上下文凭证,从数据库、缓存或第三方服务中读取对应的资源状态。
|
||||
Args: provider_title: 目标对象的唯一全局标识符 (UUID/ULID),用于在数据库表或缓存结构中精准匹配该 provider_title 实例。
|
||||
Returns: : 经由当前业务模型加工处理后所输出的具体数据实例或领域模型对象。 """
|
||||
Returns: : 经由当前业务模型加工处理后所输出的具体数据实例或领域模型对象。"""
|
||||
return self.provider_register.get(provider_title)
|
||||
|
||||
async def delete_provider(self, provider_title: str, postgres_database) -> None:
|
||||
"""安全地移除或注销 provider。
|
||||
执行物理删除或逻辑删除操作,并妥善清理相关的关联数据及占用资源。
|
||||
Args: provider_title (str): 目标对象的唯一全局标识符 (UUID/ULID),用于在数据库表或缓存结构中精准匹配该 provider_title 实例。 postgres_database: 从客户端传递过来或由上游组件生成的核心业务数据体,通常需要进一步的清洗和结构化解析。
|
||||
Returns: (None): 经由当前业务模型加工处理后所输出的具体数据实例或领域模型对象。 """
|
||||
Returns: (None): 经由当前业务模型加工处理后所输出的具体数据实例或领域模型对象。"""
|
||||
if provider_title in self.provider_register:
|
||||
provider = self.provider_register[provider_title]
|
||||
await postgres_database.delete_provider_db.remote( provider_id=provider.provider_id)
|
||||
del self.provider_register[provider_title]
|
||||
await postgres_database.delete_provider_db.remote(
|
||||
provider_id=provider.provider_id
|
||||
)
|
||||
del self.provider_register[provider_title]
|
||||
|
||||
@@ -17,22 +17,29 @@ from collections import defaultdict
|
||||
import pathlib
|
||||
import json
|
||||
|
||||
|
||||
class GlobalSkillManager:
|
||||
"""GlobalSkillManager 核心组件类。
|
||||
这是一个管理器类,职责集中在维护整个系统内有关 GlobalSkill 资源的全局生命周期。它提供了注册机制、状态同步以及跨组件的统一查询入口,确保系统中该类型资源的实例一致性与可控性。 """
|
||||
skill_mapper = Dict[str,Tuple[str]]
|
||||
这是一个管理器类,职责集中在维护整个系统内有关 GlobalSkill 资源的全局生命周期。它提供了注册机制、状态同步以及跨组件的统一查询入口,确保系统中该类型资源的实例一致性与可控性。"""
|
||||
|
||||
skill_mapper = Dict[str, Tuple[str]]
|
||||
"""skill的存储表"""
|
||||
|
||||
def __init__(self):
|
||||
self.skill_mapper = defaultdict(tuple)
|
||||
|
||||
import os
|
||||
skill_plugin_dir = pathlib.Path(os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "..", "plugin", "skill")))
|
||||
|
||||
skill_plugin_dir = pathlib.Path(
|
||||
os.path.abspath(
|
||||
os.path.join(os.path.dirname(__file__), "..", "..", "plugin", "skill")
|
||||
)
|
||||
)
|
||||
if not skill_plugin_dir.exists() or not skill_plugin_dir.is_dir():
|
||||
return
|
||||
for item in skill_plugin_dir.iterdir():
|
||||
if item.is_dir() and not item.name.startswith((".", "__")):
|
||||
json_path = item / "skill.json" # 拼接文件路径
|
||||
json_path = item / "skill.json" # 拼接文件路径
|
||||
if json_path.exists():
|
||||
try:
|
||||
with open(json_path, "r", encoding="utf-8") as f:
|
||||
@@ -42,7 +49,7 @@ class GlobalSkillManager:
|
||||
if name:
|
||||
self.skill_mapper[name] = (
|
||||
skill.get("description", ""),
|
||||
skill.get("instructions", "")
|
||||
skill.get("instructions", ""),
|
||||
)
|
||||
except (json.JSONDecodeError, OSError) as e:
|
||||
print(f"警告: 加载插件 {item.name} 失败: {e}")
|
||||
@@ -50,7 +57,12 @@ class GlobalSkillManager:
|
||||
def add_skill(self, skill_name: str) -> None:
|
||||
"""Add a skill to the manager by reading its skill.json from the path"""
|
||||
import os
|
||||
skill_plugin_dir = pathlib.Path(os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "..", "plugin", "skill")))
|
||||
|
||||
skill_plugin_dir = pathlib.Path(
|
||||
os.path.abspath(
|
||||
os.path.join(os.path.dirname(__file__), "..", "..", "plugin", "skill")
|
||||
)
|
||||
)
|
||||
item = skill_plugin_dir / skill_name
|
||||
if item.is_dir() and not item.name.startswith((".", "__")):
|
||||
json_path = item / "skill.json"
|
||||
@@ -62,7 +74,7 @@ class GlobalSkillManager:
|
||||
if name:
|
||||
self.skill_mapper[name] = (
|
||||
skill.get("description", ""),
|
||||
skill.get("instructions", "")
|
||||
skill.get("instructions", ""),
|
||||
)
|
||||
except (json.JSONDecodeError, OSError) as e:
|
||||
print(f"警告: 加载插件 {item.name} 失败: {e}")
|
||||
|
||||
@@ -19,17 +19,22 @@ from collections import defaultdict
|
||||
from pretor.plugin.tool_plugin.base_tool import BaseToolData
|
||||
from typing import Dict, Type
|
||||
from pretor.utils.logger import get_logger
|
||||
logger = get_logger('tool_manager')
|
||||
|
||||
logger = get_logger("tool_manager")
|
||||
|
||||
|
||||
class GlobalToolManager:
|
||||
"""GlobalToolManager 核心组件类。
|
||||
这是一个管理器类,职责集中在维护整个系统内有关 GlobalTool 资源的全局生命周期。它提供了注册机制、状态同步以及跨组件的统一查询入口,确保系统中该类型资源的实例一致性与可控性。 """
|
||||
这是一个管理器类,职责集中在维护整个系统内有关 GlobalTool 资源的全局生命周期。它提供了注册机制、状态同步以及跨组件的统一查询入口,确保系统中该类型资源的实例一致性与可控性。"""
|
||||
|
||||
tool_mapper: Dict[str, Dict[str, Type[BaseToolData]]]
|
||||
|
||||
def __init__(self):
|
||||
self.tool_mapper = defaultdict(dict)
|
||||
|
||||
tool_plugin_dir = pathlib.Path(__file__).parent.parent.parent / "plugin" / "tool_plugin"
|
||||
tool_plugin_dir = (
|
||||
pathlib.Path(__file__).parent.parent.parent / "plugin" / "tool_plugin"
|
||||
)
|
||||
if not tool_plugin_dir.exists() or not tool_plugin_dir.is_dir():
|
||||
return
|
||||
|
||||
@@ -51,4 +56,4 @@ class GlobalToolManager:
|
||||
for scope in action_scopes:
|
||||
self.tool_mapper[scope][plugin_name] = obj
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to load tool plugin {plugin_name}: {e}")
|
||||
logger.warning(f"Failed to load tool plugin {plugin_name}: {e}")
|
||||
|
||||
Reference in New Issue
Block a user