wip: 增加了api的简单权限校验,使用反射对于actor进行重构
This commit is contained in:
parent
81da2e9f81
commit
be0d78c855
|
|
@ -21,6 +21,8 @@ from pretor.utils.access import Accessor, TokenData
|
||||||
from pretor.core.database.table.individual import AgentType
|
from pretor.core.database.table.individual import AgentType
|
||||||
from fastapi import HTTPException
|
from fastapi import HTTPException
|
||||||
from typing import Optional, List, Dict
|
from typing import Optional, List, Dict
|
||||||
|
from pretor.utils.check_user.role_check import RoleChecker
|
||||||
|
from pretor.core.database.table.user import UserAuthority
|
||||||
|
|
||||||
agent_router = APIRouter(prefix="/api/v1/agent", tags=["agent"])
|
agent_router = APIRouter(prefix="/api/v1/agent", tags=["agent"])
|
||||||
|
|
||||||
|
|
@ -35,7 +37,7 @@ class AgentLocalRegister(BaseModel):
|
||||||
|
|
||||||
@agent_router.post("")
|
@agent_router.post("")
|
||||||
async def load_agent(agent_register: Union[AgentRegister, AgentLocalRegister],
|
async def load_agent(agent_register: Union[AgentRegister, AgentLocalRegister],
|
||||||
_: TokenData = Depends(Accessor.get_current_user)):
|
_: TokenData = Depends(RoleChecker(allowed_roles=UserAuthority.USER))):
|
||||||
global_state_machine = ray_actor_hook("global_state_machine")
|
global_state_machine = ray_actor_hook("global_state_machine")
|
||||||
if isinstance(agent_register, AgentLocalRegister):
|
if isinstance(agent_register, AgentLocalRegister):
|
||||||
pass
|
pass
|
||||||
|
|
@ -82,18 +84,18 @@ class WorkerIndividualUpdate(BaseModel):
|
||||||
|
|
||||||
@agent_router.post("/worker")
|
@agent_router.post("/worker")
|
||||||
async def create_worker_individual(worker_data: WorkerIndividualCreate,
|
async def create_worker_individual(worker_data: WorkerIndividualCreate,
|
||||||
token_data: TokenData = Depends(Accessor.get_current_user)):
|
token_data: TokenData = Depends(RoleChecker(allowed_roles=UserAuthority.USER))):
|
||||||
postgres_database = ray_actor_hook("postgres_database")
|
postgres_database = ray_actor_hook("postgres_database")
|
||||||
data_dict = worker_data.model_dump()
|
data_dict = worker_data.model_dump()
|
||||||
data_dict["owner_id"] = token_data.user_id
|
data_dict["owner_id"] = token_data.user_id
|
||||||
worker = await postgres_database.add_worker_individual.remote(**data_dict)
|
worker = await postgres_database.individual_database.remote("add_worker_individual", **data_dict)
|
||||||
return {"message": "success", "agent_id": worker.agent_id}
|
return {"message": "success", "agent_id": worker.agent_id}
|
||||||
|
|
||||||
|
|
||||||
@agent_router.get("/worker")
|
@agent_router.get("/worker")
|
||||||
async def get_worker_individual_list(token_data: TokenData = Depends(Accessor.get_current_user)):
|
async def get_worker_individual_list(token_data: TokenData = Depends(Accessor.get_current_user)):
|
||||||
postgres_database = ray_actor_hook("postgres_database")
|
postgres_database = ray_actor_hook("postgres_database")
|
||||||
workers = await postgres_database.get_worker_individual_list.remote(owner_id=token_data.user_id)
|
workers = await postgres_database.individual_database.remote("get_worker_individual_list", owner_id=token_data.user_id)
|
||||||
return {"workers": workers}
|
return {"workers": workers}
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -101,7 +103,7 @@ async def get_worker_individual_list(token_data: TokenData = Depends(Accessor.ge
|
||||||
async def get_worker_individual(agent_id: str,
|
async def get_worker_individual(agent_id: str,
|
||||||
token_data: TokenData = Depends(Accessor.get_current_user)):
|
token_data: TokenData = Depends(Accessor.get_current_user)):
|
||||||
postgres_database = ray_actor_hook("postgres_database")
|
postgres_database = ray_actor_hook("postgres_database")
|
||||||
worker = await postgres_database.get_worker_individual.remote(agent_id=agent_id)
|
worker = await postgres_database.individual_database.remote("get_worker_individual", agent_id=agent_id)
|
||||||
if not worker:
|
if not worker:
|
||||||
raise HTTPException(status_code=404, detail="Agent not found")
|
raise HTTPException(status_code=404, detail="Agent not found")
|
||||||
if worker.owner_id != token_data.user_id:
|
if worker.owner_id != token_data.user_id:
|
||||||
|
|
@ -114,14 +116,14 @@ async def update_worker_individual(agent_id: str,
|
||||||
worker_data: WorkerIndividualUpdate,
|
worker_data: WorkerIndividualUpdate,
|
||||||
token_data: TokenData = Depends(Accessor.get_current_user)):
|
token_data: TokenData = Depends(Accessor.get_current_user)):
|
||||||
postgres_database = ray_actor_hook("postgres_database")
|
postgres_database = ray_actor_hook("postgres_database")
|
||||||
worker = await postgres_database.get_worker_individual.remote(agent_id=agent_id)
|
worker = await postgres_database.individual_database.remote("get_worker_individual", agent_id=agent_id)
|
||||||
if not worker:
|
if not worker:
|
||||||
raise HTTPException(status_code=404, detail="Agent not found")
|
raise HTTPException(status_code=404, detail="Agent not found")
|
||||||
if worker.owner_id != token_data.user_id:
|
if worker.owner_id != token_data.user_id:
|
||||||
raise HTTPException(status_code=403, detail="Forbidden: You do not own this agent")
|
raise HTTPException(status_code=403, detail="Forbidden: You do not own this agent")
|
||||||
|
|
||||||
update_data = worker_data.model_dump(exclude_unset=True)
|
update_data = worker_data.model_dump(exclude_unset=True)
|
||||||
updated_worker = await postgres_database.update_worker_individual.remote(agent_id=agent_id, **update_data)
|
updated_worker = await postgres_database.individual_database.remote("update_worker_individual", agent_id=agent_id, **update_data)
|
||||||
return {"message": "success", "worker": updated_worker}
|
return {"message": "success", "worker": updated_worker}
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -129,11 +131,10 @@ async def update_worker_individual(agent_id: str,
|
||||||
async def delete_worker_individual(agent_id: str,
|
async def delete_worker_individual(agent_id: str,
|
||||||
token_data: TokenData = Depends(Accessor.get_current_user)):
|
token_data: TokenData = Depends(Accessor.get_current_user)):
|
||||||
postgres_database = ray_actor_hook("postgres_database")
|
postgres_database = ray_actor_hook("postgres_database")
|
||||||
worker = await postgres_database.get_worker_individual.remote(agent_id=agent_id)
|
worker = await postgres_database.individual_database.remote("get_worker_individual", agent_id=agent_id)
|
||||||
if not worker:
|
if not worker:
|
||||||
raise HTTPException(status_code=404, detail="Agent not found")
|
raise HTTPException(status_code=404, detail="Agent not found")
|
||||||
if worker.owner_id != token_data.user_id:
|
if worker.owner_id != token_data.user_id:
|
||||||
raise HTTPException(status_code=403, detail="Forbidden: You do not own this agent")
|
raise HTTPException(status_code=403, detail="Forbidden: You do not own this agent")
|
||||||
|
await postgres_database.individual_database.remote("delete_worker_individual", agent_id=agent_id)
|
||||||
await postgres_database.delete_worker_individual.remote(agent_id=agent_id)
|
|
||||||
return {"message": "success"}
|
return {"message": "success"}
|
||||||
|
|
@ -28,7 +28,7 @@ class UserRegister(BaseModel):
|
||||||
async def create_user(user_register: UserRegister):
|
async def create_user(user_register: UserRegister):
|
||||||
postgres_database = ray_actor_hook("postgres_database")
|
postgres_database = ray_actor_hook("postgres_database")
|
||||||
hashed_password = await run_in_threadpool(Accessor.hash_password, user_register.password)
|
hashed_password = await run_in_threadpool(Accessor.hash_password, user_register.password)
|
||||||
user = await postgres_database.add_user.remote(user_register.user_name, hashed_password)
|
user = await postgres_database.auth_database.remote("add_user", user_register.user_name, hashed_password)
|
||||||
return {"message": "success", "user_id": user.user_id}
|
return {"message": "success", "user_id": user.user_id}
|
||||||
|
|
||||||
class UserLogin(BaseModel):
|
class UserLogin(BaseModel):
|
||||||
|
|
@ -38,7 +38,7 @@ class UserLogin(BaseModel):
|
||||||
@auth_router.post("/login")
|
@auth_router.post("/login")
|
||||||
async def login_user(user_login: UserLogin):
|
async def login_user(user_login: UserLogin):
|
||||||
postgres_database = ray_actor_hook("postgres_database")
|
postgres_database = ray_actor_hook("postgres_database")
|
||||||
user = await postgres_database.login_user.remote(user_login.user_name)
|
user = await postgres_database.auth_database.remote("login_user", user_login.user_name)
|
||||||
if user.user_name != user_login.user_name:
|
if user.user_name != user_login.user_name:
|
||||||
pass
|
pass
|
||||||
token = await run_in_threadpool(Accessor.login_hashed_password, user, user_login.password)
|
token = await run_in_threadpool(Accessor.login_hashed_password, user, user_login.password)
|
||||||
|
|
|
||||||
|
|
@ -16,6 +16,8 @@ from fastapi import APIRouter, Depends
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
from typing import Literal
|
from typing import Literal
|
||||||
from pretor.utils.access import TokenData, Accessor
|
from pretor.utils.access import TokenData, Accessor
|
||||||
|
from pretor.utils.check_user.role_check import RoleChecker
|
||||||
|
from pretor.core.database.table.user import UserAuthority
|
||||||
from typing import Dict
|
from typing import Dict
|
||||||
from pretor.core.global_state_machine.model_provider.base_provider import Provider
|
from pretor.core.global_state_machine.model_provider.base_provider import Provider
|
||||||
from pretor.utils.ray_hook import ray_actor_hook
|
from pretor.utils.ray_hook import ray_actor_hook
|
||||||
|
|
@ -30,9 +32,9 @@ class ProviderRegister(BaseModel):
|
||||||
|
|
||||||
@provider_router.post("")
|
@provider_router.post("")
|
||||||
async def create_provider(provider_register: ProviderRegister,
|
async def create_provider(provider_register: ProviderRegister,
|
||||||
token_data: TokenData = Depends(Accessor.get_current_user)) -> None:
|
token_data: TokenData = Depends(RoleChecker(allowed_roles=UserAuthority.USER))) -> None:
|
||||||
global_state_machine = ray_actor_hook("global_state_machine")
|
global_state_machine = ray_actor_hook("global_state_machine")
|
||||||
await global_state_machine.add_provider.remote(provider_type=provider_register.provider_type,
|
await global_state_machine.add_provider_wrap.remote(provider_type=provider_register.provider_type,
|
||||||
provider_title=provider_register.provider_title,
|
provider_title=provider_register.provider_title,
|
||||||
provider_url=provider_register.provider_url,
|
provider_url=provider_register.provider_url,
|
||||||
provider_apikey=provider_register.provider_apikey,
|
provider_apikey=provider_register.provider_apikey,
|
||||||
|
|
@ -42,5 +44,5 @@ async def create_provider(provider_register: ProviderRegister,
|
||||||
@provider_router.get("/list")
|
@provider_router.get("/list")
|
||||||
async def get_provider_list(_: TokenData = Depends(Accessor.get_current_user)) -> Dict[str, Provider]:
|
async def get_provider_list(_: TokenData = Depends(Accessor.get_current_user)) -> Dict[str, Provider]:
|
||||||
global_state_machine = ray_actor_hook("global_state_machine")
|
global_state_machine = ray_actor_hook("global_state_machine")
|
||||||
provider_list: Dict[str, Provider] = await global_state_machine.get_provider_list.remote()
|
provider_list: Dict[str, Provider] = await global_state_machine.provider_manager.remote("get_provider_list")
|
||||||
return {"provider_list": provider_list}
|
return {"provider_list": provider_list}
|
||||||
|
|
@ -17,13 +17,15 @@ import viceroy
|
||||||
from pretor.core.workflow.workflow_template_generator.workflow_template import WorkflowTemplate
|
from pretor.core.workflow.workflow_template_generator.workflow_template import WorkflowTemplate
|
||||||
from pretor.utils.ray_hook import ray_actor_hook
|
from pretor.utils.ray_hook import ray_actor_hook
|
||||||
from fastapi import APIRouter, Depends
|
from fastapi import APIRouter, Depends
|
||||||
from pretor.utils.access import TokenData, Accessor
|
from pretor.utils.access import TokenData
|
||||||
|
from pretor.utils.check_user.role_check import RoleChecker
|
||||||
|
from pretor.core.database.table.user import UserAuthority
|
||||||
|
|
||||||
resource_router = APIRouter(prefix="/api/v1/resource")
|
resource_router = APIRouter(prefix="/api/v1/resource")
|
||||||
|
|
||||||
@resource_router.post("/workflow_template")
|
@resource_router.post("/workflow_template")
|
||||||
async def create_workflow_template(workflow_template: WorkflowTemplate,
|
async def create_workflow_template(workflow_template: WorkflowTemplate,
|
||||||
_: TokenData = Depends(Accessor.get_current_user)):
|
_: TokenData = Depends(RoleChecker(allowed_roles=UserAuthority.USER))):
|
||||||
global_state_machine = ray_actor_hook("global_state_machine")
|
global_state_machine = ray_actor_hook("global_state_machine")
|
||||||
await global_state_machine.workflow_template_generate.remote(workflow_template)
|
await global_state_machine.workflow_template_generate.remote(workflow_template)
|
||||||
return {"message": "创建成功"}
|
return {"message": "创建成功"}
|
||||||
|
|
@ -36,7 +38,7 @@ class Skill(BaseModel):
|
||||||
|
|
||||||
@resource_router.post("/skill")
|
@resource_router.post("/skill")
|
||||||
async def install_skill(skill: Skill,
|
async def install_skill(skill: Skill,
|
||||||
_: TokenData = Depends(Accessor.get_current_user)):
|
_: TokenData = Depends(RoleChecker(allowed_roles=UserAuthority.USER))):
|
||||||
global_state_machine = ray_actor_hook("global_state_machine")
|
global_state_machine = ray_actor_hook("global_state_machine")
|
||||||
# noinspection PyUnresolvedReferences
|
# noinspection PyUnresolvedReferences
|
||||||
await viceroy.install_skill_async(url = skill.repo_url,
|
await viceroy.install_skill_async(url = skill.repo_url,
|
||||||
|
|
@ -46,6 +48,6 @@ async def install_skill(skill: Skill,
|
||||||
skill_name = skill.path.split("/")[-1]
|
skill_name = skill.path.split("/")[-1]
|
||||||
else:
|
else:
|
||||||
skill_name = skill.repo_url.split("/")[-1]
|
skill_name = skill.repo_url.split("/")[-1]
|
||||||
await global_state_machine.add_skill.remote(skill_name)
|
await global_state_machine.skill_manager.remote("add_skill", skill_name)
|
||||||
return {"message": "创建成功"}
|
return {"message": "创建成功"}
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -35,49 +35,22 @@ class PostgresDatabase:
|
||||||
self.async_engine = create_async_engine(database_url, echo=True)
|
self.async_engine = create_async_engine(database_url, echo=True)
|
||||||
self.async_session_maker = sessionmaker(self.async_engine, class_=AsyncSession, expire_on_commit=False)
|
self.async_session_maker = sessionmaker(self.async_engine, class_=AsyncSession, expire_on_commit=False)
|
||||||
|
|
||||||
self.auth_database = AuthDatabase(self.async_session_maker)
|
self._auth_database = AuthDatabase(self.async_session_maker)
|
||||||
self.provider_database = ProviderDatabase(self.async_session_maker)
|
self._provider_database = ProviderDatabase(self.async_session_maker)
|
||||||
self.individual_database = IndividualDatabase(self.async_session_maker)
|
self._individual_database = IndividualDatabase(self.async_session_maker)
|
||||||
|
|
||||||
async def init_db(self) -> None:
|
async def init_db(self) -> None:
|
||||||
async with self.async_engine.begin() as conn:
|
async with self.async_engine.begin() as conn:
|
||||||
await conn.run_sync(SQLModel.metadata.create_all)
|
await conn.run_sync(SQLModel.metadata.create_all)
|
||||||
|
|
||||||
# provider_database操作
|
async def auth_database(self, method_name: str, *args, **kwargs):
|
||||||
async def get_providers(self):
|
method = getattr(self._auth_database, method_name)
|
||||||
return await self.provider_database.get_provider()
|
return await method(*args, **kwargs)
|
||||||
|
|
||||||
async def add_provider(self, **kwargs):
|
async def provider_database(self, method_name: str, *args, **kwargs):
|
||||||
return await self.provider_database.add_provider(**kwargs)
|
method = getattr(self._provider_database, method_name)
|
||||||
|
return await method(*args, **kwargs)
|
||||||
|
|
||||||
# auth_database操作
|
async def individual_database(self, method_name: str, *args, **kwargs):
|
||||||
async def add_user(self, **kwargs):
|
method = getattr(self._individual_database, method_name)
|
||||||
return await self.auth_database.add_user(**kwargs)
|
return await method(*args, **kwargs)
|
||||||
|
|
||||||
async def change_password(self, **kwargs):
|
|
||||||
return await self.auth_database.change_password(**kwargs)
|
|
||||||
|
|
||||||
async def delete_user(self, **kwargs):
|
|
||||||
return await self.auth_database.delete_user(**kwargs)
|
|
||||||
|
|
||||||
async def login_user(self, **kwargs):
|
|
||||||
return await self.auth_database.login_user(**kwargs)
|
|
||||||
|
|
||||||
async def get_user_authority(self, **kwargs):
|
|
||||||
return await self.auth_database.get_user_authority(**kwargs)
|
|
||||||
|
|
||||||
##individual_database 操作
|
|
||||||
async def add_worker_individual(self, **kwargs):
|
|
||||||
return await self.individual_database.add_worker_individual(**kwargs)
|
|
||||||
|
|
||||||
async def get_worker_individual(self, agent_id: str):
|
|
||||||
return await self.individual_database.get_worker_individual(agent_id)
|
|
||||||
|
|
||||||
async def get_worker_individual_list(self, owner_id: str):
|
|
||||||
return await self.individual_database.get_worker_individual_list(owner_id)
|
|
||||||
|
|
||||||
async def update_worker_individual(self, agent_id: str, **kwargs):
|
|
||||||
return await self.individual_database.update_worker_individual(agent_id, **kwargs)
|
|
||||||
|
|
||||||
async def delete_worker_individual(self, agent_id: str):
|
|
||||||
return await self.individual_database.delete_worker_individual(agent_id)
|
|
||||||
|
|
@ -15,13 +15,13 @@
|
||||||
from sqlmodel import SQLModel, Field
|
from sqlmodel import SQLModel, Field
|
||||||
from typing import List
|
from typing import List
|
||||||
from sqlalchemy import Column, JSON
|
from sqlalchemy import Column, JSON
|
||||||
from typing import Optional, Literal
|
from typing import Optional
|
||||||
|
|
||||||
class Provider(SQLModel, table=True):
|
class Provider(SQLModel, table=True):
|
||||||
__tablename__ = "provider"
|
__tablename__ = "provider"
|
||||||
provider_id: str = Field(primary_key=True)
|
provider_id: str = Field(primary_key=True)
|
||||||
provider_title: str = Field(index=True)
|
provider_title: str = Field(index=True)
|
||||||
provider_type: Literal["openai", "vllm"]
|
provider_type: str
|
||||||
|
|
||||||
provider_url: Optional[str]
|
provider_url: Optional[str]
|
||||||
provider_apikey: Optional[str]
|
provider_apikey: Optional[str]
|
||||||
|
|
|
||||||
|
|
@ -15,37 +15,63 @@
|
||||||
import ray
|
import ray
|
||||||
from pretor.core.global_state_machine.provider_manager import ProviderManager
|
from pretor.core.global_state_machine.provider_manager import ProviderManager
|
||||||
from pretor.core.global_state_machine.tool_manager import GlobalToolManager
|
from pretor.core.global_state_machine.tool_manager import GlobalToolManager
|
||||||
from pretor.core.global_state_machine.model_provider import Provider, ProviderArgs
|
from typing import Dict
|
||||||
import httpx
|
|
||||||
import pathlib
|
|
||||||
import json
|
|
||||||
from loguru import logger
|
|
||||||
from typing import Dict, Literal, List
|
|
||||||
from pretor.core.database.postgres import PostgresDatabase
|
from pretor.core.database.postgres import PostgresDatabase
|
||||||
from pretor.api.platform.event import PretorEvent
|
from pretor.api.platform.event import PretorEvent
|
||||||
import asyncio
|
import asyncio
|
||||||
from pretor.core.workflow.workflow import PretorWorkflow
|
from pretor.core.workflow.workflow import PretorWorkflow
|
||||||
from pretor.core.workflow.workflow_template_generator.workflow_template import WorkflowTemplate
|
|
||||||
from pretor.core.workflow.workflow_template_manager import WorkflowManager
|
from pretor.core.workflow.workflow_template_manager import WorkflowManager
|
||||||
from pretor.core.global_state_machine.skill_manager import GlobalSkillManager
|
from pretor.core.global_state_machine.skill_manager import GlobalSkillManager
|
||||||
from pretor.plugin.tool_plugin.base_tool import BaseToolData
|
|
||||||
|
|
||||||
@ray.remote
|
@ray.remote
|
||||||
class GlobalStateMachine:
|
class GlobalStateMachine:
|
||||||
def __init__(self, postgres_database: PostgresDatabase):
|
def __init__(self, postgres_database: PostgresDatabase):
|
||||||
|
|
||||||
self.event_dict: Dict[int, PretorEvent] = {}
|
self.event_dict: Dict[str, PretorEvent] = {}
|
||||||
self.global_provider_manager = ProviderManager(postgres_database)
|
|
||||||
self.global_tool_manager = GlobalToolManager()
|
|
||||||
self.global_workflow_template_manager = WorkflowManager()
|
|
||||||
self.global_skill_manager = GlobalSkillManager()
|
|
||||||
|
|
||||||
|
self._global_provider_manager = ProviderManager(postgres_database)
|
||||||
|
self._global_tool_manager = GlobalToolManager()
|
||||||
|
self._global_workflow_template_manager = WorkflowManager()
|
||||||
|
self._global_skill_manager = GlobalSkillManager()
|
||||||
|
|
||||||
self.postgres_database = postgres_database
|
self.postgres_database = postgres_database
|
||||||
|
|
||||||
async def init_state_machine(self):
|
async def init_state_machine(self):
|
||||||
await self.global_provider_manager.init_provider_register(self.postgres_database)
|
await self._global_provider_manager.init_provider_register(self.postgres_database)
|
||||||
|
|
||||||
|
async def add_provider_wrap(self, provider_type, provider_title, provider_url, provider_apikey, provider_owner):
|
||||||
|
return await self._global_provider_manager.add_provider(
|
||||||
|
provider_type=provider_type,
|
||||||
|
provider_title=provider_title,
|
||||||
|
provider_url=provider_url,
|
||||||
|
provider_apikey=provider_apikey,
|
||||||
|
provider_owner=provider_owner,
|
||||||
|
postgres_database=self.postgres_database
|
||||||
|
)
|
||||||
|
|
||||||
|
async def provider_manager(self, method_name: str, *args, **kwargs):
|
||||||
|
method = getattr(self._global_provider_manager, method_name)
|
||||||
|
if asyncio.iscoroutinefunction(method):
|
||||||
|
return await method(*args, **kwargs)
|
||||||
|
return method(*args, **kwargs)
|
||||||
|
|
||||||
|
async def tool_manager(self, method_name: str, *args, **kwargs):
|
||||||
|
method = getattr(self._global_tool_manager, method_name)
|
||||||
|
if asyncio.iscoroutinefunction(method):
|
||||||
|
return await method(*args, **kwargs)
|
||||||
|
return method(*args, **kwargs)
|
||||||
|
|
||||||
|
async def workflow_template_manager(self, method_name: str, *args, **kwargs):
|
||||||
|
method = getattr(self._global_workflow_template_manager, method_name)
|
||||||
|
if asyncio.iscoroutinefunction(method):
|
||||||
|
return await method(*args, **kwargs)
|
||||||
|
return method(*args, **kwargs)
|
||||||
|
|
||||||
|
async def skill_manager(self, method_name: str, *args, **kwargs):
|
||||||
|
method = getattr(self._global_skill_manager, method_name)
|
||||||
|
if asyncio.iscoroutinefunction(method):
|
||||||
|
return await method(*args, **kwargs)
|
||||||
|
return method(*args, **kwargs)
|
||||||
|
|
||||||
###以下方法为event_dict方法
|
###以下方法为event_dict方法
|
||||||
def add_event(self, event: PretorEvent) -> None:
|
def add_event(self, event: PretorEvent) -> None:
|
||||||
|
|
@ -79,108 +105,3 @@ class GlobalStateMachine:
|
||||||
|
|
||||||
async def get_received(self, event_id) -> str:
|
async def get_received(self, event_id) -> str:
|
||||||
return await self.event_dict[event_id].receive_queue.get()
|
return await self.event_dict[event_id].receive_queue.get()
|
||||||
|
|
||||||
|
|
||||||
###以下方法为global_provider_manager方法
|
|
||||||
async def add_provider(self, provider_type: Literal["openai", "gemini", "claude"],
|
|
||||||
provider_title: str,
|
|
||||||
provider_url: str,
|
|
||||||
provider_apikey: str,
|
|
||||||
provider_owner: int) -> None:
|
|
||||||
"""
|
|
||||||
add_provider方法,注册供应商适配器(provider_manager方法)
|
|
||||||
|
|
||||||
Args
|
|
||||||
provider_type: 注册商接口类型,目前只支持openai,gemini和claude接口
|
|
||||||
provider_title: 供应商名称,为供应商提供的别名
|
|
||||||
provider_url: 供应商url
|
|
||||||
provider_apikey: 供应商所需要的apikey
|
|
||||||
Returns:
|
|
||||||
|
|
||||||
"""
|
|
||||||
provider_args: ProviderArgs = ProviderArgs(provider_title=provider_title,
|
|
||||||
provider_url=provider_url,
|
|
||||||
provider_apikey=provider_apikey,
|
|
||||||
provider_owner=provider_owner)
|
|
||||||
try:
|
|
||||||
provider_class = self.global_provider_manager.provider_mapper.get(provider_type, None)
|
|
||||||
if provider_class is None:
|
|
||||||
logger.warning(f"Provider type {provider_type} is not supported.")
|
|
||||||
return None
|
|
||||||
provider: Provider = await provider_class.create_model(provider_args)
|
|
||||||
|
|
||||||
provider.provider_owner = provider_owner
|
|
||||||
|
|
||||||
self.global_provider_manager.provider_register[provider_title] = provider
|
|
||||||
|
|
||||||
await self.postgres_database.add_provider.remote(provider_title=provider.provider_title,
|
|
||||||
provider_url=provider.provider_url,
|
|
||||||
provider_apikey=provider.provider_apikey,
|
|
||||||
provider_models=provider.provider_models,
|
|
||||||
provider_type=provider.provider_type,
|
|
||||||
provider_owner=provider.provider_owner)
|
|
||||||
|
|
||||||
logger.info(f"已添加适配器{provider_title}")
|
|
||||||
except httpx.RequestError as e:
|
|
||||||
logger.warning(f"[{provider_args.provider_title}] 网络请求异常: {e}")
|
|
||||||
except Exception as e:
|
|
||||||
logger.warning(f"[{provider_args.provider_title}] 解析模型列表时发生错误: {e}")
|
|
||||||
|
|
||||||
def get_provider_list(self) -> Dict[str, Provider]:
|
|
||||||
"""
|
|
||||||
get_provider_list方法,获取注册表(provider_manager方法)
|
|
||||||
Returns:
|
|
||||||
返回provider_register属性字典
|
|
||||||
"""
|
|
||||||
return self.global_provider_manager.provider_register
|
|
||||||
|
|
||||||
def get_provider(self, provider_title) -> Provider:
|
|
||||||
"""
|
|
||||||
get_provider方法,获取供应商信息(provider_manager方法)
|
|
||||||
Args:
|
|
||||||
provider_title:provider名称
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Provider对象,返回注册在self.global_provider_manager.provider_register的供应商
|
|
||||||
"""
|
|
||||||
provider = self.global_provider_manager.provider_register.get(provider_title)
|
|
||||||
return provider
|
|
||||||
|
|
||||||
|
|
||||||
###以下为global_tool_manager方法
|
|
||||||
def get_tool_list(self, agent_name: str) -> Dict[str, BaseToolData]:
|
|
||||||
"""
|
|
||||||
获取工具表方法
|
|
||||||
Args:
|
|
||||||
agent_name: agent的名字
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
返回该agent的tool,类型为dict
|
|
||||||
"""
|
|
||||||
tool_list = self.global_tool_manager.tool_mapper.get(agent_name, {})
|
|
||||||
return tool_list
|
|
||||||
|
|
||||||
|
|
||||||
###以下为workflow_template_manager方法
|
|
||||||
def workflow_template_generate(self, workflow_template: WorkflowTemplate) -> None:
|
|
||||||
self.global_workflow_template_manager.generate_workflow_template(workflow_template)
|
|
||||||
|
|
||||||
def get_workflow_template_list(self) -> List[Dict[str, str]]:
|
|
||||||
return self.global_workflow_template_manager.workflow_templates_registry
|
|
||||||
|
|
||||||
|
|
||||||
###以下为skill_manager方法
|
|
||||||
def add_skill(self, skill_name: str):
|
|
||||||
skill_plugin_dir = pathlib.Path(__file__).parent.parent.parent / "plugin" / "skill_plugin" / skill_name
|
|
||||||
json_path = skill_plugin_dir / "skill.json"
|
|
||||||
try:
|
|
||||||
with open(json_path, "r", encoding="utf-8") as f:
|
|
||||||
skill = json.load(f)
|
|
||||||
name = skill.get("name")
|
|
||||||
if name:
|
|
||||||
self.global_skill_manager.skill_mapper[name] = (
|
|
||||||
skill.get("description", ""),
|
|
||||||
skill.get("instructions", "")
|
|
||||||
)
|
|
||||||
except (json.JSONDecodeError, OSError) as e:
|
|
||||||
print(f"警告: 加载插件 {skill_name} 失败: {e}")
|
|
||||||
|
|
|
||||||
|
|
@ -33,6 +33,45 @@ class ProviderManager:
|
||||||
self.provider_register = {}
|
self.provider_register = {}
|
||||||
|
|
||||||
async def init_provider_register(self, postgres) -> None:
|
async def init_provider_register(self, postgres) -> None:
|
||||||
providers = await postgres.get_providers.remote()
|
providers = await postgres.provider_database.remote("get_provider")
|
||||||
for provider in providers:
|
for provider in providers:
|
||||||
self.provider_register[provider.provider_title] = provider
|
self.provider_register[provider.provider_title] = provider
|
||||||
|
|
||||||
|
async def add_provider(self, provider_type, provider_title, provider_url, provider_apikey, provider_owner, postgres_database) -> None:
|
||||||
|
from pretor.core.global_state_machine.model_provider import ProviderArgs
|
||||||
|
from loguru import logger
|
||||||
|
import httpx
|
||||||
|
|
||||||
|
provider_args: ProviderArgs = ProviderArgs(provider_title=provider_title,
|
||||||
|
provider_url=provider_url,
|
||||||
|
provider_apikey=provider_apikey,
|
||||||
|
provider_owner=provider_owner)
|
||||||
|
try:
|
||||||
|
provider_class = self.provider_mapper.get(provider_type, None)
|
||||||
|
if provider_class is None:
|
||||||
|
logger.warning(f"Provider type {provider_type} is not supported.")
|
||||||
|
return None
|
||||||
|
provider: Provider = await provider_class.create_model(provider_args)
|
||||||
|
|
||||||
|
provider.provider_owner = provider_owner
|
||||||
|
|
||||||
|
self.provider_register[provider_title] = provider
|
||||||
|
|
||||||
|
await postgres_database.provider_database.remote("add_provider", provider_title=provider.provider_title,
|
||||||
|
provider_url=provider.provider_url,
|
||||||
|
provider_apikey=provider.provider_apikey,
|
||||||
|
provider_models=provider.provider_models,
|
||||||
|
provider_type=provider.provider_type,
|
||||||
|
provider_owner=provider.provider_owner)
|
||||||
|
|
||||||
|
logger.info(f"已添加适配器{provider_title}")
|
||||||
|
except httpx.RequestError as e:
|
||||||
|
logger.warning(f"[{provider_args.provider_title}] 网络请求异常: {e}")
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"[{provider_args.provider_title}] 解析模型列表时发生错误: {e}")
|
||||||
|
|
||||||
|
def get_provider_list(self):
|
||||||
|
return self.provider_register
|
||||||
|
|
||||||
|
def get_provider(self, provider_title):
|
||||||
|
return self.provider_register.get(provider_title)
|
||||||
|
|
@ -55,7 +55,7 @@ class ConsciousnessNode:
|
||||||
"请确保所有的思考和生成过程符合逻辑,严密且高质量。"
|
"请确保所有的思考和生成过程符合逻辑,严密且高质量。"
|
||||||
)
|
)
|
||||||
output_type = Union[ForSupervisoryNode, ForWorkflow, ForWorkflowEngine]
|
output_type = Union[ForSupervisoryNode, ForWorkflow, ForWorkflowEngine]
|
||||||
provider: Provider = global_state_machine.get_provider.remote(provider_title)
|
provider: Provider = global_state_machine.provider_manager.remote("get_provider", provider_title)
|
||||||
agent_factory = AgentFactory()
|
agent_factory = AgentFactory()
|
||||||
self.agent = agent_factory.create_agent(provider=provider,
|
self.agent = agent_factory.create_agent(provider=provider,
|
||||||
model_id=model_id,
|
model_id=model_id,
|
||||||
|
|
@ -85,7 +85,7 @@ class ConsciousnessNode:
|
||||||
else:
|
else:
|
||||||
logger.error(f"ConsciousnessNode: 未知或不匹配的返回类型: {type(result)}")
|
logger.error(f"ConsciousnessNode: 未知或不匹配的返回类型: {type(result)}")
|
||||||
return None
|
return None
|
||||||
except Exception as e:
|
except Exception:
|
||||||
logger.exception("ConsciousnessNode在执行working时发生严重错误")
|
logger.exception("ConsciousnessNode在执行working时发生严重错误")
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -51,7 +51,7 @@ class ControlNode:
|
||||||
"请注意:你的输出应当具体、实用,直接提供任务所要求的结果,不要做过多无关的寒暄。"
|
"请注意:你的输出应当具体、实用,直接提供任务所要求的结果,不要做过多无关的寒暄。"
|
||||||
)
|
)
|
||||||
output_type = ForWorkflow
|
output_type = ForWorkflow
|
||||||
provider: Provider = global_state_machine.get_provider.remote(provider_title)
|
provider: Provider = global_state_machine.provider_manager.remote("get_provider", provider_title)
|
||||||
agent_factory = AgentFactory()
|
agent_factory = AgentFactory()
|
||||||
self.agent = agent_factory.create_agent(provider=provider,
|
self.agent = agent_factory.create_agent(provider=provider,
|
||||||
model_id=model_id,
|
model_id=model_id,
|
||||||
|
|
|
||||||
|
|
@ -54,7 +54,7 @@ class SupervisoryNode:
|
||||||
"请保持冷静、专业,并严格遵循上述路由规则。"
|
"请保持冷静、专业,并严格遵循上述路由规则。"
|
||||||
)
|
)
|
||||||
output_type = Union[ForConsciousnessNode, ForUser]
|
output_type = Union[ForConsciousnessNode, ForUser]
|
||||||
provider: Provider = await global_state_machine.get_provider.remote(provider_title)
|
provider: Provider = await global_state_machine.provider_manager.remote("get_provider", provider_title)
|
||||||
agent_factory = AgentFactory()
|
agent_factory = AgentFactory()
|
||||||
self.agent = agent_factory.create_agent(provider=provider,
|
self.agent = agent_factory.create_agent(provider=provider,
|
||||||
model_id=model_id,
|
model_id=model_id,
|
||||||
|
|
@ -155,7 +155,7 @@ class SupervisoryNode:
|
||||||
time_str = datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
time_str = datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
||||||
try:
|
try:
|
||||||
global_state_machine = ray_actor_hook("global_state_machine")
|
global_state_machine = ray_actor_hook("global_state_machine")
|
||||||
workflow_template_dict = await global_state_machine.get_workflow_template_list.remote()
|
workflow_template_dict = await global_state_machine.workflow_template_manager.remote("get_workflow_template_list")
|
||||||
available_templates_str = "\n".join([f"- 名称: {k}, 描述/内容: {v}" for k, v in
|
available_templates_str = "\n".join([f"- 名称: {k}, 描述/内容: {v}" for k, v in
|
||||||
workflow_template_dict.items()]) if workflow_template_dict else "暂无注册的工作流模板"
|
workflow_template_dict.items()]) if workflow_template_dict else "暂无注册的工作流模板"
|
||||||
deps = SupervisoryNodeDeps(
|
deps = SupervisoryNodeDeps(
|
||||||
|
|
|
||||||
|
|
@ -158,7 +158,7 @@ class WorkflowEngine:
|
||||||
logger.info(f"Supervisory 最终回复:{user_response}")
|
logger.info(f"Supervisory 最终回复:{user_response}")
|
||||||
else:
|
else:
|
||||||
logger.warning("未提供 supervisory_node 句柄,跳过用户反馈生成。")
|
logger.warning("未提供 supervisory_node 句柄,跳过用户反馈生成。")
|
||||||
except Exception as e:
|
except Exception:
|
||||||
logger.exception("生成工作流执行汇报时发生错误")
|
logger.exception("生成工作流执行汇报时发生错误")
|
||||||
|
|
||||||
async def _dispatch_to_node(self, step: WorkStep, input_data: Any) -> tuple[Any, bool]:
|
async def _dispatch_to_node(self, step: WorkStep, input_data: Any) -> tuple[Any, bool]:
|
||||||
|
|
@ -207,7 +207,7 @@ class WorkflowEngine:
|
||||||
else:
|
else:
|
||||||
raise WorkflowError(f"未知的节点类型:{step.node}")
|
raise WorkflowError(f"未知的节点类型:{step.node}")
|
||||||
|
|
||||||
except Exception as e:
|
except Exception:
|
||||||
logger.exception(f"节点 {step.node} 执行动作 {step.action} 失败")
|
logger.exception(f"节点 {step.node} 执行动作 {step.action} 失败")
|
||||||
return None, False
|
return None, False
|
||||||
|
|
||||||
|
|
@ -245,7 +245,7 @@ class WorkflowRunningEngine:
|
||||||
self.consciousness_node = consciousness_node
|
self.consciousness_node = consciousness_node
|
||||||
self.control_node = control_node
|
self.control_node = control_node
|
||||||
self.supervisory_node = supervisory_node
|
self.supervisory_node = supervisory_node
|
||||||
self.global_state_machine = ray_actor_hook("global_state_machine")
|
self.global_state_machine = ray_actor_hook("global_state_machine").global_state_machine
|
||||||
|
|
||||||
async def run(self):
|
async def run(self):
|
||||||
self.runner_engine = {
|
self.runner_engine = {
|
||||||
|
|
|
||||||
|
|
@ -40,6 +40,5 @@ class WorkflowManager:
|
||||||
try:
|
try:
|
||||||
workflow_template = self.workflow_template_generator.generate_workflow_template(workflow_template=workflow_template)
|
workflow_template = self.workflow_template_generator.generate_workflow_template(workflow_template=workflow_template)
|
||||||
self.workflow_templates_registry[workflow_template.name] = workflow_template.desc
|
self.workflow_templates_registry[workflow_template.name] = workflow_template.desc
|
||||||
except Exception as e:
|
except Exception:
|
||||||
logger.exception("Failed to generate workflow template")
|
logger.exception("Failed to generate workflow template")
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -21,7 +21,7 @@ from pretor.utils.ray_hook import ray_actor_hook
|
||||||
@lru_cache
|
@lru_cache
|
||||||
async def get_authority(user_id: str) -> UserAuthority:
|
async def get_authority(user_id: str) -> UserAuthority:
|
||||||
postgres_database = ray_actor_hook("postgres_database")
|
postgres_database = ray_actor_hook("postgres_database")
|
||||||
user_authority = await postgres_database.get_user_authority.remote(user_id=user_id)
|
user_authority = await postgres_database.auth_database.remote("get_user_authority", user_id=user_id)
|
||||||
return user_authority
|
return user_authority
|
||||||
|
|
||||||
class RoleChecker:
|
class RoleChecker:
|
||||||
|
|
|
||||||
|
|
@ -49,7 +49,7 @@ def del_tool_cache(tool_name: str) -> None:
|
||||||
@lru_cache(maxsize=1)
|
@lru_cache(maxsize=1)
|
||||||
async def get_tool(agent_name: str) -> List[Callable]:
|
async def get_tool(agent_name: str) -> List[Callable]:
|
||||||
global_state_machine = ray_actor_hook("global_state_machine")
|
global_state_machine = ray_actor_hook("global_state_machine")
|
||||||
_tool_list = await global_state_machine.get_tool_list.remote(agent_name)
|
_tool_list = await global_state_machine.tool_manager.remote("get_tool_list", agent_name)
|
||||||
tool_list = []
|
tool_list = []
|
||||||
for tool_name in _tool_list.keys():
|
for tool_name in _tool_list.keys():
|
||||||
tool_func = _get_tool_func(tool_name)
|
tool_func = _get_tool_func(tool_name)
|
||||||
|
|
|
||||||
|
|
@ -65,6 +65,8 @@ async def test_postgres_database(mock_env_get, mock_provider_db, mock_auth_db, m
|
||||||
)
|
)
|
||||||
mock_auth_db.assert_called_once()
|
mock_auth_db.assert_called_once()
|
||||||
mock_provider_db.assert_called_once()
|
mock_provider_db.assert_called_once()
|
||||||
|
mock_auth_db.return_value.get_user_authority = AsyncMock(return_value="test_auth")
|
||||||
|
assert await db.auth_database("get_user_authority", user_id="123") == "test_auth"
|
||||||
|
|
||||||
with patch("pretor.core.database.postgres.SQLModel.metadata.create_all") as mock_create_all:
|
with patch("pretor.core.database.postgres.SQLModel.metadata.create_all") as mock_create_all:
|
||||||
await db.init_db()
|
await db.init_db()
|
||||||
|
|
|
||||||
|
|
@ -2,6 +2,6 @@ import pytest
|
||||||
from pretor.core.database.table.user import User
|
from pretor.core.database.table.user import User
|
||||||
|
|
||||||
def test_user_table():
|
def test_user_table():
|
||||||
user = User(user_name="name", hashed_password="pw")
|
user = User(user_id="id", user_name="name", hashed_password="pw")
|
||||||
assert User.__tablename__ == 'user'
|
assert User.__tablename__ == 'user'
|
||||||
assert user.user_name == "name"
|
assert user.user_name == "name"
|
||||||
|
|
|
||||||
|
|
@ -40,7 +40,6 @@ def mock_postgres():
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def gsm(mock_postgres):
|
def gsm(mock_postgres):
|
||||||
with patch("pretor.core.global_state_machine.global_state_machine.ProviderManager") as mock_pm:
|
|
||||||
manager = GlobalStateMachine(mock_postgres)
|
manager = GlobalStateMachine(mock_postgres)
|
||||||
return manager
|
return manager
|
||||||
|
|
||||||
|
|
@ -103,24 +102,24 @@ async def test_add_provider_success(gsm, mock_postgres):
|
||||||
mock_provider.provider_type = "openai"
|
mock_provider.provider_type = "openai"
|
||||||
mock_provider_class.create_model.return_value = mock_provider
|
mock_provider_class.create_model.return_value = mock_provider
|
||||||
|
|
||||||
gsm.global_provider_manager.provider_mapper = {"openai": mock_provider_class}
|
gsm._global_provider_manager.provider_mapper = {"openai": mock_provider_class}
|
||||||
gsm.global_provider_manager.provider_register = {}
|
gsm._global_provider_manager.provider_register = {}
|
||||||
|
|
||||||
mock_add_provider = AsyncMock()
|
mock_add_provider = AsyncMock()
|
||||||
mock_postgres.provider_database.add_provider.remote = mock_add_provider
|
mock_postgres.provider_database.remote = mock_add_provider
|
||||||
|
|
||||||
await gsm.add_provider("openai", "title", "url", "key", 1)
|
await gsm.add_provider_wrap("openai", "title", "url", "key", 1)
|
||||||
|
|
||||||
assert gsm.global_provider_manager.provider_register["title"] == mock_provider
|
assert gsm._global_provider_manager.provider_register["title"] == mock_provider
|
||||||
mock_add_provider.assert_called_once()
|
mock_add_provider.assert_called_once()
|
||||||
assert mock_provider.provider_owner == 1
|
assert mock_provider.provider_owner == 1
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_add_provider_unsupported(gsm):
|
async def test_add_provider_unsupported(gsm):
|
||||||
gsm.global_provider_manager.provider_mapper = {}
|
gsm._global_provider_manager.provider_mapper = {}
|
||||||
with patch("pretor.core.global_state_machine.global_state_machine.logger") as mock_logger:
|
with patch("loguru.logger") as mock_logger:
|
||||||
await gsm.add_provider("magic", "title", "url", "key", 1)
|
await gsm.add_provider_wrap("magic", "title", "url", "key", 1)
|
||||||
mock_logger.warning.assert_called_with("Provider type magic is not supported.")
|
mock_logger.warning.assert_called_with("Provider type magic is not supported.")
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -129,10 +128,10 @@ async def test_add_provider_request_error(gsm):
|
||||||
from httpx import RequestError
|
from httpx import RequestError
|
||||||
mock_provider_class = AsyncMock()
|
mock_provider_class = AsyncMock()
|
||||||
mock_provider_class.create_model.side_effect = RequestError("Network Error", request=MagicMock())
|
mock_provider_class.create_model.side_effect = RequestError("Network Error", request=MagicMock())
|
||||||
gsm.global_provider_manager.provider_mapper = {"openai": mock_provider_class}
|
gsm._global_provider_manager.provider_mapper = {"openai": mock_provider_class}
|
||||||
|
|
||||||
with patch("pretor.core.global_state_machine.global_state_machine.logger") as mock_logger:
|
with patch("loguru.logger") as mock_logger:
|
||||||
await gsm.add_provider("openai", "title", "url", "key", 1)
|
await gsm.add_provider_wrap("openai", "title", "url", "key", 1)
|
||||||
mock_logger.warning.assert_called_once()
|
mock_logger.warning.assert_called_once()
|
||||||
assert "网络请求异常" in mock_logger.warning.call_args[0][0]
|
assert "网络请求异常" in mock_logger.warning.call_args[0][0]
|
||||||
|
|
||||||
|
|
@ -141,18 +140,18 @@ async def test_add_provider_request_error(gsm):
|
||||||
async def test_add_provider_generic_error(gsm):
|
async def test_add_provider_generic_error(gsm):
|
||||||
mock_provider_class = AsyncMock()
|
mock_provider_class = AsyncMock()
|
||||||
mock_provider_class.create_model.side_effect = ValueError("Some Error")
|
mock_provider_class.create_model.side_effect = ValueError("Some Error")
|
||||||
gsm.global_provider_manager.provider_mapper = {"openai": mock_provider_class}
|
gsm._global_provider_manager.provider_mapper = {"openai": mock_provider_class}
|
||||||
|
|
||||||
with patch("pretor.core.global_state_machine.global_state_machine.logger") as mock_logger:
|
with patch("loguru.logger") as mock_logger:
|
||||||
await gsm.add_provider("openai", "title", "url", "key", 1)
|
await gsm.add_provider_wrap("openai", "title", "url", "key", 1)
|
||||||
mock_logger.warning.assert_called_once()
|
mock_logger.warning.assert_called_once()
|
||||||
assert "解析模型列表时发生错误" in mock_logger.warning.call_args[0][0]
|
assert "解析模型列表时发生错误" in mock_logger.warning.call_args[0][0]
|
||||||
|
|
||||||
|
|
||||||
def test_get_provider_list_and_get_provider(gsm):
|
def test_get_provider_list_and_get_provider(gsm):
|
||||||
mock_provider = MagicMock()
|
mock_provider = MagicMock()
|
||||||
gsm.global_provider_manager.provider_register = {"p1": mock_provider}
|
gsm._global_provider_manager.provider_register = {"p1": mock_provider}
|
||||||
|
|
||||||
assert gsm.get_provider_list() == {"p1": mock_provider}
|
assert gsm._global_provider_manager.get_provider_list() == {"p1": mock_provider}
|
||||||
assert gsm.get_provider("p1") == mock_provider
|
assert gsm._global_provider_manager.get_provider("p1") == mock_provider
|
||||||
assert gsm.get_provider("missing") is None
|
assert gsm._global_provider_manager.get_provider("missing") is None
|
||||||
|
|
|
||||||
|
|
@ -16,6 +16,8 @@ async def test_provider_manager_init():
|
||||||
mock_postgres.get_providers.remote = AsyncMock(return_value=[mock_provider1, mock_provider2])
|
mock_postgres.get_providers.remote = AsyncMock(return_value=[mock_provider1, mock_provider2])
|
||||||
|
|
||||||
manager = ProviderManager(mock_postgres)
|
manager = ProviderManager(mock_postgres)
|
||||||
|
mock_postgres.provider_database = MagicMock()
|
||||||
|
mock_postgres.provider_database.remote = AsyncMock(return_value=[mock_provider1, mock_provider2])
|
||||||
await manager.init_provider_register(mock_postgres)
|
await manager.init_provider_register(mock_postgres)
|
||||||
|
|
||||||
assert "openai" in manager.provider_mapper
|
assert "openai" in manager.provider_mapper
|
||||||
|
|
|
||||||
|
|
@ -119,20 +119,17 @@ async def test_workflow_running_engine_runner():
|
||||||
from pretor.core.individual.consciousness_node.template import ForWorkflowEngine
|
from pretor.core.individual.consciousness_node.template import ForWorkflowEngine
|
||||||
|
|
||||||
mock_consciousness = MagicMock()
|
mock_consciousness = MagicMock()
|
||||||
|
|
||||||
mock_wf = MagicMock()
|
mock_wf = MagicMock()
|
||||||
mock_wf.trace_id = "test_trace"
|
mock_wf.trace_id = "test_trace"
|
||||||
mock_wf.title = "test_title"
|
mock_wf.title = "test_title"
|
||||||
|
|
||||||
mock_result = MagicMock(spec=ForWorkflowEngine)
|
mock_result = MagicMock(spec=ForWorkflowEngine)
|
||||||
mock_result.workflow = mock_wf
|
mock_result.workflow = mock_wf
|
||||||
|
|
||||||
mock_consciousness.working.remote = AsyncMock(return_value=mock_result)
|
mock_consciousness.working.remote = AsyncMock(return_value=mock_result)
|
||||||
|
|
||||||
engine = WorkflowRunningEngine(mock_consciousness, "control", "supervisor")
|
engine = WorkflowRunningEngine(mock_consciousness, "control", "supervisor")
|
||||||
engine.workflow_queue = asyncio.Queue()
|
engine.workflow_queue = asyncio.Queue()
|
||||||
|
|
||||||
# Use real PretorEvent to avoid Pydantic validation errors on MagicMock properties
|
|
||||||
mock_event = PretorEvent(
|
mock_event = PretorEvent(
|
||||||
platform="test_platform",
|
platform="test_platform",
|
||||||
user_id="test_user",
|
user_id="test_user",
|
||||||
|
|
@ -142,18 +139,19 @@ async def test_workflow_running_engine_runner():
|
||||||
)
|
)
|
||||||
await engine.workflow_queue.put(mock_event)
|
await engine.workflow_queue.put(mock_event)
|
||||||
|
|
||||||
with patch("pretor.core.workflow.workflow_runner.WorkflowEngine") as mock_wf_engine_cls, \
|
with patch("pretor.core.workflow.workflow_runner.WorkflowEngine") as mock_wf_engine_cls, patch("builtins.open", new_callable=MagicMock) as mock_open:
|
||||||
patch("pretor.core.workflow.workflow_runner.ray_actor_hook") as mock_hook:
|
|
||||||
mock_gsm = MagicMock()
|
# Instead of patching hook, we inject it directly
|
||||||
mock_gsm.update_workflow.remote = AsyncMock()
|
engine.global_state_machine = AsyncMock()
|
||||||
mock_hook.return_value = mock_gsm
|
|
||||||
|
mock_open.return_value.__enter__.return_value.read.return_value = '{}'
|
||||||
|
|
||||||
mock_engine_instance = MagicMock()
|
mock_engine_instance = MagicMock()
|
||||||
mock_engine_instance.run = AsyncMock()
|
mock_engine_instance.run = AsyncMock()
|
||||||
mock_wf_engine_cls.return_value = mock_engine_instance
|
mock_wf_engine_cls.return_value = mock_engine_instance
|
||||||
|
|
||||||
task = asyncio.create_task(engine.runner(1))
|
task = asyncio.create_task(engine.runner(1))
|
||||||
await asyncio.sleep(0.05) # Give runner time to process one item
|
await asyncio.sleep(0.05)
|
||||||
task.cancel() # Stop the infinite loop
|
task.cancel()
|
||||||
|
|
||||||
mock_wf_engine_cls.assert_called_with(mock_wf, mock_consciousness, "control", "supervisor")
|
mock_wf_engine_cls.assert_called_with(mock_wf, mock_consciousness, "control", "supervisor")
|
||||||
|
|
@ -24,7 +24,7 @@ def test_pretor_workflow_validation_success():
|
||||||
ws1 = WorkStep(step=1, node="control_node", action="a1", desc="d1")
|
ws1 = WorkStep(step=1, node="control_node", action="a1", desc="d1")
|
||||||
ws2 = WorkStep(step=2, node="supervisory_node", action="a2", desc="d2")
|
ws2 = WorkStep(step=2, node="supervisory_node", action="a2", desc="d2")
|
||||||
wg = WorkerGroup(name="g1", primary_individual={"coder": 1}, composite_individual={})
|
wg = WorkerGroup(name="g1", primary_individual={"coder": 1}, composite_individual={})
|
||||||
wf = PretorWorkflow(title="wf1", workgroup_list=[wg], work_link=[ws1, ws2], trace_id="t", event_info={"platform":"a", "username":"b"})
|
wf = PretorWorkflow(title="wf1", workgroup_list=[wg], work_link=[ws1, ws2], trace_id="t", event_info={"platform":"a", "user_name":"b"})
|
||||||
assert wf.title == "wf1"
|
assert wf.title == "wf1"
|
||||||
|
|
||||||
def test_pretor_workflow_validation_error_step_discontinuous():
|
def test_pretor_workflow_validation_error_step_discontinuous():
|
||||||
|
|
@ -32,7 +32,7 @@ def test_pretor_workflow_validation_error_step_discontinuous():
|
||||||
ws2 = WorkStep(step=3, node="supervisory_node", action="a2", desc="d2")
|
ws2 = WorkStep(step=3, node="supervisory_node", action="a2", desc="d2")
|
||||||
wg = WorkerGroup(name="g1", primary_individual={}, composite_individual={})
|
wg = WorkerGroup(name="g1", primary_individual={}, composite_individual={})
|
||||||
with pytest.raises(ValueError, match="工作链步数不连续"):
|
with pytest.raises(ValueError, match="工作链步数不连续"):
|
||||||
PretorWorkflow(title="wf1", workgroup_list=[wg], work_link=[ws1, ws2], trace_id="t", event_info={"platform":"a", "username":"b"})
|
PretorWorkflow(title="wf1", workgroup_list=[wg], work_link=[ws1, ws2], trace_id="t", event_info={"platform":"a", "user_name":"b"})
|
||||||
|
|
||||||
def test_pretor_workflow_validation_error_jump_out_of_bounds():
|
def test_pretor_workflow_validation_error_jump_out_of_bounds():
|
||||||
lg = LogicGate(if_fail="jump_to_step_3", if_pass="continue")
|
lg = LogicGate(if_fail="jump_to_step_3", if_pass="continue")
|
||||||
|
|
@ -40,14 +40,14 @@ def test_pretor_workflow_validation_error_jump_out_of_bounds():
|
||||||
ws2 = WorkStep(step=2, node="supervisory_node", action="a2", desc="d2")
|
ws2 = WorkStep(step=2, node="supervisory_node", action="a2", desc="d2")
|
||||||
wg = WorkerGroup(name="g1", primary_individual={}, composite_individual={})
|
wg = WorkerGroup(name="g1", primary_individual={}, composite_individual={})
|
||||||
with pytest.raises(ValueError, match="跳转目标 Step 3 越界了"):
|
with pytest.raises(ValueError, match="跳转目标 Step 3 越界了"):
|
||||||
PretorWorkflow(title="wf1", workgroup_list=[wg], work_link=[ws1, ws2], trace_id="t", event_info={"platform":"a", "username":"b"})
|
PretorWorkflow(title="wf1", workgroup_list=[wg], work_link=[ws1, ws2], trace_id="t", event_info={"platform":"a", "user_name":"b"})
|
||||||
|
|
||||||
def test_pretor_workflow_validation_error_jump_format_error():
|
def test_pretor_workflow_validation_error_jump_format_error():
|
||||||
lg = LogicGate(if_fail="jump_to_step_invalid", if_pass="continue")
|
lg = LogicGate(if_fail="jump_to_step_invalid", if_pass="continue")
|
||||||
ws1 = WorkStep(step=1, node="control_node", action="a1", desc="d1", logic_gate=lg)
|
ws1 = WorkStep(step=1, node="control_node", action="a1", desc="d1", logic_gate=lg)
|
||||||
wg = WorkerGroup(name="g1", primary_individual={}, composite_individual={})
|
wg = WorkerGroup(name="g1", primary_individual={}, composite_individual={})
|
||||||
with pytest.raises(ValueError, match="LogicGate 格式错误"):
|
with pytest.raises(ValueError, match="LogicGate 格式错误"):
|
||||||
PretorWorkflow(title="wf1", workgroup_list=[wg], work_link=[ws1], trace_id="t", event_info={"platform":"a", "username":"b"})
|
PretorWorkflow(title="wf1", workgroup_list=[wg], work_link=[ws1], trace_id="t", event_info={"platform":"a", "user_name":"b"})
|
||||||
|
|
||||||
def test_workflow_status():
|
def test_workflow_status():
|
||||||
status = WorkflowStatus()
|
status = WorkflowStatus()
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue