wip: 增加了api的简单权限校验,使用反射对于actor进行重构
This commit is contained in:
parent
81da2e9f81
commit
be0d78c855
|
|
@ -21,6 +21,8 @@ from pretor.utils.access import Accessor, TokenData
|
|||
from pretor.core.database.table.individual import AgentType
|
||||
from fastapi import HTTPException
|
||||
from typing import Optional, List, Dict
|
||||
from pretor.utils.check_user.role_check import RoleChecker
|
||||
from pretor.core.database.table.user import UserAuthority
|
||||
|
||||
agent_router = APIRouter(prefix="/api/v1/agent", tags=["agent"])
|
||||
|
||||
|
|
@ -35,7 +37,7 @@ class AgentLocalRegister(BaseModel):
|
|||
|
||||
@agent_router.post("")
|
||||
async def load_agent(agent_register: Union[AgentRegister, AgentLocalRegister],
|
||||
_: TokenData = Depends(Accessor.get_current_user)):
|
||||
_: TokenData = Depends(RoleChecker(allowed_roles=UserAuthority.USER))):
|
||||
global_state_machine = ray_actor_hook("global_state_machine")
|
||||
if isinstance(agent_register, AgentLocalRegister):
|
||||
pass
|
||||
|
|
@ -82,18 +84,18 @@ class WorkerIndividualUpdate(BaseModel):
|
|||
|
||||
@agent_router.post("/worker")
|
||||
async def create_worker_individual(worker_data: WorkerIndividualCreate,
|
||||
token_data: TokenData = Depends(Accessor.get_current_user)):
|
||||
token_data: TokenData = Depends(RoleChecker(allowed_roles=UserAuthority.USER))):
|
||||
postgres_database = ray_actor_hook("postgres_database")
|
||||
data_dict = worker_data.model_dump()
|
||||
data_dict["owner_id"] = token_data.user_id
|
||||
worker = await postgres_database.add_worker_individual.remote(**data_dict)
|
||||
worker = await postgres_database.individual_database.remote("add_worker_individual", **data_dict)
|
||||
return {"message": "success", "agent_id": worker.agent_id}
|
||||
|
||||
|
||||
@agent_router.get("/worker")
|
||||
async def get_worker_individual_list(token_data: TokenData = Depends(Accessor.get_current_user)):
|
||||
postgres_database = ray_actor_hook("postgres_database")
|
||||
workers = await postgres_database.get_worker_individual_list.remote(owner_id=token_data.user_id)
|
||||
workers = await postgres_database.individual_database.remote("get_worker_individual_list", owner_id=token_data.user_id)
|
||||
return {"workers": workers}
|
||||
|
||||
|
||||
|
|
@ -101,7 +103,7 @@ async def get_worker_individual_list(token_data: TokenData = Depends(Accessor.ge
|
|||
async def get_worker_individual(agent_id: str,
|
||||
token_data: TokenData = Depends(Accessor.get_current_user)):
|
||||
postgres_database = ray_actor_hook("postgres_database")
|
||||
worker = await postgres_database.get_worker_individual.remote(agent_id=agent_id)
|
||||
worker = await postgres_database.individual_database.remote("get_worker_individual", agent_id=agent_id)
|
||||
if not worker:
|
||||
raise HTTPException(status_code=404, detail="Agent not found")
|
||||
if worker.owner_id != token_data.user_id:
|
||||
|
|
@ -114,14 +116,14 @@ async def update_worker_individual(agent_id: str,
|
|||
worker_data: WorkerIndividualUpdate,
|
||||
token_data: TokenData = Depends(Accessor.get_current_user)):
|
||||
postgres_database = ray_actor_hook("postgres_database")
|
||||
worker = await postgres_database.get_worker_individual.remote(agent_id=agent_id)
|
||||
worker = await postgres_database.individual_database.remote("get_worker_individual", agent_id=agent_id)
|
||||
if not worker:
|
||||
raise HTTPException(status_code=404, detail="Agent not found")
|
||||
if worker.owner_id != token_data.user_id:
|
||||
raise HTTPException(status_code=403, detail="Forbidden: You do not own this agent")
|
||||
|
||||
update_data = worker_data.model_dump(exclude_unset=True)
|
||||
updated_worker = await postgres_database.update_worker_individual.remote(agent_id=agent_id, **update_data)
|
||||
updated_worker = await postgres_database.individual_database.remote("update_worker_individual", agent_id=agent_id, **update_data)
|
||||
return {"message": "success", "worker": updated_worker}
|
||||
|
||||
|
||||
|
|
@ -129,11 +131,10 @@ async def update_worker_individual(agent_id: str,
|
|||
async def delete_worker_individual(agent_id: str,
|
||||
token_data: TokenData = Depends(Accessor.get_current_user)):
|
||||
postgres_database = ray_actor_hook("postgres_database")
|
||||
worker = await postgres_database.get_worker_individual.remote(agent_id=agent_id)
|
||||
worker = await postgres_database.individual_database.remote("get_worker_individual", agent_id=agent_id)
|
||||
if not worker:
|
||||
raise HTTPException(status_code=404, detail="Agent not found")
|
||||
if worker.owner_id != token_data.user_id:
|
||||
raise HTTPException(status_code=403, detail="Forbidden: You do not own this agent")
|
||||
|
||||
await postgres_database.delete_worker_individual.remote(agent_id=agent_id)
|
||||
await postgres_database.individual_database.remote("delete_worker_individual", agent_id=agent_id)
|
||||
return {"message": "success"}
|
||||
|
|
@ -28,7 +28,7 @@ class UserRegister(BaseModel):
|
|||
async def create_user(user_register: UserRegister):
|
||||
postgres_database = ray_actor_hook("postgres_database")
|
||||
hashed_password = await run_in_threadpool(Accessor.hash_password, user_register.password)
|
||||
user = await postgres_database.add_user.remote(user_register.user_name, hashed_password)
|
||||
user = await postgres_database.auth_database.remote("add_user", user_register.user_name, hashed_password)
|
||||
return {"message": "success", "user_id": user.user_id}
|
||||
|
||||
class UserLogin(BaseModel):
|
||||
|
|
@ -38,7 +38,7 @@ class UserLogin(BaseModel):
|
|||
@auth_router.post("/login")
|
||||
async def login_user(user_login: UserLogin):
|
||||
postgres_database = ray_actor_hook("postgres_database")
|
||||
user = await postgres_database.login_user.remote(user_login.user_name)
|
||||
user = await postgres_database.auth_database.remote("login_user", user_login.user_name)
|
||||
if user.user_name != user_login.user_name:
|
||||
pass
|
||||
token = await run_in_threadpool(Accessor.login_hashed_password, user, user_login.password)
|
||||
|
|
|
|||
|
|
@ -16,6 +16,8 @@ from fastapi import APIRouter, Depends
|
|||
from pydantic import BaseModel
|
||||
from typing import Literal
|
||||
from pretor.utils.access import TokenData, Accessor
|
||||
from pretor.utils.check_user.role_check import RoleChecker
|
||||
from pretor.core.database.table.user import UserAuthority
|
||||
from typing import Dict
|
||||
from pretor.core.global_state_machine.model_provider.base_provider import Provider
|
||||
from pretor.utils.ray_hook import ray_actor_hook
|
||||
|
|
@ -30,17 +32,17 @@ class ProviderRegister(BaseModel):
|
|||
|
||||
@provider_router.post("")
|
||||
async def create_provider(provider_register: ProviderRegister,
|
||||
token_data: TokenData = Depends(Accessor.get_current_user)) -> None:
|
||||
token_data: TokenData = Depends(RoleChecker(allowed_roles=UserAuthority.USER))) -> None:
|
||||
global_state_machine = ray_actor_hook("global_state_machine")
|
||||
await global_state_machine.add_provider.remote(provider_type=provider_register.provider_type,
|
||||
provider_title=provider_register.provider_title,
|
||||
provider_url=provider_register.provider_url,
|
||||
provider_apikey=provider_register.provider_apikey,
|
||||
provider_owner=token_data.user_id)
|
||||
await global_state_machine.add_provider_wrap.remote(provider_type=provider_register.provider_type,
|
||||
provider_title=provider_register.provider_title,
|
||||
provider_url=provider_register.provider_url,
|
||||
provider_apikey=provider_register.provider_apikey,
|
||||
provider_owner=token_data.user_id)
|
||||
|
||||
|
||||
@provider_router.get("/list")
|
||||
async def get_provider_list(_: TokenData = Depends(Accessor.get_current_user)) -> Dict[str, Provider]:
|
||||
global_state_machine = ray_actor_hook("global_state_machine")
|
||||
provider_list: Dict[str, Provider] = await global_state_machine.get_provider_list.remote()
|
||||
provider_list: Dict[str, Provider] = await global_state_machine.provider_manager.remote("get_provider_list")
|
||||
return {"provider_list": provider_list}
|
||||
|
|
@ -17,13 +17,15 @@ import viceroy
|
|||
from pretor.core.workflow.workflow_template_generator.workflow_template import WorkflowTemplate
|
||||
from pretor.utils.ray_hook import ray_actor_hook
|
||||
from fastapi import APIRouter, Depends
|
||||
from pretor.utils.access import TokenData, Accessor
|
||||
from pretor.utils.access import TokenData
|
||||
from pretor.utils.check_user.role_check import RoleChecker
|
||||
from pretor.core.database.table.user import UserAuthority
|
||||
|
||||
resource_router = APIRouter(prefix="/api/v1/resource")
|
||||
|
||||
@resource_router.post("/workflow_template")
|
||||
async def create_workflow_template(workflow_template: WorkflowTemplate,
|
||||
_: TokenData = Depends(Accessor.get_current_user)):
|
||||
_: TokenData = Depends(RoleChecker(allowed_roles=UserAuthority.USER))):
|
||||
global_state_machine = ray_actor_hook("global_state_machine")
|
||||
await global_state_machine.workflow_template_generate.remote(workflow_template)
|
||||
return {"message": "创建成功"}
|
||||
|
|
@ -36,7 +38,7 @@ class Skill(BaseModel):
|
|||
|
||||
@resource_router.post("/skill")
|
||||
async def install_skill(skill: Skill,
|
||||
_: TokenData = Depends(Accessor.get_current_user)):
|
||||
_: TokenData = Depends(RoleChecker(allowed_roles=UserAuthority.USER))):
|
||||
global_state_machine = ray_actor_hook("global_state_machine")
|
||||
# noinspection PyUnresolvedReferences
|
||||
await viceroy.install_skill_async(url = skill.repo_url,
|
||||
|
|
@ -46,6 +48,6 @@ async def install_skill(skill: Skill,
|
|||
skill_name = skill.path.split("/")[-1]
|
||||
else:
|
||||
skill_name = skill.repo_url.split("/")[-1]
|
||||
await global_state_machine.add_skill.remote(skill_name)
|
||||
await global_state_machine.skill_manager.remote("add_skill", skill_name)
|
||||
return {"message": "创建成功"}
|
||||
|
||||
|
|
|
|||
|
|
@ -35,49 +35,22 @@ class PostgresDatabase:
|
|||
self.async_engine = create_async_engine(database_url, echo=True)
|
||||
self.async_session_maker = sessionmaker(self.async_engine, class_=AsyncSession, expire_on_commit=False)
|
||||
|
||||
self.auth_database = AuthDatabase(self.async_session_maker)
|
||||
self.provider_database = ProviderDatabase(self.async_session_maker)
|
||||
self.individual_database = IndividualDatabase(self.async_session_maker)
|
||||
self._auth_database = AuthDatabase(self.async_session_maker)
|
||||
self._provider_database = ProviderDatabase(self.async_session_maker)
|
||||
self._individual_database = IndividualDatabase(self.async_session_maker)
|
||||
|
||||
async def init_db(self) -> None:
|
||||
async with self.async_engine.begin() as conn:
|
||||
await conn.run_sync(SQLModel.metadata.create_all)
|
||||
|
||||
# provider_database操作
|
||||
async def get_providers(self):
|
||||
return await self.provider_database.get_provider()
|
||||
async def auth_database(self, method_name: str, *args, **kwargs):
|
||||
method = getattr(self._auth_database, method_name)
|
||||
return await method(*args, **kwargs)
|
||||
|
||||
async def add_provider(self, **kwargs):
|
||||
return await self.provider_database.add_provider(**kwargs)
|
||||
async def provider_database(self, method_name: str, *args, **kwargs):
|
||||
method = getattr(self._provider_database, method_name)
|
||||
return await method(*args, **kwargs)
|
||||
|
||||
# auth_database操作
|
||||
async def add_user(self, **kwargs):
|
||||
return await self.auth_database.add_user(**kwargs)
|
||||
|
||||
async def change_password(self, **kwargs):
|
||||
return await self.auth_database.change_password(**kwargs)
|
||||
|
||||
async def delete_user(self, **kwargs):
|
||||
return await self.auth_database.delete_user(**kwargs)
|
||||
|
||||
async def login_user(self, **kwargs):
|
||||
return await self.auth_database.login_user(**kwargs)
|
||||
|
||||
async def get_user_authority(self, **kwargs):
|
||||
return await self.auth_database.get_user_authority(**kwargs)
|
||||
|
||||
##individual_database 操作
|
||||
async def add_worker_individual(self, **kwargs):
|
||||
return await self.individual_database.add_worker_individual(**kwargs)
|
||||
|
||||
async def get_worker_individual(self, agent_id: str):
|
||||
return await self.individual_database.get_worker_individual(agent_id)
|
||||
|
||||
async def get_worker_individual_list(self, owner_id: str):
|
||||
return await self.individual_database.get_worker_individual_list(owner_id)
|
||||
|
||||
async def update_worker_individual(self, agent_id: str, **kwargs):
|
||||
return await self.individual_database.update_worker_individual(agent_id, **kwargs)
|
||||
|
||||
async def delete_worker_individual(self, agent_id: str):
|
||||
return await self.individual_database.delete_worker_individual(agent_id)
|
||||
async def individual_database(self, method_name: str, *args, **kwargs):
|
||||
method = getattr(self._individual_database, method_name)
|
||||
return await method(*args, **kwargs)
|
||||
|
|
@ -15,13 +15,13 @@
|
|||
from sqlmodel import SQLModel, Field
|
||||
from typing import List
|
||||
from sqlalchemy import Column, JSON
|
||||
from typing import Optional, Literal
|
||||
from typing import Optional
|
||||
|
||||
class Provider(SQLModel, table=True):
|
||||
__tablename__ = "provider"
|
||||
provider_id: str = Field(primary_key=True)
|
||||
provider_title: str = Field(index=True)
|
||||
provider_type: Literal["openai", "vllm"]
|
||||
provider_type: str
|
||||
|
||||
provider_url: Optional[str]
|
||||
provider_apikey: Optional[str]
|
||||
|
|
|
|||
|
|
@ -15,37 +15,63 @@
|
|||
import ray
|
||||
from pretor.core.global_state_machine.provider_manager import ProviderManager
|
||||
from pretor.core.global_state_machine.tool_manager import GlobalToolManager
|
||||
from pretor.core.global_state_machine.model_provider import Provider, ProviderArgs
|
||||
import httpx
|
||||
import pathlib
|
||||
import json
|
||||
from loguru import logger
|
||||
from typing import Dict, Literal, List
|
||||
from typing import Dict
|
||||
from pretor.core.database.postgres import PostgresDatabase
|
||||
from pretor.api.platform.event import PretorEvent
|
||||
import asyncio
|
||||
from pretor.core.workflow.workflow import PretorWorkflow
|
||||
from pretor.core.workflow.workflow_template_generator.workflow_template import WorkflowTemplate
|
||||
from pretor.core.workflow.workflow_template_manager import WorkflowManager
|
||||
from pretor.core.global_state_machine.skill_manager import GlobalSkillManager
|
||||
from pretor.plugin.tool_plugin.base_tool import BaseToolData
|
||||
|
||||
@ray.remote
|
||||
class GlobalStateMachine:
|
||||
def __init__(self, postgres_database: PostgresDatabase):
|
||||
|
||||
self.event_dict: Dict[int, PretorEvent] = {}
|
||||
self.global_provider_manager = ProviderManager(postgres_database)
|
||||
self.global_tool_manager = GlobalToolManager()
|
||||
self.global_workflow_template_manager = WorkflowManager()
|
||||
self.global_skill_manager = GlobalSkillManager()
|
||||
self.event_dict: Dict[str, PretorEvent] = {}
|
||||
|
||||
self._global_provider_manager = ProviderManager(postgres_database)
|
||||
self._global_tool_manager = GlobalToolManager()
|
||||
self._global_workflow_template_manager = WorkflowManager()
|
||||
self._global_skill_manager = GlobalSkillManager()
|
||||
|
||||
self.postgres_database = postgres_database
|
||||
|
||||
async def init_state_machine(self):
|
||||
await self.global_provider_manager.init_provider_register(self.postgres_database)
|
||||
await self._global_provider_manager.init_provider_register(self.postgres_database)
|
||||
|
||||
async def add_provider_wrap(self, provider_type, provider_title, provider_url, provider_apikey, provider_owner):
|
||||
return await self._global_provider_manager.add_provider(
|
||||
provider_type=provider_type,
|
||||
provider_title=provider_title,
|
||||
provider_url=provider_url,
|
||||
provider_apikey=provider_apikey,
|
||||
provider_owner=provider_owner,
|
||||
postgres_database=self.postgres_database
|
||||
)
|
||||
|
||||
async def provider_manager(self, method_name: str, *args, **kwargs):
|
||||
method = getattr(self._global_provider_manager, method_name)
|
||||
if asyncio.iscoroutinefunction(method):
|
||||
return await method(*args, **kwargs)
|
||||
return method(*args, **kwargs)
|
||||
|
||||
async def tool_manager(self, method_name: str, *args, **kwargs):
|
||||
method = getattr(self._global_tool_manager, method_name)
|
||||
if asyncio.iscoroutinefunction(method):
|
||||
return await method(*args, **kwargs)
|
||||
return method(*args, **kwargs)
|
||||
|
||||
async def workflow_template_manager(self, method_name: str, *args, **kwargs):
|
||||
method = getattr(self._global_workflow_template_manager, method_name)
|
||||
if asyncio.iscoroutinefunction(method):
|
||||
return await method(*args, **kwargs)
|
||||
return method(*args, **kwargs)
|
||||
|
||||
async def skill_manager(self, method_name: str, *args, **kwargs):
|
||||
method = getattr(self._global_skill_manager, method_name)
|
||||
if asyncio.iscoroutinefunction(method):
|
||||
return await method(*args, **kwargs)
|
||||
return method(*args, **kwargs)
|
||||
|
||||
###以下方法为event_dict方法
|
||||
def add_event(self, event: PretorEvent) -> None:
|
||||
|
|
@ -79,108 +105,3 @@ class GlobalStateMachine:
|
|||
|
||||
async def get_received(self, event_id) -> str:
|
||||
return await self.event_dict[event_id].receive_queue.get()
|
||||
|
||||
|
||||
###以下方法为global_provider_manager方法
|
||||
async def add_provider(self, provider_type: Literal["openai", "gemini", "claude"],
|
||||
provider_title: str,
|
||||
provider_url: str,
|
||||
provider_apikey: str,
|
||||
provider_owner: int) -> None:
|
||||
"""
|
||||
add_provider方法,注册供应商适配器(provider_manager方法)
|
||||
|
||||
Args
|
||||
provider_type: 注册商接口类型,目前只支持openai,gemini和claude接口
|
||||
provider_title: 供应商名称,为供应商提供的别名
|
||||
provider_url: 供应商url
|
||||
provider_apikey: 供应商所需要的apikey
|
||||
Returns:
|
||||
|
||||
"""
|
||||
provider_args: ProviderArgs = ProviderArgs(provider_title=provider_title,
|
||||
provider_url=provider_url,
|
||||
provider_apikey=provider_apikey,
|
||||
provider_owner=provider_owner)
|
||||
try:
|
||||
provider_class = self.global_provider_manager.provider_mapper.get(provider_type, None)
|
||||
if provider_class is None:
|
||||
logger.warning(f"Provider type {provider_type} is not supported.")
|
||||
return None
|
||||
provider: Provider = await provider_class.create_model(provider_args)
|
||||
|
||||
provider.provider_owner = provider_owner
|
||||
|
||||
self.global_provider_manager.provider_register[provider_title] = provider
|
||||
|
||||
await self.postgres_database.add_provider.remote(provider_title=provider.provider_title,
|
||||
provider_url=provider.provider_url,
|
||||
provider_apikey=provider.provider_apikey,
|
||||
provider_models=provider.provider_models,
|
||||
provider_type=provider.provider_type,
|
||||
provider_owner=provider.provider_owner)
|
||||
|
||||
logger.info(f"已添加适配器{provider_title}")
|
||||
except httpx.RequestError as e:
|
||||
logger.warning(f"[{provider_args.provider_title}] 网络请求异常: {e}")
|
||||
except Exception as e:
|
||||
logger.warning(f"[{provider_args.provider_title}] 解析模型列表时发生错误: {e}")
|
||||
|
||||
def get_provider_list(self) -> Dict[str, Provider]:
|
||||
"""
|
||||
get_provider_list方法,获取注册表(provider_manager方法)
|
||||
Returns:
|
||||
返回provider_register属性字典
|
||||
"""
|
||||
return self.global_provider_manager.provider_register
|
||||
|
||||
def get_provider(self, provider_title) -> Provider:
|
||||
"""
|
||||
get_provider方法,获取供应商信息(provider_manager方法)
|
||||
Args:
|
||||
provider_title:provider名称
|
||||
|
||||
Returns:
|
||||
Provider对象,返回注册在self.global_provider_manager.provider_register的供应商
|
||||
"""
|
||||
provider = self.global_provider_manager.provider_register.get(provider_title)
|
||||
return provider
|
||||
|
||||
|
||||
###以下为global_tool_manager方法
|
||||
def get_tool_list(self, agent_name: str) -> Dict[str, BaseToolData]:
|
||||
"""
|
||||
获取工具表方法
|
||||
Args:
|
||||
agent_name: agent的名字
|
||||
|
||||
Returns:
|
||||
返回该agent的tool,类型为dict
|
||||
"""
|
||||
tool_list = self.global_tool_manager.tool_mapper.get(agent_name, {})
|
||||
return tool_list
|
||||
|
||||
|
||||
###以下为workflow_template_manager方法
|
||||
def workflow_template_generate(self, workflow_template: WorkflowTemplate) -> None:
|
||||
self.global_workflow_template_manager.generate_workflow_template(workflow_template)
|
||||
|
||||
def get_workflow_template_list(self) -> List[Dict[str, str]]:
|
||||
return self.global_workflow_template_manager.workflow_templates_registry
|
||||
|
||||
|
||||
###以下为skill_manager方法
|
||||
def add_skill(self, skill_name: str):
|
||||
skill_plugin_dir = pathlib.Path(__file__).parent.parent.parent / "plugin" / "skill_plugin" / skill_name
|
||||
json_path = skill_plugin_dir / "skill.json"
|
||||
try:
|
||||
with open(json_path, "r", encoding="utf-8") as f:
|
||||
skill = json.load(f)
|
||||
name = skill.get("name")
|
||||
if name:
|
||||
self.global_skill_manager.skill_mapper[name] = (
|
||||
skill.get("description", ""),
|
||||
skill.get("instructions", "")
|
||||
)
|
||||
except (json.JSONDecodeError, OSError) as e:
|
||||
print(f"警告: 加载插件 {skill_name} 失败: {e}")
|
||||
|
|
|
|||
|
|
@ -33,6 +33,45 @@ class ProviderManager:
|
|||
self.provider_register = {}
|
||||
|
||||
async def init_provider_register(self, postgres) -> None:
|
||||
providers = await postgres.get_providers.remote()
|
||||
providers = await postgres.provider_database.remote("get_provider")
|
||||
for provider in providers:
|
||||
self.provider_register[provider.provider_title] = provider
|
||||
|
||||
async def add_provider(self, provider_type, provider_title, provider_url, provider_apikey, provider_owner, postgres_database) -> None:
|
||||
from pretor.core.global_state_machine.model_provider import ProviderArgs
|
||||
from loguru import logger
|
||||
import httpx
|
||||
|
||||
provider_args: ProviderArgs = ProviderArgs(provider_title=provider_title,
|
||||
provider_url=provider_url,
|
||||
provider_apikey=provider_apikey,
|
||||
provider_owner=provider_owner)
|
||||
try:
|
||||
provider_class = self.provider_mapper.get(provider_type, None)
|
||||
if provider_class is None:
|
||||
logger.warning(f"Provider type {provider_type} is not supported.")
|
||||
return None
|
||||
provider: Provider = await provider_class.create_model(provider_args)
|
||||
|
||||
provider.provider_owner = provider_owner
|
||||
|
||||
self.provider_register[provider_title] = provider
|
||||
|
||||
await postgres_database.provider_database.remote("add_provider", provider_title=provider.provider_title,
|
||||
provider_url=provider.provider_url,
|
||||
provider_apikey=provider.provider_apikey,
|
||||
provider_models=provider.provider_models,
|
||||
provider_type=provider.provider_type,
|
||||
provider_owner=provider.provider_owner)
|
||||
|
||||
logger.info(f"已添加适配器{provider_title}")
|
||||
except httpx.RequestError as e:
|
||||
logger.warning(f"[{provider_args.provider_title}] 网络请求异常: {e}")
|
||||
except Exception as e:
|
||||
logger.warning(f"[{provider_args.provider_title}] 解析模型列表时发生错误: {e}")
|
||||
|
||||
def get_provider_list(self):
|
||||
return self.provider_register
|
||||
|
||||
def get_provider(self, provider_title):
|
||||
return self.provider_register.get(provider_title)
|
||||
|
|
@ -55,7 +55,7 @@ class ConsciousnessNode:
|
|||
"请确保所有的思考和生成过程符合逻辑,严密且高质量。"
|
||||
)
|
||||
output_type = Union[ForSupervisoryNode, ForWorkflow, ForWorkflowEngine]
|
||||
provider: Provider = global_state_machine.get_provider.remote(provider_title)
|
||||
provider: Provider = global_state_machine.provider_manager.remote("get_provider", provider_title)
|
||||
agent_factory = AgentFactory()
|
||||
self.agent = agent_factory.create_agent(provider=provider,
|
||||
model_id=model_id,
|
||||
|
|
@ -85,7 +85,7 @@ class ConsciousnessNode:
|
|||
else:
|
||||
logger.error(f"ConsciousnessNode: 未知或不匹配的返回类型: {type(result)}")
|
||||
return None
|
||||
except Exception as e:
|
||||
except Exception:
|
||||
logger.exception("ConsciousnessNode在执行working时发生严重错误")
|
||||
return None
|
||||
|
||||
|
|
|
|||
|
|
@ -51,7 +51,7 @@ class ControlNode:
|
|||
"请注意:你的输出应当具体、实用,直接提供任务所要求的结果,不要做过多无关的寒暄。"
|
||||
)
|
||||
output_type = ForWorkflow
|
||||
provider: Provider = global_state_machine.get_provider.remote(provider_title)
|
||||
provider: Provider = global_state_machine.provider_manager.remote("get_provider", provider_title)
|
||||
agent_factory = AgentFactory()
|
||||
self.agent = agent_factory.create_agent(provider=provider,
|
||||
model_id=model_id,
|
||||
|
|
|
|||
|
|
@ -54,7 +54,7 @@ class SupervisoryNode:
|
|||
"请保持冷静、专业,并严格遵循上述路由规则。"
|
||||
)
|
||||
output_type = Union[ForConsciousnessNode, ForUser]
|
||||
provider: Provider = await global_state_machine.get_provider.remote(provider_title)
|
||||
provider: Provider = await global_state_machine.provider_manager.remote("get_provider", provider_title)
|
||||
agent_factory = AgentFactory()
|
||||
self.agent = agent_factory.create_agent(provider=provider,
|
||||
model_id=model_id,
|
||||
|
|
@ -155,7 +155,7 @@ class SupervisoryNode:
|
|||
time_str = datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
||||
try:
|
||||
global_state_machine = ray_actor_hook("global_state_machine")
|
||||
workflow_template_dict = await global_state_machine.get_workflow_template_list.remote()
|
||||
workflow_template_dict = await global_state_machine.workflow_template_manager.remote("get_workflow_template_list")
|
||||
available_templates_str = "\n".join([f"- 名称: {k}, 描述/内容: {v}" for k, v in
|
||||
workflow_template_dict.items()]) if workflow_template_dict else "暂无注册的工作流模板"
|
||||
deps = SupervisoryNodeDeps(
|
||||
|
|
|
|||
|
|
@ -158,7 +158,7 @@ class WorkflowEngine:
|
|||
logger.info(f"Supervisory 最终回复:{user_response}")
|
||||
else:
|
||||
logger.warning("未提供 supervisory_node 句柄,跳过用户反馈生成。")
|
||||
except Exception as e:
|
||||
except Exception:
|
||||
logger.exception("生成工作流执行汇报时发生错误")
|
||||
|
||||
async def _dispatch_to_node(self, step: WorkStep, input_data: Any) -> tuple[Any, bool]:
|
||||
|
|
@ -207,7 +207,7 @@ class WorkflowEngine:
|
|||
else:
|
||||
raise WorkflowError(f"未知的节点类型:{step.node}")
|
||||
|
||||
except Exception as e:
|
||||
except Exception:
|
||||
logger.exception(f"节点 {step.node} 执行动作 {step.action} 失败")
|
||||
return None, False
|
||||
|
||||
|
|
@ -245,7 +245,7 @@ class WorkflowRunningEngine:
|
|||
self.consciousness_node = consciousness_node
|
||||
self.control_node = control_node
|
||||
self.supervisory_node = supervisory_node
|
||||
self.global_state_machine = ray_actor_hook("global_state_machine")
|
||||
self.global_state_machine = ray_actor_hook("global_state_machine").global_state_machine
|
||||
|
||||
async def run(self):
|
||||
self.runner_engine = {
|
||||
|
|
|
|||
|
|
@ -40,6 +40,5 @@ class WorkflowManager:
|
|||
try:
|
||||
workflow_template = self.workflow_template_generator.generate_workflow_template(workflow_template=workflow_template)
|
||||
self.workflow_templates_registry[workflow_template.name] = workflow_template.desc
|
||||
except Exception as e:
|
||||
except Exception:
|
||||
logger.exception("Failed to generate workflow template")
|
||||
|
||||
|
|
|
|||
|
|
@ -21,7 +21,7 @@ from pretor.utils.ray_hook import ray_actor_hook
|
|||
@lru_cache
|
||||
async def get_authority(user_id: str) -> UserAuthority:
|
||||
postgres_database = ray_actor_hook("postgres_database")
|
||||
user_authority = await postgres_database.get_user_authority.remote(user_id=user_id)
|
||||
user_authority = await postgres_database.auth_database.remote("get_user_authority", user_id=user_id)
|
||||
return user_authority
|
||||
|
||||
class RoleChecker:
|
||||
|
|
|
|||
|
|
@ -49,7 +49,7 @@ def del_tool_cache(tool_name: str) -> None:
|
|||
@lru_cache(maxsize=1)
|
||||
async def get_tool(agent_name: str) -> List[Callable]:
|
||||
global_state_machine = ray_actor_hook("global_state_machine")
|
||||
_tool_list = await global_state_machine.get_tool_list.remote(agent_name)
|
||||
_tool_list = await global_state_machine.tool_manager.remote("get_tool_list", agent_name)
|
||||
tool_list = []
|
||||
for tool_name in _tool_list.keys():
|
||||
tool_func = _get_tool_func(tool_name)
|
||||
|
|
|
|||
|
|
@ -65,6 +65,8 @@ async def test_postgres_database(mock_env_get, mock_provider_db, mock_auth_db, m
|
|||
)
|
||||
mock_auth_db.assert_called_once()
|
||||
mock_provider_db.assert_called_once()
|
||||
mock_auth_db.return_value.get_user_authority = AsyncMock(return_value="test_auth")
|
||||
assert await db.auth_database("get_user_authority", user_id="123") == "test_auth"
|
||||
|
||||
with patch("pretor.core.database.postgres.SQLModel.metadata.create_all") as mock_create_all:
|
||||
await db.init_db()
|
||||
|
|
|
|||
|
|
@ -2,6 +2,6 @@ import pytest
|
|||
from pretor.core.database.table.user import User
|
||||
|
||||
def test_user_table():
|
||||
user = User(user_name="name", hashed_password="pw")
|
||||
user = User(user_id="id", user_name="name", hashed_password="pw")
|
||||
assert User.__tablename__ == 'user'
|
||||
assert user.user_name == "name"
|
||||
|
|
|
|||
|
|
@ -40,9 +40,8 @@ def mock_postgres():
|
|||
|
||||
@pytest.fixture
|
||||
def gsm(mock_postgres):
|
||||
with patch("pretor.core.global_state_machine.global_state_machine.ProviderManager") as mock_pm:
|
||||
manager = GlobalStateMachine(mock_postgres)
|
||||
return manager
|
||||
manager = GlobalStateMachine(mock_postgres)
|
||||
return manager
|
||||
|
||||
|
||||
def test_add_delete_get_event(gsm):
|
||||
|
|
@ -103,24 +102,24 @@ async def test_add_provider_success(gsm, mock_postgres):
|
|||
mock_provider.provider_type = "openai"
|
||||
mock_provider_class.create_model.return_value = mock_provider
|
||||
|
||||
gsm.global_provider_manager.provider_mapper = {"openai": mock_provider_class}
|
||||
gsm.global_provider_manager.provider_register = {}
|
||||
gsm._global_provider_manager.provider_mapper = {"openai": mock_provider_class}
|
||||
gsm._global_provider_manager.provider_register = {}
|
||||
|
||||
mock_add_provider = AsyncMock()
|
||||
mock_postgres.provider_database.add_provider.remote = mock_add_provider
|
||||
mock_postgres.provider_database.remote = mock_add_provider
|
||||
|
||||
await gsm.add_provider("openai", "title", "url", "key", 1)
|
||||
await gsm.add_provider_wrap("openai", "title", "url", "key", 1)
|
||||
|
||||
assert gsm.global_provider_manager.provider_register["title"] == mock_provider
|
||||
assert gsm._global_provider_manager.provider_register["title"] == mock_provider
|
||||
mock_add_provider.assert_called_once()
|
||||
assert mock_provider.provider_owner == 1
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_add_provider_unsupported(gsm):
|
||||
gsm.global_provider_manager.provider_mapper = {}
|
||||
with patch("pretor.core.global_state_machine.global_state_machine.logger") as mock_logger:
|
||||
await gsm.add_provider("magic", "title", "url", "key", 1)
|
||||
gsm._global_provider_manager.provider_mapper = {}
|
||||
with patch("loguru.logger") as mock_logger:
|
||||
await gsm.add_provider_wrap("magic", "title", "url", "key", 1)
|
||||
mock_logger.warning.assert_called_with("Provider type magic is not supported.")
|
||||
|
||||
|
||||
|
|
@ -129,10 +128,10 @@ async def test_add_provider_request_error(gsm):
|
|||
from httpx import RequestError
|
||||
mock_provider_class = AsyncMock()
|
||||
mock_provider_class.create_model.side_effect = RequestError("Network Error", request=MagicMock())
|
||||
gsm.global_provider_manager.provider_mapper = {"openai": mock_provider_class}
|
||||
gsm._global_provider_manager.provider_mapper = {"openai": mock_provider_class}
|
||||
|
||||
with patch("pretor.core.global_state_machine.global_state_machine.logger") as mock_logger:
|
||||
await gsm.add_provider("openai", "title", "url", "key", 1)
|
||||
with patch("loguru.logger") as mock_logger:
|
||||
await gsm.add_provider_wrap("openai", "title", "url", "key", 1)
|
||||
mock_logger.warning.assert_called_once()
|
||||
assert "网络请求异常" in mock_logger.warning.call_args[0][0]
|
||||
|
||||
|
|
@ -141,18 +140,18 @@ async def test_add_provider_request_error(gsm):
|
|||
async def test_add_provider_generic_error(gsm):
|
||||
mock_provider_class = AsyncMock()
|
||||
mock_provider_class.create_model.side_effect = ValueError("Some Error")
|
||||
gsm.global_provider_manager.provider_mapper = {"openai": mock_provider_class}
|
||||
gsm._global_provider_manager.provider_mapper = {"openai": mock_provider_class}
|
||||
|
||||
with patch("pretor.core.global_state_machine.global_state_machine.logger") as mock_logger:
|
||||
await gsm.add_provider("openai", "title", "url", "key", 1)
|
||||
with patch("loguru.logger") as mock_logger:
|
||||
await gsm.add_provider_wrap("openai", "title", "url", "key", 1)
|
||||
mock_logger.warning.assert_called_once()
|
||||
assert "解析模型列表时发生错误" in mock_logger.warning.call_args[0][0]
|
||||
|
||||
|
||||
def test_get_provider_list_and_get_provider(gsm):
|
||||
mock_provider = MagicMock()
|
||||
gsm.global_provider_manager.provider_register = {"p1": mock_provider}
|
||||
gsm._global_provider_manager.provider_register = {"p1": mock_provider}
|
||||
|
||||
assert gsm.get_provider_list() == {"p1": mock_provider}
|
||||
assert gsm.get_provider("p1") == mock_provider
|
||||
assert gsm.get_provider("missing") is None
|
||||
assert gsm._global_provider_manager.get_provider_list() == {"p1": mock_provider}
|
||||
assert gsm._global_provider_manager.get_provider("p1") == mock_provider
|
||||
assert gsm._global_provider_manager.get_provider("missing") is None
|
||||
|
|
|
|||
|
|
@ -16,6 +16,8 @@ async def test_provider_manager_init():
|
|||
mock_postgres.get_providers.remote = AsyncMock(return_value=[mock_provider1, mock_provider2])
|
||||
|
||||
manager = ProviderManager(mock_postgres)
|
||||
mock_postgres.provider_database = MagicMock()
|
||||
mock_postgres.provider_database.remote = AsyncMock(return_value=[mock_provider1, mock_provider2])
|
||||
await manager.init_provider_register(mock_postgres)
|
||||
|
||||
assert "openai" in manager.provider_mapper
|
||||
|
|
|
|||
|
|
@ -119,20 +119,17 @@ async def test_workflow_running_engine_runner():
|
|||
from pretor.core.individual.consciousness_node.template import ForWorkflowEngine
|
||||
|
||||
mock_consciousness = MagicMock()
|
||||
|
||||
mock_wf = MagicMock()
|
||||
mock_wf.trace_id = "test_trace"
|
||||
mock_wf.title = "test_title"
|
||||
|
||||
mock_result = MagicMock(spec=ForWorkflowEngine)
|
||||
mock_result.workflow = mock_wf
|
||||
|
||||
mock_consciousness.working.remote = AsyncMock(return_value=mock_result)
|
||||
|
||||
engine = WorkflowRunningEngine(mock_consciousness, "control", "supervisor")
|
||||
engine.workflow_queue = asyncio.Queue()
|
||||
|
||||
# Use real PretorEvent to avoid Pydantic validation errors on MagicMock properties
|
||||
mock_event = PretorEvent(
|
||||
platform="test_platform",
|
||||
user_id="test_user",
|
||||
|
|
@ -142,18 +139,19 @@ async def test_workflow_running_engine_runner():
|
|||
)
|
||||
await engine.workflow_queue.put(mock_event)
|
||||
|
||||
with patch("pretor.core.workflow.workflow_runner.WorkflowEngine") as mock_wf_engine_cls, \
|
||||
patch("pretor.core.workflow.workflow_runner.ray_actor_hook") as mock_hook:
|
||||
mock_gsm = MagicMock()
|
||||
mock_gsm.update_workflow.remote = AsyncMock()
|
||||
mock_hook.return_value = mock_gsm
|
||||
with patch("pretor.core.workflow.workflow_runner.WorkflowEngine") as mock_wf_engine_cls, patch("builtins.open", new_callable=MagicMock) as mock_open:
|
||||
|
||||
# Instead of patching hook, we inject it directly
|
||||
engine.global_state_machine = AsyncMock()
|
||||
|
||||
mock_open.return_value.__enter__.return_value.read.return_value = '{}'
|
||||
|
||||
mock_engine_instance = MagicMock()
|
||||
mock_engine_instance.run = AsyncMock()
|
||||
mock_wf_engine_cls.return_value = mock_engine_instance
|
||||
|
||||
task = asyncio.create_task(engine.runner(1))
|
||||
await asyncio.sleep(0.05) # Give runner time to process one item
|
||||
task.cancel() # Stop the infinite loop
|
||||
await asyncio.sleep(0.05)
|
||||
task.cancel()
|
||||
|
||||
mock_wf_engine_cls.assert_called_with(mock_wf, mock_consciousness, "control", "supervisor")
|
||||
|
|
@ -24,7 +24,7 @@ def test_pretor_workflow_validation_success():
|
|||
ws1 = WorkStep(step=1, node="control_node", action="a1", desc="d1")
|
||||
ws2 = WorkStep(step=2, node="supervisory_node", action="a2", desc="d2")
|
||||
wg = WorkerGroup(name="g1", primary_individual={"coder": 1}, composite_individual={})
|
||||
wf = PretorWorkflow(title="wf1", workgroup_list=[wg], work_link=[ws1, ws2], trace_id="t", event_info={"platform":"a", "username":"b"})
|
||||
wf = PretorWorkflow(title="wf1", workgroup_list=[wg], work_link=[ws1, ws2], trace_id="t", event_info={"platform":"a", "user_name":"b"})
|
||||
assert wf.title == "wf1"
|
||||
|
||||
def test_pretor_workflow_validation_error_step_discontinuous():
|
||||
|
|
@ -32,7 +32,7 @@ def test_pretor_workflow_validation_error_step_discontinuous():
|
|||
ws2 = WorkStep(step=3, node="supervisory_node", action="a2", desc="d2")
|
||||
wg = WorkerGroup(name="g1", primary_individual={}, composite_individual={})
|
||||
with pytest.raises(ValueError, match="工作链步数不连续"):
|
||||
PretorWorkflow(title="wf1", workgroup_list=[wg], work_link=[ws1, ws2], trace_id="t", event_info={"platform":"a", "username":"b"})
|
||||
PretorWorkflow(title="wf1", workgroup_list=[wg], work_link=[ws1, ws2], trace_id="t", event_info={"platform":"a", "user_name":"b"})
|
||||
|
||||
def test_pretor_workflow_validation_error_jump_out_of_bounds():
|
||||
lg = LogicGate(if_fail="jump_to_step_3", if_pass="continue")
|
||||
|
|
@ -40,14 +40,14 @@ def test_pretor_workflow_validation_error_jump_out_of_bounds():
|
|||
ws2 = WorkStep(step=2, node="supervisory_node", action="a2", desc="d2")
|
||||
wg = WorkerGroup(name="g1", primary_individual={}, composite_individual={})
|
||||
with pytest.raises(ValueError, match="跳转目标 Step 3 越界了"):
|
||||
PretorWorkflow(title="wf1", workgroup_list=[wg], work_link=[ws1, ws2], trace_id="t", event_info={"platform":"a", "username":"b"})
|
||||
PretorWorkflow(title="wf1", workgroup_list=[wg], work_link=[ws1, ws2], trace_id="t", event_info={"platform":"a", "user_name":"b"})
|
||||
|
||||
def test_pretor_workflow_validation_error_jump_format_error():
|
||||
lg = LogicGate(if_fail="jump_to_step_invalid", if_pass="continue")
|
||||
ws1 = WorkStep(step=1, node="control_node", action="a1", desc="d1", logic_gate=lg)
|
||||
wg = WorkerGroup(name="g1", primary_individual={}, composite_individual={})
|
||||
with pytest.raises(ValueError, match="LogicGate 格式错误"):
|
||||
PretorWorkflow(title="wf1", workgroup_list=[wg], work_link=[ws1], trace_id="t", event_info={"platform":"a", "username":"b"})
|
||||
PretorWorkflow(title="wf1", workgroup_list=[wg], work_link=[ws1], trace_id="t", event_info={"platform":"a", "user_name":"b"})
|
||||
|
||||
def test_workflow_status():
|
||||
status = WorkflowStatus()
|
||||
|
|
|
|||
Loading…
Reference in New Issue