wip: 增加了api的简单权限校验,使用反射对于actor进行重构

This commit is contained in:
朝夕 2026-04-22 11:36:59 +08:00
parent 81da2e9f81
commit be0d78c855
21 changed files with 173 additions and 235 deletions

View File

@ -21,6 +21,8 @@ from pretor.utils.access import Accessor, TokenData
from pretor.core.database.table.individual import AgentType
from fastapi import HTTPException
from typing import Optional, List, Dict
from pretor.utils.check_user.role_check import RoleChecker
from pretor.core.database.table.user import UserAuthority
agent_router = APIRouter(prefix="/api/v1/agent", tags=["agent"])
@ -35,7 +37,7 @@ class AgentLocalRegister(BaseModel):
@agent_router.post("")
async def load_agent(agent_register: Union[AgentRegister, AgentLocalRegister],
_: TokenData = Depends(Accessor.get_current_user)):
_: TokenData = Depends(RoleChecker(allowed_roles=UserAuthority.USER))):
global_state_machine = ray_actor_hook("global_state_machine")
if isinstance(agent_register, AgentLocalRegister):
pass
@ -82,18 +84,18 @@ class WorkerIndividualUpdate(BaseModel):
@agent_router.post("/worker")
async def create_worker_individual(worker_data: WorkerIndividualCreate,
token_data: TokenData = Depends(Accessor.get_current_user)):
token_data: TokenData = Depends(RoleChecker(allowed_roles=UserAuthority.USER))):
postgres_database = ray_actor_hook("postgres_database")
data_dict = worker_data.model_dump()
data_dict["owner_id"] = token_data.user_id
worker = await postgres_database.add_worker_individual.remote(**data_dict)
worker = await postgres_database.individual_database.remote("add_worker_individual", **data_dict)
return {"message": "success", "agent_id": worker.agent_id}
@agent_router.get("/worker")
async def get_worker_individual_list(token_data: TokenData = Depends(Accessor.get_current_user)):
postgres_database = ray_actor_hook("postgres_database")
workers = await postgres_database.get_worker_individual_list.remote(owner_id=token_data.user_id)
workers = await postgres_database.individual_database.remote("get_worker_individual_list", owner_id=token_data.user_id)
return {"workers": workers}
@ -101,7 +103,7 @@ async def get_worker_individual_list(token_data: TokenData = Depends(Accessor.ge
async def get_worker_individual(agent_id: str,
token_data: TokenData = Depends(Accessor.get_current_user)):
postgres_database = ray_actor_hook("postgres_database")
worker = await postgres_database.get_worker_individual.remote(agent_id=agent_id)
worker = await postgres_database.individual_database.remote("get_worker_individual", agent_id=agent_id)
if not worker:
raise HTTPException(status_code=404, detail="Agent not found")
if worker.owner_id != token_data.user_id:
@ -114,14 +116,14 @@ async def update_worker_individual(agent_id: str,
worker_data: WorkerIndividualUpdate,
token_data: TokenData = Depends(Accessor.get_current_user)):
postgres_database = ray_actor_hook("postgres_database")
worker = await postgres_database.get_worker_individual.remote(agent_id=agent_id)
worker = await postgres_database.individual_database.remote("get_worker_individual", agent_id=agent_id)
if not worker:
raise HTTPException(status_code=404, detail="Agent not found")
if worker.owner_id != token_data.user_id:
raise HTTPException(status_code=403, detail="Forbidden: You do not own this agent")
update_data = worker_data.model_dump(exclude_unset=True)
updated_worker = await postgres_database.update_worker_individual.remote(agent_id=agent_id, **update_data)
updated_worker = await postgres_database.individual_database.remote("update_worker_individual", agent_id=agent_id, **update_data)
return {"message": "success", "worker": updated_worker}
@ -129,11 +131,10 @@ async def update_worker_individual(agent_id: str,
async def delete_worker_individual(agent_id: str,
token_data: TokenData = Depends(Accessor.get_current_user)):
postgres_database = ray_actor_hook("postgres_database")
worker = await postgres_database.get_worker_individual.remote(agent_id=agent_id)
worker = await postgres_database.individual_database.remote("get_worker_individual", agent_id=agent_id)
if not worker:
raise HTTPException(status_code=404, detail="Agent not found")
if worker.owner_id != token_data.user_id:
raise HTTPException(status_code=403, detail="Forbidden: You do not own this agent")
await postgres_database.delete_worker_individual.remote(agent_id=agent_id)
await postgres_database.individual_database.remote("delete_worker_individual", agent_id=agent_id)
return {"message": "success"}

View File

@ -28,7 +28,7 @@ class UserRegister(BaseModel):
async def create_user(user_register: UserRegister):
postgres_database = ray_actor_hook("postgres_database")
hashed_password = await run_in_threadpool(Accessor.hash_password, user_register.password)
user = await postgres_database.add_user.remote(user_register.user_name, hashed_password)
user = await postgres_database.auth_database.remote("add_user", user_register.user_name, hashed_password)
return {"message": "success", "user_id": user.user_id}
class UserLogin(BaseModel):
@ -38,7 +38,7 @@ class UserLogin(BaseModel):
@auth_router.post("/login")
async def login_user(user_login: UserLogin):
postgres_database = ray_actor_hook("postgres_database")
user = await postgres_database.login_user.remote(user_login.user_name)
user = await postgres_database.auth_database.remote("login_user", user_login.user_name)
if user.user_name != user_login.user_name:
pass
token = await run_in_threadpool(Accessor.login_hashed_password, user, user_login.password)

View File

@ -16,6 +16,8 @@ from fastapi import APIRouter, Depends
from pydantic import BaseModel
from typing import Literal
from pretor.utils.access import TokenData, Accessor
from pretor.utils.check_user.role_check import RoleChecker
from pretor.core.database.table.user import UserAuthority
from typing import Dict
from pretor.core.global_state_machine.model_provider.base_provider import Provider
from pretor.utils.ray_hook import ray_actor_hook
@ -30,9 +32,9 @@ class ProviderRegister(BaseModel):
@provider_router.post("")
async def create_provider(provider_register: ProviderRegister,
token_data: TokenData = Depends(Accessor.get_current_user)) -> None:
token_data: TokenData = Depends(RoleChecker(allowed_roles=UserAuthority.USER))) -> None:
global_state_machine = ray_actor_hook("global_state_machine")
await global_state_machine.add_provider.remote(provider_type=provider_register.provider_type,
await global_state_machine.add_provider_wrap.remote(provider_type=provider_register.provider_type,
provider_title=provider_register.provider_title,
provider_url=provider_register.provider_url,
provider_apikey=provider_register.provider_apikey,
@ -42,5 +44,5 @@ async def create_provider(provider_register: ProviderRegister,
@provider_router.get("/list")
async def get_provider_list(_: TokenData = Depends(Accessor.get_current_user)) -> Dict[str, Provider]:
global_state_machine = ray_actor_hook("global_state_machine")
provider_list: Dict[str, Provider] = await global_state_machine.get_provider_list.remote()
provider_list: Dict[str, Provider] = await global_state_machine.provider_manager.remote("get_provider_list")
return {"provider_list": provider_list}

View File

@ -17,13 +17,15 @@ import viceroy
from pretor.core.workflow.workflow_template_generator.workflow_template import WorkflowTemplate
from pretor.utils.ray_hook import ray_actor_hook
from fastapi import APIRouter, Depends
from pretor.utils.access import TokenData, Accessor
from pretor.utils.access import TokenData
from pretor.utils.check_user.role_check import RoleChecker
from pretor.core.database.table.user import UserAuthority
resource_router = APIRouter(prefix="/api/v1/resource")
@resource_router.post("/workflow_template")
async def create_workflow_template(workflow_template: WorkflowTemplate,
_: TokenData = Depends(Accessor.get_current_user)):
_: TokenData = Depends(RoleChecker(allowed_roles=UserAuthority.USER))):
global_state_machine = ray_actor_hook("global_state_machine")
await global_state_machine.workflow_template_generate.remote(workflow_template)
return {"message": "创建成功"}
@ -36,7 +38,7 @@ class Skill(BaseModel):
@resource_router.post("/skill")
async def install_skill(skill: Skill,
_: TokenData = Depends(Accessor.get_current_user)):
_: TokenData = Depends(RoleChecker(allowed_roles=UserAuthority.USER))):
global_state_machine = ray_actor_hook("global_state_machine")
# noinspection PyUnresolvedReferences
await viceroy.install_skill_async(url = skill.repo_url,
@ -46,6 +48,6 @@ async def install_skill(skill: Skill,
skill_name = skill.path.split("/")[-1]
else:
skill_name = skill.repo_url.split("/")[-1]
await global_state_machine.add_skill.remote(skill_name)
await global_state_machine.skill_manager.remote("add_skill", skill_name)
return {"message": "创建成功"}

View File

@ -35,49 +35,22 @@ class PostgresDatabase:
self.async_engine = create_async_engine(database_url, echo=True)
self.async_session_maker = sessionmaker(self.async_engine, class_=AsyncSession, expire_on_commit=False)
self.auth_database = AuthDatabase(self.async_session_maker)
self.provider_database = ProviderDatabase(self.async_session_maker)
self.individual_database = IndividualDatabase(self.async_session_maker)
self._auth_database = AuthDatabase(self.async_session_maker)
self._provider_database = ProviderDatabase(self.async_session_maker)
self._individual_database = IndividualDatabase(self.async_session_maker)
async def init_db(self) -> None:
async with self.async_engine.begin() as conn:
await conn.run_sync(SQLModel.metadata.create_all)
# provider_database操作
async def get_providers(self):
return await self.provider_database.get_provider()
async def auth_database(self, method_name: str, *args, **kwargs):
method = getattr(self._auth_database, method_name)
return await method(*args, **kwargs)
async def add_provider(self, **kwargs):
return await self.provider_database.add_provider(**kwargs)
async def provider_database(self, method_name: str, *args, **kwargs):
method = getattr(self._provider_database, method_name)
return await method(*args, **kwargs)
# auth_database操作
async def add_user(self, **kwargs):
return await self.auth_database.add_user(**kwargs)
async def change_password(self, **kwargs):
return await self.auth_database.change_password(**kwargs)
async def delete_user(self, **kwargs):
return await self.auth_database.delete_user(**kwargs)
async def login_user(self, **kwargs):
return await self.auth_database.login_user(**kwargs)
async def get_user_authority(self, **kwargs):
return await self.auth_database.get_user_authority(**kwargs)
##individual_database 操作
async def add_worker_individual(self, **kwargs):
return await self.individual_database.add_worker_individual(**kwargs)
async def get_worker_individual(self, agent_id: str):
return await self.individual_database.get_worker_individual(agent_id)
async def get_worker_individual_list(self, owner_id: str):
return await self.individual_database.get_worker_individual_list(owner_id)
async def update_worker_individual(self, agent_id: str, **kwargs):
return await self.individual_database.update_worker_individual(agent_id, **kwargs)
async def delete_worker_individual(self, agent_id: str):
return await self.individual_database.delete_worker_individual(agent_id)
async def individual_database(self, method_name: str, *args, **kwargs):
method = getattr(self._individual_database, method_name)
return await method(*args, **kwargs)

View File

@ -15,13 +15,13 @@
from sqlmodel import SQLModel, Field
from typing import List
from sqlalchemy import Column, JSON
from typing import Optional, Literal
from typing import Optional
class Provider(SQLModel, table=True):
__tablename__ = "provider"
provider_id: str = Field(primary_key=True)
provider_title: str = Field(index=True)
provider_type: Literal["openai", "vllm"]
provider_type: str
provider_url: Optional[str]
provider_apikey: Optional[str]

View File

@ -15,37 +15,63 @@
import ray
from pretor.core.global_state_machine.provider_manager import ProviderManager
from pretor.core.global_state_machine.tool_manager import GlobalToolManager
from pretor.core.global_state_machine.model_provider import Provider, ProviderArgs
import httpx
import pathlib
import json
from loguru import logger
from typing import Dict, Literal, List
from typing import Dict
from pretor.core.database.postgres import PostgresDatabase
from pretor.api.platform.event import PretorEvent
import asyncio
from pretor.core.workflow.workflow import PretorWorkflow
from pretor.core.workflow.workflow_template_generator.workflow_template import WorkflowTemplate
from pretor.core.workflow.workflow_template_manager import WorkflowManager
from pretor.core.global_state_machine.skill_manager import GlobalSkillManager
from pretor.plugin.tool_plugin.base_tool import BaseToolData
@ray.remote
class GlobalStateMachine:
def __init__(self, postgres_database: PostgresDatabase):
self.event_dict: Dict[int, PretorEvent] = {}
self.global_provider_manager = ProviderManager(postgres_database)
self.global_tool_manager = GlobalToolManager()
self.global_workflow_template_manager = WorkflowManager()
self.global_skill_manager = GlobalSkillManager()
self.event_dict: Dict[str, PretorEvent] = {}
self._global_provider_manager = ProviderManager(postgres_database)
self._global_tool_manager = GlobalToolManager()
self._global_workflow_template_manager = WorkflowManager()
self._global_skill_manager = GlobalSkillManager()
self.postgres_database = postgres_database
async def init_state_machine(self):
await self.global_provider_manager.init_provider_register(self.postgres_database)
await self._global_provider_manager.init_provider_register(self.postgres_database)
async def add_provider_wrap(self, provider_type, provider_title, provider_url, provider_apikey, provider_owner):
return await self._global_provider_manager.add_provider(
provider_type=provider_type,
provider_title=provider_title,
provider_url=provider_url,
provider_apikey=provider_apikey,
provider_owner=provider_owner,
postgres_database=self.postgres_database
)
async def provider_manager(self, method_name: str, *args, **kwargs):
method = getattr(self._global_provider_manager, method_name)
if asyncio.iscoroutinefunction(method):
return await method(*args, **kwargs)
return method(*args, **kwargs)
async def tool_manager(self, method_name: str, *args, **kwargs):
method = getattr(self._global_tool_manager, method_name)
if asyncio.iscoroutinefunction(method):
return await method(*args, **kwargs)
return method(*args, **kwargs)
async def workflow_template_manager(self, method_name: str, *args, **kwargs):
method = getattr(self._global_workflow_template_manager, method_name)
if asyncio.iscoroutinefunction(method):
return await method(*args, **kwargs)
return method(*args, **kwargs)
async def skill_manager(self, method_name: str, *args, **kwargs):
method = getattr(self._global_skill_manager, method_name)
if asyncio.iscoroutinefunction(method):
return await method(*args, **kwargs)
return method(*args, **kwargs)
###以下方法为event_dict方法
def add_event(self, event: PretorEvent) -> None:
@ -79,108 +105,3 @@ class GlobalStateMachine:
async def get_received(self, event_id) -> str:
return await self.event_dict[event_id].receive_queue.get()
###以下方法为global_provider_manager方法
async def add_provider(self, provider_type: Literal["openai", "gemini", "claude"],
provider_title: str,
provider_url: str,
provider_apikey: str,
provider_owner: int) -> None:
"""
add_provider方法注册供应商适配器(provider_manager方法)
Args
provider_type: 注册商接口类型目前只支持openai,gemini和claude接口
provider_title: 供应商名称为供应商提供的别名
provider_url: 供应商url
provider_apikey: 供应商所需要的apikey
Returns:
"""
provider_args: ProviderArgs = ProviderArgs(provider_title=provider_title,
provider_url=provider_url,
provider_apikey=provider_apikey,
provider_owner=provider_owner)
try:
provider_class = self.global_provider_manager.provider_mapper.get(provider_type, None)
if provider_class is None:
logger.warning(f"Provider type {provider_type} is not supported.")
return None
provider: Provider = await provider_class.create_model(provider_args)
provider.provider_owner = provider_owner
self.global_provider_manager.provider_register[provider_title] = provider
await self.postgres_database.add_provider.remote(provider_title=provider.provider_title,
provider_url=provider.provider_url,
provider_apikey=provider.provider_apikey,
provider_models=provider.provider_models,
provider_type=provider.provider_type,
provider_owner=provider.provider_owner)
logger.info(f"已添加适配器{provider_title}")
except httpx.RequestError as e:
logger.warning(f"[{provider_args.provider_title}] 网络请求异常: {e}")
except Exception as e:
logger.warning(f"[{provider_args.provider_title}] 解析模型列表时发生错误: {e}")
def get_provider_list(self) -> Dict[str, Provider]:
"""
get_provider_list方法获取注册表(provider_manager方法)
Returns:
返回provider_register属性字典
"""
return self.global_provider_manager.provider_register
def get_provider(self, provider_title) -> Provider:
"""
get_provider方法获取供应商信息(provider_manager方法)
Args:
provider_title:provider名称
Returns:
Provider对象返回注册在self.global_provider_manager.provider_register的供应商
"""
provider = self.global_provider_manager.provider_register.get(provider_title)
return provider
###以下为global_tool_manager方法
def get_tool_list(self, agent_name: str) -> Dict[str, BaseToolData]:
"""
获取工具表方法
Args:
agent_name: agent的名字
Returns:
返回该agent的tool类型为dict
"""
tool_list = self.global_tool_manager.tool_mapper.get(agent_name, {})
return tool_list
###以下为workflow_template_manager方法
def workflow_template_generate(self, workflow_template: WorkflowTemplate) -> None:
self.global_workflow_template_manager.generate_workflow_template(workflow_template)
def get_workflow_template_list(self) -> List[Dict[str, str]]:
return self.global_workflow_template_manager.workflow_templates_registry
###以下为skill_manager方法
def add_skill(self, skill_name: str):
skill_plugin_dir = pathlib.Path(__file__).parent.parent.parent / "plugin" / "skill_plugin" / skill_name
json_path = skill_plugin_dir / "skill.json"
try:
with open(json_path, "r", encoding="utf-8") as f:
skill = json.load(f)
name = skill.get("name")
if name:
self.global_skill_manager.skill_mapper[name] = (
skill.get("description", ""),
skill.get("instructions", "")
)
except (json.JSONDecodeError, OSError) as e:
print(f"警告: 加载插件 {skill_name} 失败: {e}")

View File

@ -33,6 +33,45 @@ class ProviderManager:
self.provider_register = {}
async def init_provider_register(self, postgres) -> None:
providers = await postgres.get_providers.remote()
providers = await postgres.provider_database.remote("get_provider")
for provider in providers:
self.provider_register[provider.provider_title] = provider
async def add_provider(self, provider_type, provider_title, provider_url, provider_apikey, provider_owner, postgres_database) -> None:
from pretor.core.global_state_machine.model_provider import ProviderArgs
from loguru import logger
import httpx
provider_args: ProviderArgs = ProviderArgs(provider_title=provider_title,
provider_url=provider_url,
provider_apikey=provider_apikey,
provider_owner=provider_owner)
try:
provider_class = self.provider_mapper.get(provider_type, None)
if provider_class is None:
logger.warning(f"Provider type {provider_type} is not supported.")
return None
provider: Provider = await provider_class.create_model(provider_args)
provider.provider_owner = provider_owner
self.provider_register[provider_title] = provider
await postgres_database.provider_database.remote("add_provider", provider_title=provider.provider_title,
provider_url=provider.provider_url,
provider_apikey=provider.provider_apikey,
provider_models=provider.provider_models,
provider_type=provider.provider_type,
provider_owner=provider.provider_owner)
logger.info(f"已添加适配器{provider_title}")
except httpx.RequestError as e:
logger.warning(f"[{provider_args.provider_title}] 网络请求异常: {e}")
except Exception as e:
logger.warning(f"[{provider_args.provider_title}] 解析模型列表时发生错误: {e}")
def get_provider_list(self):
return self.provider_register
def get_provider(self, provider_title):
return self.provider_register.get(provider_title)

View File

@ -55,7 +55,7 @@ class ConsciousnessNode:
"请确保所有的思考和生成过程符合逻辑,严密且高质量。"
)
output_type = Union[ForSupervisoryNode, ForWorkflow, ForWorkflowEngine]
provider: Provider = global_state_machine.get_provider.remote(provider_title)
provider: Provider = global_state_machine.provider_manager.remote("get_provider", provider_title)
agent_factory = AgentFactory()
self.agent = agent_factory.create_agent(provider=provider,
model_id=model_id,
@ -85,7 +85,7 @@ class ConsciousnessNode:
else:
logger.error(f"ConsciousnessNode: 未知或不匹配的返回类型: {type(result)}")
return None
except Exception as e:
except Exception:
logger.exception("ConsciousnessNode在执行working时发生严重错误")
return None

View File

@ -51,7 +51,7 @@ class ControlNode:
"请注意:你的输出应当具体、实用,直接提供任务所要求的结果,不要做过多无关的寒暄。"
)
output_type = ForWorkflow
provider: Provider = global_state_machine.get_provider.remote(provider_title)
provider: Provider = global_state_machine.provider_manager.remote("get_provider", provider_title)
agent_factory = AgentFactory()
self.agent = agent_factory.create_agent(provider=provider,
model_id=model_id,

View File

@ -54,7 +54,7 @@ class SupervisoryNode:
"请保持冷静、专业,并严格遵循上述路由规则。"
)
output_type = Union[ForConsciousnessNode, ForUser]
provider: Provider = await global_state_machine.get_provider.remote(provider_title)
provider: Provider = await global_state_machine.provider_manager.remote("get_provider", provider_title)
agent_factory = AgentFactory()
self.agent = agent_factory.create_agent(provider=provider,
model_id=model_id,
@ -155,7 +155,7 @@ class SupervisoryNode:
time_str = datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")
try:
global_state_machine = ray_actor_hook("global_state_machine")
workflow_template_dict = await global_state_machine.get_workflow_template_list.remote()
workflow_template_dict = await global_state_machine.workflow_template_manager.remote("get_workflow_template_list")
available_templates_str = "\n".join([f"- 名称: {k}, 描述/内容: {v}" for k, v in
workflow_template_dict.items()]) if workflow_template_dict else "暂无注册的工作流模板"
deps = SupervisoryNodeDeps(

View File

@ -158,7 +158,7 @@ class WorkflowEngine:
logger.info(f"Supervisory 最终回复:{user_response}")
else:
logger.warning("未提供 supervisory_node 句柄,跳过用户反馈生成。")
except Exception as e:
except Exception:
logger.exception("生成工作流执行汇报时发生错误")
async def _dispatch_to_node(self, step: WorkStep, input_data: Any) -> tuple[Any, bool]:
@ -207,7 +207,7 @@ class WorkflowEngine:
else:
raise WorkflowError(f"未知的节点类型:{step.node}")
except Exception as e:
except Exception:
logger.exception(f"节点 {step.node} 执行动作 {step.action} 失败")
return None, False
@ -245,7 +245,7 @@ class WorkflowRunningEngine:
self.consciousness_node = consciousness_node
self.control_node = control_node
self.supervisory_node = supervisory_node
self.global_state_machine = ray_actor_hook("global_state_machine")
self.global_state_machine = ray_actor_hook("global_state_machine").global_state_machine
async def run(self):
self.runner_engine = {

View File

@ -40,6 +40,5 @@ class WorkflowManager:
try:
workflow_template = self.workflow_template_generator.generate_workflow_template(workflow_template=workflow_template)
self.workflow_templates_registry[workflow_template.name] = workflow_template.desc
except Exception as e:
except Exception:
logger.exception("Failed to generate workflow template")

View File

@ -21,7 +21,7 @@ from pretor.utils.ray_hook import ray_actor_hook
@lru_cache
async def get_authority(user_id: str) -> UserAuthority:
postgres_database = ray_actor_hook("postgres_database")
user_authority = await postgres_database.get_user_authority.remote(user_id=user_id)
user_authority = await postgres_database.auth_database.remote("get_user_authority", user_id=user_id)
return user_authority
class RoleChecker:

View File

@ -49,7 +49,7 @@ def del_tool_cache(tool_name: str) -> None:
@lru_cache(maxsize=1)
async def get_tool(agent_name: str) -> List[Callable]:
global_state_machine = ray_actor_hook("global_state_machine")
_tool_list = await global_state_machine.get_tool_list.remote(agent_name)
_tool_list = await global_state_machine.tool_manager.remote("get_tool_list", agent_name)
tool_list = []
for tool_name in _tool_list.keys():
tool_func = _get_tool_func(tool_name)

View File

@ -65,6 +65,8 @@ async def test_postgres_database(mock_env_get, mock_provider_db, mock_auth_db, m
)
mock_auth_db.assert_called_once()
mock_provider_db.assert_called_once()
mock_auth_db.return_value.get_user_authority = AsyncMock(return_value="test_auth")
assert await db.auth_database("get_user_authority", user_id="123") == "test_auth"
with patch("pretor.core.database.postgres.SQLModel.metadata.create_all") as mock_create_all:
await db.init_db()

View File

@ -2,6 +2,6 @@ import pytest
from pretor.core.database.table.user import User
def test_user_table():
user = User(user_name="name", hashed_password="pw")
user = User(user_id="id", user_name="name", hashed_password="pw")
assert User.__tablename__ == 'user'
assert user.user_name == "name"

View File

@ -40,7 +40,6 @@ def mock_postgres():
@pytest.fixture
def gsm(mock_postgres):
with patch("pretor.core.global_state_machine.global_state_machine.ProviderManager") as mock_pm:
manager = GlobalStateMachine(mock_postgres)
return manager
@ -103,24 +102,24 @@ async def test_add_provider_success(gsm, mock_postgres):
mock_provider.provider_type = "openai"
mock_provider_class.create_model.return_value = mock_provider
gsm.global_provider_manager.provider_mapper = {"openai": mock_provider_class}
gsm.global_provider_manager.provider_register = {}
gsm._global_provider_manager.provider_mapper = {"openai": mock_provider_class}
gsm._global_provider_manager.provider_register = {}
mock_add_provider = AsyncMock()
mock_postgres.provider_database.add_provider.remote = mock_add_provider
mock_postgres.provider_database.remote = mock_add_provider
await gsm.add_provider("openai", "title", "url", "key", 1)
await gsm.add_provider_wrap("openai", "title", "url", "key", 1)
assert gsm.global_provider_manager.provider_register["title"] == mock_provider
assert gsm._global_provider_manager.provider_register["title"] == mock_provider
mock_add_provider.assert_called_once()
assert mock_provider.provider_owner == 1
@pytest.mark.asyncio
async def test_add_provider_unsupported(gsm):
gsm.global_provider_manager.provider_mapper = {}
with patch("pretor.core.global_state_machine.global_state_machine.logger") as mock_logger:
await gsm.add_provider("magic", "title", "url", "key", 1)
gsm._global_provider_manager.provider_mapper = {}
with patch("loguru.logger") as mock_logger:
await gsm.add_provider_wrap("magic", "title", "url", "key", 1)
mock_logger.warning.assert_called_with("Provider type magic is not supported.")
@ -129,10 +128,10 @@ async def test_add_provider_request_error(gsm):
from httpx import RequestError
mock_provider_class = AsyncMock()
mock_provider_class.create_model.side_effect = RequestError("Network Error", request=MagicMock())
gsm.global_provider_manager.provider_mapper = {"openai": mock_provider_class}
gsm._global_provider_manager.provider_mapper = {"openai": mock_provider_class}
with patch("pretor.core.global_state_machine.global_state_machine.logger") as mock_logger:
await gsm.add_provider("openai", "title", "url", "key", 1)
with patch("loguru.logger") as mock_logger:
await gsm.add_provider_wrap("openai", "title", "url", "key", 1)
mock_logger.warning.assert_called_once()
assert "网络请求异常" in mock_logger.warning.call_args[0][0]
@ -141,18 +140,18 @@ async def test_add_provider_request_error(gsm):
async def test_add_provider_generic_error(gsm):
mock_provider_class = AsyncMock()
mock_provider_class.create_model.side_effect = ValueError("Some Error")
gsm.global_provider_manager.provider_mapper = {"openai": mock_provider_class}
gsm._global_provider_manager.provider_mapper = {"openai": mock_provider_class}
with patch("pretor.core.global_state_machine.global_state_machine.logger") as mock_logger:
await gsm.add_provider("openai", "title", "url", "key", 1)
with patch("loguru.logger") as mock_logger:
await gsm.add_provider_wrap("openai", "title", "url", "key", 1)
mock_logger.warning.assert_called_once()
assert "解析模型列表时发生错误" in mock_logger.warning.call_args[0][0]
def test_get_provider_list_and_get_provider(gsm):
mock_provider = MagicMock()
gsm.global_provider_manager.provider_register = {"p1": mock_provider}
gsm._global_provider_manager.provider_register = {"p1": mock_provider}
assert gsm.get_provider_list() == {"p1": mock_provider}
assert gsm.get_provider("p1") == mock_provider
assert gsm.get_provider("missing") is None
assert gsm._global_provider_manager.get_provider_list() == {"p1": mock_provider}
assert gsm._global_provider_manager.get_provider("p1") == mock_provider
assert gsm._global_provider_manager.get_provider("missing") is None

View File

@ -16,6 +16,8 @@ async def test_provider_manager_init():
mock_postgres.get_providers.remote = AsyncMock(return_value=[mock_provider1, mock_provider2])
manager = ProviderManager(mock_postgres)
mock_postgres.provider_database = MagicMock()
mock_postgres.provider_database.remote = AsyncMock(return_value=[mock_provider1, mock_provider2])
await manager.init_provider_register(mock_postgres)
assert "openai" in manager.provider_mapper

View File

@ -119,20 +119,17 @@ async def test_workflow_running_engine_runner():
from pretor.core.individual.consciousness_node.template import ForWorkflowEngine
mock_consciousness = MagicMock()
mock_wf = MagicMock()
mock_wf.trace_id = "test_trace"
mock_wf.title = "test_title"
mock_result = MagicMock(spec=ForWorkflowEngine)
mock_result.workflow = mock_wf
mock_consciousness.working.remote = AsyncMock(return_value=mock_result)
engine = WorkflowRunningEngine(mock_consciousness, "control", "supervisor")
engine.workflow_queue = asyncio.Queue()
# Use real PretorEvent to avoid Pydantic validation errors on MagicMock properties
mock_event = PretorEvent(
platform="test_platform",
user_id="test_user",
@ -142,18 +139,19 @@ async def test_workflow_running_engine_runner():
)
await engine.workflow_queue.put(mock_event)
with patch("pretor.core.workflow.workflow_runner.WorkflowEngine") as mock_wf_engine_cls, \
patch("pretor.core.workflow.workflow_runner.ray_actor_hook") as mock_hook:
mock_gsm = MagicMock()
mock_gsm.update_workflow.remote = AsyncMock()
mock_hook.return_value = mock_gsm
with patch("pretor.core.workflow.workflow_runner.WorkflowEngine") as mock_wf_engine_cls, patch("builtins.open", new_callable=MagicMock) as mock_open:
# Instead of patching hook, we inject it directly
engine.global_state_machine = AsyncMock()
mock_open.return_value.__enter__.return_value.read.return_value = '{}'
mock_engine_instance = MagicMock()
mock_engine_instance.run = AsyncMock()
mock_wf_engine_cls.return_value = mock_engine_instance
task = asyncio.create_task(engine.runner(1))
await asyncio.sleep(0.05) # Give runner time to process one item
task.cancel() # Stop the infinite loop
await asyncio.sleep(0.05)
task.cancel()
mock_wf_engine_cls.assert_called_with(mock_wf, mock_consciousness, "control", "supervisor")

View File

@ -24,7 +24,7 @@ def test_pretor_workflow_validation_success():
ws1 = WorkStep(step=1, node="control_node", action="a1", desc="d1")
ws2 = WorkStep(step=2, node="supervisory_node", action="a2", desc="d2")
wg = WorkerGroup(name="g1", primary_individual={"coder": 1}, composite_individual={})
wf = PretorWorkflow(title="wf1", workgroup_list=[wg], work_link=[ws1, ws2], trace_id="t", event_info={"platform":"a", "username":"b"})
wf = PretorWorkflow(title="wf1", workgroup_list=[wg], work_link=[ws1, ws2], trace_id="t", event_info={"platform":"a", "user_name":"b"})
assert wf.title == "wf1"
def test_pretor_workflow_validation_error_step_discontinuous():
@ -32,7 +32,7 @@ def test_pretor_workflow_validation_error_step_discontinuous():
ws2 = WorkStep(step=3, node="supervisory_node", action="a2", desc="d2")
wg = WorkerGroup(name="g1", primary_individual={}, composite_individual={})
with pytest.raises(ValueError, match="工作链步数不连续"):
PretorWorkflow(title="wf1", workgroup_list=[wg], work_link=[ws1, ws2], trace_id="t", event_info={"platform":"a", "username":"b"})
PretorWorkflow(title="wf1", workgroup_list=[wg], work_link=[ws1, ws2], trace_id="t", event_info={"platform":"a", "user_name":"b"})
def test_pretor_workflow_validation_error_jump_out_of_bounds():
lg = LogicGate(if_fail="jump_to_step_3", if_pass="continue")
@ -40,14 +40,14 @@ def test_pretor_workflow_validation_error_jump_out_of_bounds():
ws2 = WorkStep(step=2, node="supervisory_node", action="a2", desc="d2")
wg = WorkerGroup(name="g1", primary_individual={}, composite_individual={})
with pytest.raises(ValueError, match="跳转目标 Step 3 越界了"):
PretorWorkflow(title="wf1", workgroup_list=[wg], work_link=[ws1, ws2], trace_id="t", event_info={"platform":"a", "username":"b"})
PretorWorkflow(title="wf1", workgroup_list=[wg], work_link=[ws1, ws2], trace_id="t", event_info={"platform":"a", "user_name":"b"})
def test_pretor_workflow_validation_error_jump_format_error():
lg = LogicGate(if_fail="jump_to_step_invalid", if_pass="continue")
ws1 = WorkStep(step=1, node="control_node", action="a1", desc="d1", logic_gate=lg)
wg = WorkerGroup(name="g1", primary_individual={}, composite_individual={})
with pytest.raises(ValueError, match="LogicGate 格式错误"):
PretorWorkflow(title="wf1", workgroup_list=[wg], work_link=[ws1], trace_id="t", event_info={"platform":"a", "username":"b"})
PretorWorkflow(title="wf1", workgroup_list=[wg], work_link=[ws1], trace_id="t", event_info={"platform":"a", "user_name":"b"})
def test_workflow_status():
status = WorkflowStatus()