From be0d78c8556d9986b75c00257e323a8266827736 Mon Sep 17 00:00:00 2001 From: zhaoxi Date: Wed, 22 Apr 2026 11:36:59 +0800 Subject: [PATCH] =?UTF-8?q?wip:=20=E5=A2=9E=E5=8A=A0=E4=BA=86api=E7=9A=84?= =?UTF-8?q?=E7=AE=80=E5=8D=95=E6=9D=83=E9=99=90=E6=A0=A1=E9=AA=8C=EF=BC=8C?= =?UTF-8?q?=E4=BD=BF=E7=94=A8=E5=8F=8D=E5=B0=84=E5=AF=B9=E4=BA=8Eactor?= =?UTF-8?q?=E8=BF=9B=E8=A1=8C=E9=87=8D=E6=9E=84?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- pretor/api/agent.py | 21 +-- pretor/api/auth.py | 4 +- pretor/api/provider.py | 16 +- pretor/api/resource.py | 10 +- pretor/core/database/postgres.py | 51 ++---- pretor/core/database/table/provider.py | 4 +- .../global_state_machine.py | 159 +++++------------- .../global_state_machine/provider_manager.py | 43 ++++- .../consciousness_node/consciousness_node.py | 4 +- .../individual/control_node/control_node.py | 2 +- .../supervisory_node/supervisory_node.py | 4 +- pretor/core/workflow/workflow_runner.py | 6 +- .../workflow/workflow_template_manager.py | 5 +- pretor/utils/check_user/role_check.py | 2 +- pretor/utils/get_tool.py | 2 +- tests/core/database/postgres_test.py | 2 + tests/core/database/table/table_user_test.py | 2 +- .../global_state_machine_test.py | 41 +++-- .../provider_manager_test.py | 2 + tests/core/workflow/workflow_runner_test.py | 20 +-- tests/core/workflow/workflow_test.py | 8 +- 21 files changed, 173 insertions(+), 235 deletions(-) diff --git a/pretor/api/agent.py b/pretor/api/agent.py index ec84a7d..79f1fe6 100644 --- a/pretor/api/agent.py +++ b/pretor/api/agent.py @@ -21,6 +21,8 @@ from pretor.utils.access import Accessor, TokenData from pretor.core.database.table.individual import AgentType from fastapi import HTTPException from typing import Optional, List, Dict +from pretor.utils.check_user.role_check import RoleChecker +from pretor.core.database.table.user import UserAuthority agent_router = APIRouter(prefix="/api/v1/agent", tags=["agent"]) @@ -35,7 +37,7 @@ class AgentLocalRegister(BaseModel): @agent_router.post("") async def load_agent(agent_register: Union[AgentRegister, AgentLocalRegister], - _: TokenData = Depends(Accessor.get_current_user)): + _: TokenData = Depends(RoleChecker(allowed_roles=UserAuthority.USER))): global_state_machine = ray_actor_hook("global_state_machine") if isinstance(agent_register, AgentLocalRegister): pass @@ -82,18 +84,18 @@ class WorkerIndividualUpdate(BaseModel): @agent_router.post("/worker") async def create_worker_individual(worker_data: WorkerIndividualCreate, - token_data: TokenData = Depends(Accessor.get_current_user)): + token_data: TokenData = Depends(RoleChecker(allowed_roles=UserAuthority.USER))): postgres_database = ray_actor_hook("postgres_database") data_dict = worker_data.model_dump() data_dict["owner_id"] = token_data.user_id - worker = await postgres_database.add_worker_individual.remote(**data_dict) + worker = await postgres_database.individual_database.remote("add_worker_individual", **data_dict) return {"message": "success", "agent_id": worker.agent_id} @agent_router.get("/worker") async def get_worker_individual_list(token_data: TokenData = Depends(Accessor.get_current_user)): postgres_database = ray_actor_hook("postgres_database") - workers = await postgres_database.get_worker_individual_list.remote(owner_id=token_data.user_id) + workers = await postgres_database.individual_database.remote("get_worker_individual_list", owner_id=token_data.user_id) return {"workers": workers} @@ -101,7 +103,7 @@ async def get_worker_individual_list(token_data: TokenData = Depends(Accessor.ge async def get_worker_individual(agent_id: str, token_data: TokenData = Depends(Accessor.get_current_user)): postgres_database = ray_actor_hook("postgres_database") - worker = await postgres_database.get_worker_individual.remote(agent_id=agent_id) + worker = await postgres_database.individual_database.remote("get_worker_individual", agent_id=agent_id) if not worker: raise HTTPException(status_code=404, detail="Agent not found") if worker.owner_id != token_data.user_id: @@ -114,14 +116,14 @@ async def update_worker_individual(agent_id: str, worker_data: WorkerIndividualUpdate, token_data: TokenData = Depends(Accessor.get_current_user)): postgres_database = ray_actor_hook("postgres_database") - worker = await postgres_database.get_worker_individual.remote(agent_id=agent_id) + worker = await postgres_database.individual_database.remote("get_worker_individual", agent_id=agent_id) if not worker: raise HTTPException(status_code=404, detail="Agent not found") if worker.owner_id != token_data.user_id: raise HTTPException(status_code=403, detail="Forbidden: You do not own this agent") update_data = worker_data.model_dump(exclude_unset=True) - updated_worker = await postgres_database.update_worker_individual.remote(agent_id=agent_id, **update_data) + updated_worker = await postgres_database.individual_database.remote("update_worker_individual", agent_id=agent_id, **update_data) return {"message": "success", "worker": updated_worker} @@ -129,11 +131,10 @@ async def update_worker_individual(agent_id: str, async def delete_worker_individual(agent_id: str, token_data: TokenData = Depends(Accessor.get_current_user)): postgres_database = ray_actor_hook("postgres_database") - worker = await postgres_database.get_worker_individual.remote(agent_id=agent_id) + worker = await postgres_database.individual_database.remote("get_worker_individual", agent_id=agent_id) if not worker: raise HTTPException(status_code=404, detail="Agent not found") if worker.owner_id != token_data.user_id: raise HTTPException(status_code=403, detail="Forbidden: You do not own this agent") - - await postgres_database.delete_worker_individual.remote(agent_id=agent_id) + await postgres_database.individual_database.remote("delete_worker_individual", agent_id=agent_id) return {"message": "success"} \ No newline at end of file diff --git a/pretor/api/auth.py b/pretor/api/auth.py index c3dde40..7f3d6ce 100644 --- a/pretor/api/auth.py +++ b/pretor/api/auth.py @@ -28,7 +28,7 @@ class UserRegister(BaseModel): async def create_user(user_register: UserRegister): postgres_database = ray_actor_hook("postgres_database") hashed_password = await run_in_threadpool(Accessor.hash_password, user_register.password) - user = await postgres_database.add_user.remote(user_register.user_name, hashed_password) + user = await postgres_database.auth_database.remote("add_user", user_register.user_name, hashed_password) return {"message": "success", "user_id": user.user_id} class UserLogin(BaseModel): @@ -38,7 +38,7 @@ class UserLogin(BaseModel): @auth_router.post("/login") async def login_user(user_login: UserLogin): postgres_database = ray_actor_hook("postgres_database") - user = await postgres_database.login_user.remote(user_login.user_name) + user = await postgres_database.auth_database.remote("login_user", user_login.user_name) if user.user_name != user_login.user_name: pass token = await run_in_threadpool(Accessor.login_hashed_password, user, user_login.password) diff --git a/pretor/api/provider.py b/pretor/api/provider.py index 84d899b..d9d9ff5 100644 --- a/pretor/api/provider.py +++ b/pretor/api/provider.py @@ -16,6 +16,8 @@ from fastapi import APIRouter, Depends from pydantic import BaseModel from typing import Literal from pretor.utils.access import TokenData, Accessor +from pretor.utils.check_user.role_check import RoleChecker +from pretor.core.database.table.user import UserAuthority from typing import Dict from pretor.core.global_state_machine.model_provider.base_provider import Provider from pretor.utils.ray_hook import ray_actor_hook @@ -30,17 +32,17 @@ class ProviderRegister(BaseModel): @provider_router.post("") async def create_provider(provider_register: ProviderRegister, - token_data: TokenData = Depends(Accessor.get_current_user)) -> None: + token_data: TokenData = Depends(RoleChecker(allowed_roles=UserAuthority.USER))) -> None: global_state_machine = ray_actor_hook("global_state_machine") - await global_state_machine.add_provider.remote(provider_type=provider_register.provider_type, - provider_title=provider_register.provider_title, - provider_url=provider_register.provider_url, - provider_apikey=provider_register.provider_apikey, - provider_owner=token_data.user_id) + await global_state_machine.add_provider_wrap.remote(provider_type=provider_register.provider_type, + provider_title=provider_register.provider_title, + provider_url=provider_register.provider_url, + provider_apikey=provider_register.provider_apikey, + provider_owner=token_data.user_id) @provider_router.get("/list") async def get_provider_list(_: TokenData = Depends(Accessor.get_current_user)) -> Dict[str, Provider]: global_state_machine = ray_actor_hook("global_state_machine") - provider_list: Dict[str, Provider] = await global_state_machine.get_provider_list.remote() + provider_list: Dict[str, Provider] = await global_state_machine.provider_manager.remote("get_provider_list") return {"provider_list": provider_list} \ No newline at end of file diff --git a/pretor/api/resource.py b/pretor/api/resource.py index 0754cdf..a03c102 100644 --- a/pretor/api/resource.py +++ b/pretor/api/resource.py @@ -17,13 +17,15 @@ import viceroy from pretor.core.workflow.workflow_template_generator.workflow_template import WorkflowTemplate from pretor.utils.ray_hook import ray_actor_hook from fastapi import APIRouter, Depends -from pretor.utils.access import TokenData, Accessor +from pretor.utils.access import TokenData +from pretor.utils.check_user.role_check import RoleChecker +from pretor.core.database.table.user import UserAuthority resource_router = APIRouter(prefix="/api/v1/resource") @resource_router.post("/workflow_template") async def create_workflow_template(workflow_template: WorkflowTemplate, - _: TokenData = Depends(Accessor.get_current_user)): + _: TokenData = Depends(RoleChecker(allowed_roles=UserAuthority.USER))): global_state_machine = ray_actor_hook("global_state_machine") await global_state_machine.workflow_template_generate.remote(workflow_template) return {"message": "创建成功"} @@ -36,7 +38,7 @@ class Skill(BaseModel): @resource_router.post("/skill") async def install_skill(skill: Skill, - _: TokenData = Depends(Accessor.get_current_user)): + _: TokenData = Depends(RoleChecker(allowed_roles=UserAuthority.USER))): global_state_machine = ray_actor_hook("global_state_machine") # noinspection PyUnresolvedReferences await viceroy.install_skill_async(url = skill.repo_url, @@ -46,6 +48,6 @@ async def install_skill(skill: Skill, skill_name = skill.path.split("/")[-1] else: skill_name = skill.repo_url.split("/")[-1] - await global_state_machine.add_skill.remote(skill_name) + await global_state_machine.skill_manager.remote("add_skill", skill_name) return {"message": "创建成功"} diff --git a/pretor/core/database/postgres.py b/pretor/core/database/postgres.py index 6b95f89..d550950 100644 --- a/pretor/core/database/postgres.py +++ b/pretor/core/database/postgres.py @@ -35,49 +35,22 @@ class PostgresDatabase: self.async_engine = create_async_engine(database_url, echo=True) self.async_session_maker = sessionmaker(self.async_engine, class_=AsyncSession, expire_on_commit=False) - self.auth_database = AuthDatabase(self.async_session_maker) - self.provider_database = ProviderDatabase(self.async_session_maker) - self.individual_database = IndividualDatabase(self.async_session_maker) + self._auth_database = AuthDatabase(self.async_session_maker) + self._provider_database = ProviderDatabase(self.async_session_maker) + self._individual_database = IndividualDatabase(self.async_session_maker) async def init_db(self) -> None: async with self.async_engine.begin() as conn: await conn.run_sync(SQLModel.metadata.create_all) - # provider_database操作 - async def get_providers(self): - return await self.provider_database.get_provider() + async def auth_database(self, method_name: str, *args, **kwargs): + method = getattr(self._auth_database, method_name) + return await method(*args, **kwargs) - async def add_provider(self, **kwargs): - return await self.provider_database.add_provider(**kwargs) + async def provider_database(self, method_name: str, *args, **kwargs): + method = getattr(self._provider_database, method_name) + return await method(*args, **kwargs) - # auth_database操作 - async def add_user(self, **kwargs): - return await self.auth_database.add_user(**kwargs) - - async def change_password(self, **kwargs): - return await self.auth_database.change_password(**kwargs) - - async def delete_user(self, **kwargs): - return await self.auth_database.delete_user(**kwargs) - - async def login_user(self, **kwargs): - return await self.auth_database.login_user(**kwargs) - - async def get_user_authority(self, **kwargs): - return await self.auth_database.get_user_authority(**kwargs) - - ##individual_database 操作 - async def add_worker_individual(self, **kwargs): - return await self.individual_database.add_worker_individual(**kwargs) - - async def get_worker_individual(self, agent_id: str): - return await self.individual_database.get_worker_individual(agent_id) - - async def get_worker_individual_list(self, owner_id: str): - return await self.individual_database.get_worker_individual_list(owner_id) - - async def update_worker_individual(self, agent_id: str, **kwargs): - return await self.individual_database.update_worker_individual(agent_id, **kwargs) - - async def delete_worker_individual(self, agent_id: str): - return await self.individual_database.delete_worker_individual(agent_id) \ No newline at end of file + async def individual_database(self, method_name: str, *args, **kwargs): + method = getattr(self._individual_database, method_name) + return await method(*args, **kwargs) \ No newline at end of file diff --git a/pretor/core/database/table/provider.py b/pretor/core/database/table/provider.py index 9c166d9..89ef42b 100644 --- a/pretor/core/database/table/provider.py +++ b/pretor/core/database/table/provider.py @@ -15,13 +15,13 @@ from sqlmodel import SQLModel, Field from typing import List from sqlalchemy import Column, JSON -from typing import Optional, Literal +from typing import Optional class Provider(SQLModel, table=True): __tablename__ = "provider" provider_id: str = Field(primary_key=True) provider_title: str = Field(index=True) - provider_type: Literal["openai", "vllm"] + provider_type: str provider_url: Optional[str] provider_apikey: Optional[str] diff --git a/pretor/core/global_state_machine/global_state_machine.py b/pretor/core/global_state_machine/global_state_machine.py index 22a2ae5..4faefbf 100644 --- a/pretor/core/global_state_machine/global_state_machine.py +++ b/pretor/core/global_state_machine/global_state_machine.py @@ -15,37 +15,63 @@ import ray from pretor.core.global_state_machine.provider_manager import ProviderManager from pretor.core.global_state_machine.tool_manager import GlobalToolManager -from pretor.core.global_state_machine.model_provider import Provider, ProviderArgs -import httpx -import pathlib -import json -from loguru import logger -from typing import Dict, Literal, List +from typing import Dict from pretor.core.database.postgres import PostgresDatabase from pretor.api.platform.event import PretorEvent import asyncio from pretor.core.workflow.workflow import PretorWorkflow -from pretor.core.workflow.workflow_template_generator.workflow_template import WorkflowTemplate from pretor.core.workflow.workflow_template_manager import WorkflowManager from pretor.core.global_state_machine.skill_manager import GlobalSkillManager -from pretor.plugin.tool_plugin.base_tool import BaseToolData @ray.remote class GlobalStateMachine: def __init__(self, postgres_database: PostgresDatabase): - self.event_dict: Dict[int, PretorEvent] = {} - self.global_provider_manager = ProviderManager(postgres_database) - self.global_tool_manager = GlobalToolManager() - self.global_workflow_template_manager = WorkflowManager() - self.global_skill_manager = GlobalSkillManager() + self.event_dict: Dict[str, PretorEvent] = {} + self._global_provider_manager = ProviderManager(postgres_database) + self._global_tool_manager = GlobalToolManager() + self._global_workflow_template_manager = WorkflowManager() + self._global_skill_manager = GlobalSkillManager() self.postgres_database = postgres_database async def init_state_machine(self): - await self.global_provider_manager.init_provider_register(self.postgres_database) + await self._global_provider_manager.init_provider_register(self.postgres_database) + async def add_provider_wrap(self, provider_type, provider_title, provider_url, provider_apikey, provider_owner): + return await self._global_provider_manager.add_provider( + provider_type=provider_type, + provider_title=provider_title, + provider_url=provider_url, + provider_apikey=provider_apikey, + provider_owner=provider_owner, + postgres_database=self.postgres_database + ) + + async def provider_manager(self, method_name: str, *args, **kwargs): + method = getattr(self._global_provider_manager, method_name) + if asyncio.iscoroutinefunction(method): + return await method(*args, **kwargs) + return method(*args, **kwargs) + + async def tool_manager(self, method_name: str, *args, **kwargs): + method = getattr(self._global_tool_manager, method_name) + if asyncio.iscoroutinefunction(method): + return await method(*args, **kwargs) + return method(*args, **kwargs) + + async def workflow_template_manager(self, method_name: str, *args, **kwargs): + method = getattr(self._global_workflow_template_manager, method_name) + if asyncio.iscoroutinefunction(method): + return await method(*args, **kwargs) + return method(*args, **kwargs) + + async def skill_manager(self, method_name: str, *args, **kwargs): + method = getattr(self._global_skill_manager, method_name) + if asyncio.iscoroutinefunction(method): + return await method(*args, **kwargs) + return method(*args, **kwargs) ###以下方法为event_dict方法 def add_event(self, event: PretorEvent) -> None: @@ -79,108 +105,3 @@ class GlobalStateMachine: async def get_received(self, event_id) -> str: return await self.event_dict[event_id].receive_queue.get() - - - ###以下方法为global_provider_manager方法 - async def add_provider(self, provider_type: Literal["openai", "gemini", "claude"], - provider_title: str, - provider_url: str, - provider_apikey: str, - provider_owner: int) -> None: - """ - add_provider方法,注册供应商适配器(provider_manager方法) - - Args - provider_type: 注册商接口类型,目前只支持openai,gemini和claude接口 - provider_title: 供应商名称,为供应商提供的别名 - provider_url: 供应商url - provider_apikey: 供应商所需要的apikey - Returns: - - """ - provider_args: ProviderArgs = ProviderArgs(provider_title=provider_title, - provider_url=provider_url, - provider_apikey=provider_apikey, - provider_owner=provider_owner) - try: - provider_class = self.global_provider_manager.provider_mapper.get(provider_type, None) - if provider_class is None: - logger.warning(f"Provider type {provider_type} is not supported.") - return None - provider: Provider = await provider_class.create_model(provider_args) - - provider.provider_owner = provider_owner - - self.global_provider_manager.provider_register[provider_title] = provider - - await self.postgres_database.add_provider.remote(provider_title=provider.provider_title, - provider_url=provider.provider_url, - provider_apikey=provider.provider_apikey, - provider_models=provider.provider_models, - provider_type=provider.provider_type, - provider_owner=provider.provider_owner) - - logger.info(f"已添加适配器{provider_title}") - except httpx.RequestError as e: - logger.warning(f"[{provider_args.provider_title}] 网络请求异常: {e}") - except Exception as e: - logger.warning(f"[{provider_args.provider_title}] 解析模型列表时发生错误: {e}") - - def get_provider_list(self) -> Dict[str, Provider]: - """ - get_provider_list方法,获取注册表(provider_manager方法) - Returns: - 返回provider_register属性字典 - """ - return self.global_provider_manager.provider_register - - def get_provider(self, provider_title) -> Provider: - """ - get_provider方法,获取供应商信息(provider_manager方法) - Args: - provider_title:provider名称 - - Returns: - Provider对象,返回注册在self.global_provider_manager.provider_register的供应商 - """ - provider = self.global_provider_manager.provider_register.get(provider_title) - return provider - - - ###以下为global_tool_manager方法 - def get_tool_list(self, agent_name: str) -> Dict[str, BaseToolData]: - """ - 获取工具表方法 - Args: - agent_name: agent的名字 - - Returns: - 返回该agent的tool,类型为dict - """ - tool_list = self.global_tool_manager.tool_mapper.get(agent_name, {}) - return tool_list - - - ###以下为workflow_template_manager方法 - def workflow_template_generate(self, workflow_template: WorkflowTemplate) -> None: - self.global_workflow_template_manager.generate_workflow_template(workflow_template) - - def get_workflow_template_list(self) -> List[Dict[str, str]]: - return self.global_workflow_template_manager.workflow_templates_registry - - - ###以下为skill_manager方法 - def add_skill(self, skill_name: str): - skill_plugin_dir = pathlib.Path(__file__).parent.parent.parent / "plugin" / "skill_plugin" / skill_name - json_path = skill_plugin_dir / "skill.json" - try: - with open(json_path, "r", encoding="utf-8") as f: - skill = json.load(f) - name = skill.get("name") - if name: - self.global_skill_manager.skill_mapper[name] = ( - skill.get("description", ""), - skill.get("instructions", "") - ) - except (json.JSONDecodeError, OSError) as e: - print(f"警告: 加载插件 {skill_name} 失败: {e}") diff --git a/pretor/core/global_state_machine/provider_manager.py b/pretor/core/global_state_machine/provider_manager.py index 304b2e1..42b020f 100644 --- a/pretor/core/global_state_machine/provider_manager.py +++ b/pretor/core/global_state_machine/provider_manager.py @@ -33,6 +33,45 @@ class ProviderManager: self.provider_register = {} async def init_provider_register(self, postgres) -> None: - providers = await postgres.get_providers.remote() + providers = await postgres.provider_database.remote("get_provider") for provider in providers: - self.provider_register[provider.provider_title] = provider \ No newline at end of file + self.provider_register[provider.provider_title] = provider + + async def add_provider(self, provider_type, provider_title, provider_url, provider_apikey, provider_owner, postgres_database) -> None: + from pretor.core.global_state_machine.model_provider import ProviderArgs + from loguru import logger + import httpx + + provider_args: ProviderArgs = ProviderArgs(provider_title=provider_title, + provider_url=provider_url, + provider_apikey=provider_apikey, + provider_owner=provider_owner) + try: + provider_class = self.provider_mapper.get(provider_type, None) + if provider_class is None: + logger.warning(f"Provider type {provider_type} is not supported.") + return None + provider: Provider = await provider_class.create_model(provider_args) + + provider.provider_owner = provider_owner + + self.provider_register[provider_title] = provider + + await postgres_database.provider_database.remote("add_provider", provider_title=provider.provider_title, + provider_url=provider.provider_url, + provider_apikey=provider.provider_apikey, + provider_models=provider.provider_models, + provider_type=provider.provider_type, + provider_owner=provider.provider_owner) + + logger.info(f"已添加适配器{provider_title}") + except httpx.RequestError as e: + logger.warning(f"[{provider_args.provider_title}] 网络请求异常: {e}") + except Exception as e: + logger.warning(f"[{provider_args.provider_title}] 解析模型列表时发生错误: {e}") + + def get_provider_list(self): + return self.provider_register + + def get_provider(self, provider_title): + return self.provider_register.get(provider_title) \ No newline at end of file diff --git a/pretor/core/individual/consciousness_node/consciousness_node.py b/pretor/core/individual/consciousness_node/consciousness_node.py index 3f5931b..e6bb729 100644 --- a/pretor/core/individual/consciousness_node/consciousness_node.py +++ b/pretor/core/individual/consciousness_node/consciousness_node.py @@ -55,7 +55,7 @@ class ConsciousnessNode: "请确保所有的思考和生成过程符合逻辑,严密且高质量。" ) output_type = Union[ForSupervisoryNode, ForWorkflow, ForWorkflowEngine] - provider: Provider = global_state_machine.get_provider.remote(provider_title) + provider: Provider = global_state_machine.provider_manager.remote("get_provider", provider_title) agent_factory = AgentFactory() self.agent = agent_factory.create_agent(provider=provider, model_id=model_id, @@ -85,7 +85,7 @@ class ConsciousnessNode: else: logger.error(f"ConsciousnessNode: 未知或不匹配的返回类型: {type(result)}") return None - except Exception as e: + except Exception: logger.exception("ConsciousnessNode在执行working时发生严重错误") return None diff --git a/pretor/core/individual/control_node/control_node.py b/pretor/core/individual/control_node/control_node.py index cb28404..cf3dec2 100644 --- a/pretor/core/individual/control_node/control_node.py +++ b/pretor/core/individual/control_node/control_node.py @@ -51,7 +51,7 @@ class ControlNode: "请注意:你的输出应当具体、实用,直接提供任务所要求的结果,不要做过多无关的寒暄。" ) output_type = ForWorkflow - provider: Provider = global_state_machine.get_provider.remote(provider_title) + provider: Provider = global_state_machine.provider_manager.remote("get_provider", provider_title) agent_factory = AgentFactory() self.agent = agent_factory.create_agent(provider=provider, model_id=model_id, diff --git a/pretor/core/individual/supervisory_node/supervisory_node.py b/pretor/core/individual/supervisory_node/supervisory_node.py index 95265fd..8996bc5 100644 --- a/pretor/core/individual/supervisory_node/supervisory_node.py +++ b/pretor/core/individual/supervisory_node/supervisory_node.py @@ -54,7 +54,7 @@ class SupervisoryNode: "请保持冷静、专业,并严格遵循上述路由规则。" ) output_type = Union[ForConsciousnessNode, ForUser] - provider: Provider = await global_state_machine.get_provider.remote(provider_title) + provider: Provider = await global_state_machine.provider_manager.remote("get_provider", provider_title) agent_factory = AgentFactory() self.agent = agent_factory.create_agent(provider=provider, model_id=model_id, @@ -155,7 +155,7 @@ class SupervisoryNode: time_str = datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S") try: global_state_machine = ray_actor_hook("global_state_machine") - workflow_template_dict = await global_state_machine.get_workflow_template_list.remote() + workflow_template_dict = await global_state_machine.workflow_template_manager.remote("get_workflow_template_list") available_templates_str = "\n".join([f"- 名称: {k}, 描述/内容: {v}" for k, v in workflow_template_dict.items()]) if workflow_template_dict else "暂无注册的工作流模板" deps = SupervisoryNodeDeps( diff --git a/pretor/core/workflow/workflow_runner.py b/pretor/core/workflow/workflow_runner.py index 1d65bea..761708b 100644 --- a/pretor/core/workflow/workflow_runner.py +++ b/pretor/core/workflow/workflow_runner.py @@ -158,7 +158,7 @@ class WorkflowEngine: logger.info(f"Supervisory 最终回复:{user_response}") else: logger.warning("未提供 supervisory_node 句柄,跳过用户反馈生成。") - except Exception as e: + except Exception: logger.exception("生成工作流执行汇报时发生错误") async def _dispatch_to_node(self, step: WorkStep, input_data: Any) -> tuple[Any, bool]: @@ -207,7 +207,7 @@ class WorkflowEngine: else: raise WorkflowError(f"未知的节点类型:{step.node}") - except Exception as e: + except Exception: logger.exception(f"节点 {step.node} 执行动作 {step.action} 失败") return None, False @@ -245,7 +245,7 @@ class WorkflowRunningEngine: self.consciousness_node = consciousness_node self.control_node = control_node self.supervisory_node = supervisory_node - self.global_state_machine = ray_actor_hook("global_state_machine") + self.global_state_machine = ray_actor_hook("global_state_machine").global_state_machine async def run(self): self.runner_engine = { diff --git a/pretor/core/workflow/workflow_template_manager.py b/pretor/core/workflow/workflow_template_manager.py index b86a6a6..fca50aa 100644 --- a/pretor/core/workflow/workflow_template_manager.py +++ b/pretor/core/workflow/workflow_template_manager.py @@ -40,6 +40,5 @@ class WorkflowManager: try: workflow_template = self.workflow_template_generator.generate_workflow_template(workflow_template=workflow_template) self.workflow_templates_registry[workflow_template.name] = workflow_template.desc - except Exception as e: - logger.exception("Failed to generate workflow template") - + except Exception: + logger.exception("Failed to generate workflow template") \ No newline at end of file diff --git a/pretor/utils/check_user/role_check.py b/pretor/utils/check_user/role_check.py index 60c555a..e83667f 100644 --- a/pretor/utils/check_user/role_check.py +++ b/pretor/utils/check_user/role_check.py @@ -21,7 +21,7 @@ from pretor.utils.ray_hook import ray_actor_hook @lru_cache async def get_authority(user_id: str) -> UserAuthority: postgres_database = ray_actor_hook("postgres_database") - user_authority = await postgres_database.get_user_authority.remote(user_id=user_id) + user_authority = await postgres_database.auth_database.remote("get_user_authority", user_id=user_id) return user_authority class RoleChecker: diff --git a/pretor/utils/get_tool.py b/pretor/utils/get_tool.py index 162379c..b634816 100644 --- a/pretor/utils/get_tool.py +++ b/pretor/utils/get_tool.py @@ -49,7 +49,7 @@ def del_tool_cache(tool_name: str) -> None: @lru_cache(maxsize=1) async def get_tool(agent_name: str) -> List[Callable]: global_state_machine = ray_actor_hook("global_state_machine") - _tool_list = await global_state_machine.get_tool_list.remote(agent_name) + _tool_list = await global_state_machine.tool_manager.remote("get_tool_list", agent_name) tool_list = [] for tool_name in _tool_list.keys(): tool_func = _get_tool_func(tool_name) diff --git a/tests/core/database/postgres_test.py b/tests/core/database/postgres_test.py index b69ca51..d034e43 100644 --- a/tests/core/database/postgres_test.py +++ b/tests/core/database/postgres_test.py @@ -65,6 +65,8 @@ async def test_postgres_database(mock_env_get, mock_provider_db, mock_auth_db, m ) mock_auth_db.assert_called_once() mock_provider_db.assert_called_once() + mock_auth_db.return_value.get_user_authority = AsyncMock(return_value="test_auth") + assert await db.auth_database("get_user_authority", user_id="123") == "test_auth" with patch("pretor.core.database.postgres.SQLModel.metadata.create_all") as mock_create_all: await db.init_db() diff --git a/tests/core/database/table/table_user_test.py b/tests/core/database/table/table_user_test.py index 824070d..50ea1c3 100644 --- a/tests/core/database/table/table_user_test.py +++ b/tests/core/database/table/table_user_test.py @@ -2,6 +2,6 @@ import pytest from pretor.core.database.table.user import User def test_user_table(): - user = User(user_name="name", hashed_password="pw") + user = User(user_id="id", user_name="name", hashed_password="pw") assert User.__tablename__ == 'user' assert user.user_name == "name" diff --git a/tests/core/global_state_machine/global_state_machine_test.py b/tests/core/global_state_machine/global_state_machine_test.py index c3703f1..52588f8 100644 --- a/tests/core/global_state_machine/global_state_machine_test.py +++ b/tests/core/global_state_machine/global_state_machine_test.py @@ -40,9 +40,8 @@ def mock_postgres(): @pytest.fixture def gsm(mock_postgres): - with patch("pretor.core.global_state_machine.global_state_machine.ProviderManager") as mock_pm: - manager = GlobalStateMachine(mock_postgres) - return manager + manager = GlobalStateMachine(mock_postgres) + return manager def test_add_delete_get_event(gsm): @@ -103,24 +102,24 @@ async def test_add_provider_success(gsm, mock_postgres): mock_provider.provider_type = "openai" mock_provider_class.create_model.return_value = mock_provider - gsm.global_provider_manager.provider_mapper = {"openai": mock_provider_class} - gsm.global_provider_manager.provider_register = {} + gsm._global_provider_manager.provider_mapper = {"openai": mock_provider_class} + gsm._global_provider_manager.provider_register = {} mock_add_provider = AsyncMock() - mock_postgres.provider_database.add_provider.remote = mock_add_provider + mock_postgres.provider_database.remote = mock_add_provider - await gsm.add_provider("openai", "title", "url", "key", 1) + await gsm.add_provider_wrap("openai", "title", "url", "key", 1) - assert gsm.global_provider_manager.provider_register["title"] == mock_provider + assert gsm._global_provider_manager.provider_register["title"] == mock_provider mock_add_provider.assert_called_once() assert mock_provider.provider_owner == 1 @pytest.mark.asyncio async def test_add_provider_unsupported(gsm): - gsm.global_provider_manager.provider_mapper = {} - with patch("pretor.core.global_state_machine.global_state_machine.logger") as mock_logger: - await gsm.add_provider("magic", "title", "url", "key", 1) + gsm._global_provider_manager.provider_mapper = {} + with patch("loguru.logger") as mock_logger: + await gsm.add_provider_wrap("magic", "title", "url", "key", 1) mock_logger.warning.assert_called_with("Provider type magic is not supported.") @@ -129,10 +128,10 @@ async def test_add_provider_request_error(gsm): from httpx import RequestError mock_provider_class = AsyncMock() mock_provider_class.create_model.side_effect = RequestError("Network Error", request=MagicMock()) - gsm.global_provider_manager.provider_mapper = {"openai": mock_provider_class} + gsm._global_provider_manager.provider_mapper = {"openai": mock_provider_class} - with patch("pretor.core.global_state_machine.global_state_machine.logger") as mock_logger: - await gsm.add_provider("openai", "title", "url", "key", 1) + with patch("loguru.logger") as mock_logger: + await gsm.add_provider_wrap("openai", "title", "url", "key", 1) mock_logger.warning.assert_called_once() assert "网络请求异常" in mock_logger.warning.call_args[0][0] @@ -141,18 +140,18 @@ async def test_add_provider_request_error(gsm): async def test_add_provider_generic_error(gsm): mock_provider_class = AsyncMock() mock_provider_class.create_model.side_effect = ValueError("Some Error") - gsm.global_provider_manager.provider_mapper = {"openai": mock_provider_class} + gsm._global_provider_manager.provider_mapper = {"openai": mock_provider_class} - with patch("pretor.core.global_state_machine.global_state_machine.logger") as mock_logger: - await gsm.add_provider("openai", "title", "url", "key", 1) + with patch("loguru.logger") as mock_logger: + await gsm.add_provider_wrap("openai", "title", "url", "key", 1) mock_logger.warning.assert_called_once() assert "解析模型列表时发生错误" in mock_logger.warning.call_args[0][0] def test_get_provider_list_and_get_provider(gsm): mock_provider = MagicMock() - gsm.global_provider_manager.provider_register = {"p1": mock_provider} + gsm._global_provider_manager.provider_register = {"p1": mock_provider} - assert gsm.get_provider_list() == {"p1": mock_provider} - assert gsm.get_provider("p1") == mock_provider - assert gsm.get_provider("missing") is None + assert gsm._global_provider_manager.get_provider_list() == {"p1": mock_provider} + assert gsm._global_provider_manager.get_provider("p1") == mock_provider + assert gsm._global_provider_manager.get_provider("missing") is None diff --git a/tests/core/global_state_machine/provider_manager_test.py b/tests/core/global_state_machine/provider_manager_test.py index 878567a..2e9fa93 100644 --- a/tests/core/global_state_machine/provider_manager_test.py +++ b/tests/core/global_state_machine/provider_manager_test.py @@ -16,6 +16,8 @@ async def test_provider_manager_init(): mock_postgres.get_providers.remote = AsyncMock(return_value=[mock_provider1, mock_provider2]) manager = ProviderManager(mock_postgres) + mock_postgres.provider_database = MagicMock() + mock_postgres.provider_database.remote = AsyncMock(return_value=[mock_provider1, mock_provider2]) await manager.init_provider_register(mock_postgres) assert "openai" in manager.provider_mapper diff --git a/tests/core/workflow/workflow_runner_test.py b/tests/core/workflow/workflow_runner_test.py index 7735fee..5fae511 100644 --- a/tests/core/workflow/workflow_runner_test.py +++ b/tests/core/workflow/workflow_runner_test.py @@ -119,20 +119,17 @@ async def test_workflow_running_engine_runner(): from pretor.core.individual.consciousness_node.template import ForWorkflowEngine mock_consciousness = MagicMock() - mock_wf = MagicMock() mock_wf.trace_id = "test_trace" mock_wf.title = "test_title" mock_result = MagicMock(spec=ForWorkflowEngine) mock_result.workflow = mock_wf - mock_consciousness.working.remote = AsyncMock(return_value=mock_result) engine = WorkflowRunningEngine(mock_consciousness, "control", "supervisor") engine.workflow_queue = asyncio.Queue() - # Use real PretorEvent to avoid Pydantic validation errors on MagicMock properties mock_event = PretorEvent( platform="test_platform", user_id="test_user", @@ -142,18 +139,19 @@ async def test_workflow_running_engine_runner(): ) await engine.workflow_queue.put(mock_event) - with patch("pretor.core.workflow.workflow_runner.WorkflowEngine") as mock_wf_engine_cls, \ - patch("pretor.core.workflow.workflow_runner.ray_actor_hook") as mock_hook: - mock_gsm = MagicMock() - mock_gsm.update_workflow.remote = AsyncMock() - mock_hook.return_value = mock_gsm + with patch("pretor.core.workflow.workflow_runner.WorkflowEngine") as mock_wf_engine_cls, patch("builtins.open", new_callable=MagicMock) as mock_open: + + # Instead of patching hook, we inject it directly + engine.global_state_machine = AsyncMock() + + mock_open.return_value.__enter__.return_value.read.return_value = '{}' mock_engine_instance = MagicMock() mock_engine_instance.run = AsyncMock() mock_wf_engine_cls.return_value = mock_engine_instance task = asyncio.create_task(engine.runner(1)) - await asyncio.sleep(0.05) # Give runner time to process one item - task.cancel() # Stop the infinite loop + await asyncio.sleep(0.05) + task.cancel() - mock_wf_engine_cls.assert_called_with(mock_wf, mock_consciousness, "control", "supervisor") \ No newline at end of file + mock_wf_engine_cls.assert_called_with(mock_wf, mock_consciousness, "control", "supervisor") diff --git a/tests/core/workflow/workflow_test.py b/tests/core/workflow/workflow_test.py index 8254839..b052019 100644 --- a/tests/core/workflow/workflow_test.py +++ b/tests/core/workflow/workflow_test.py @@ -24,7 +24,7 @@ def test_pretor_workflow_validation_success(): ws1 = WorkStep(step=1, node="control_node", action="a1", desc="d1") ws2 = WorkStep(step=2, node="supervisory_node", action="a2", desc="d2") wg = WorkerGroup(name="g1", primary_individual={"coder": 1}, composite_individual={}) - wf = PretorWorkflow(title="wf1", workgroup_list=[wg], work_link=[ws1, ws2], trace_id="t", event_info={"platform":"a", "username":"b"}) + wf = PretorWorkflow(title="wf1", workgroup_list=[wg], work_link=[ws1, ws2], trace_id="t", event_info={"platform":"a", "user_name":"b"}) assert wf.title == "wf1" def test_pretor_workflow_validation_error_step_discontinuous(): @@ -32,7 +32,7 @@ def test_pretor_workflow_validation_error_step_discontinuous(): ws2 = WorkStep(step=3, node="supervisory_node", action="a2", desc="d2") wg = WorkerGroup(name="g1", primary_individual={}, composite_individual={}) with pytest.raises(ValueError, match="工作链步数不连续"): - PretorWorkflow(title="wf1", workgroup_list=[wg], work_link=[ws1, ws2], trace_id="t", event_info={"platform":"a", "username":"b"}) + PretorWorkflow(title="wf1", workgroup_list=[wg], work_link=[ws1, ws2], trace_id="t", event_info={"platform":"a", "user_name":"b"}) def test_pretor_workflow_validation_error_jump_out_of_bounds(): lg = LogicGate(if_fail="jump_to_step_3", if_pass="continue") @@ -40,14 +40,14 @@ def test_pretor_workflow_validation_error_jump_out_of_bounds(): ws2 = WorkStep(step=2, node="supervisory_node", action="a2", desc="d2") wg = WorkerGroup(name="g1", primary_individual={}, composite_individual={}) with pytest.raises(ValueError, match="跳转目标 Step 3 越界了"): - PretorWorkflow(title="wf1", workgroup_list=[wg], work_link=[ws1, ws2], trace_id="t", event_info={"platform":"a", "username":"b"}) + PretorWorkflow(title="wf1", workgroup_list=[wg], work_link=[ws1, ws2], trace_id="t", event_info={"platform":"a", "user_name":"b"}) def test_pretor_workflow_validation_error_jump_format_error(): lg = LogicGate(if_fail="jump_to_step_invalid", if_pass="continue") ws1 = WorkStep(step=1, node="control_node", action="a1", desc="d1", logic_gate=lg) wg = WorkerGroup(name="g1", primary_individual={}, composite_individual={}) with pytest.raises(ValueError, match="LogicGate 格式错误"): - PretorWorkflow(title="wf1", workgroup_list=[wg], work_link=[ws1], trace_id="t", event_info={"platform":"a", "username":"b"}) + PretorWorkflow(title="wf1", workgroup_list=[wg], work_link=[ws1], trace_id="t", event_info={"platform":"a", "user_name":"b"}) def test_workflow_status(): status = WorkflowStatus()