feat:使用pytanticAI再次重构,增加了对于api的接口管理
This commit is contained in:
parent
c672c60af6
commit
46937fbc10
|
|
@ -0,0 +1 @@
|
|||
from pretor.adapter.model_adapter.provider_manager import ProviderManager
|
||||
|
|
@ -0,0 +1,20 @@
|
|||
from pydantic_ai.models.openai import OpenAIChatModel
|
||||
from pydantic_ai.models.google import GoogleModel
|
||||
from pydantic_ai.models.anthropic import AnthropicModel
|
||||
from pydantic_ai.providers.openai import OpenAIProvider
|
||||
from pydantic_ai.providers.google import GoogleProvider
|
||||
from pydantic_ai.providers.anthropic import AnthropicProvider
|
||||
from pathlib import Path
|
||||
|
||||
class AgentFactory:
|
||||
def __init__(self):
|
||||
self._models_mapping = {"openai": (OpenAIChatModel, OpenAIProvider), "gemini": (GoogleModel, GoogleProvider), "claude": (AnthropicModel, AnthropicProvider)}
|
||||
|
||||
def _load_agent_protocol(self):
|
||||
pass
|
||||
|
||||
def create_model(self, protocol_name: str, api_key: str, url: str | None, model_id: str):
|
||||
if protocol_name not in self._models_mapping:
|
||||
raise ValueError(f"不支持的协议类型: {protocol_name}")
|
||||
model_class, provider_class = self._models_mapping[protocol_name]
|
||||
return model_class(model_id, provider_class(api_key = api_key, url = url))
|
||||
|
|
@ -0,0 +1,4 @@
|
|||
from pretor.adapter.model_adapter.model_provider.base_provider import Provider, ProviderArgs
|
||||
from pretor.adapter.model_adapter.model_provider.openai_provider import OpenAIProvider
|
||||
from pretor.adapter.model_adapter.model_provider.gemini_provider import GeminiProvider
|
||||
from pretor.adapter.model_adapter.model_provider.claude_provider import ClaudeProvider
|
||||
|
|
@ -0,0 +1,33 @@
|
|||
from abc import ABC, abstractmethod
|
||||
from pydantic import BaseModel
|
||||
from typing import List
|
||||
|
||||
class Provider(BaseModel):
|
||||
provider_title: str
|
||||
provider_url: str
|
||||
provider_apikey: str
|
||||
provider_models: List[str]
|
||||
provider_type: str
|
||||
|
||||
class ProviderArgs(BaseModel):
|
||||
provider_title: str
|
||||
provider_url: str
|
||||
provider_apikey: str
|
||||
|
||||
class BaseProvider(ABC):
|
||||
@staticmethod
|
||||
@abstractmethod
|
||||
async def create_model(provider_args: ProviderArgs) -> Provider:
|
||||
pass
|
||||
|
||||
@staticmethod
|
||||
@abstractmethod
|
||||
async def _load_models(provider_args: ProviderArgs) -> List[str]:
|
||||
pass
|
||||
|
||||
@staticmethod
|
||||
@abstractmethod
|
||||
def _return_provider(provider_args: ProviderArgs, provider_models: List[str]) -> Provider:
|
||||
pass
|
||||
|
||||
|
||||
|
|
@ -0,0 +1,44 @@
|
|||
from pretor.adapter.model_adapter.model_provider.base_provider import BaseProvider, Provider, ProviderArgs
|
||||
import httpx
|
||||
from typing import List
|
||||
|
||||
class ClaudeProvider(BaseProvider):
|
||||
@staticmethod
|
||||
async def create_model(provider_args: ProviderArgs) -> Provider:
|
||||
provider_models: List[str] = await ClaudeProvider._load_models(provider_args)
|
||||
provider: Provider = ClaudeProvider._return_provider(provider_args, provider_models)
|
||||
return provider
|
||||
|
||||
@staticmethod
|
||||
async def _load_models(provider_args: ProviderArgs) -> List[str]:
|
||||
# Anthropic 官方需要 version 头
|
||||
headers = {
|
||||
"x-api-key": provider_args.provider_apikey,
|
||||
"anthropic-version": "2023-06-01",
|
||||
"Content-Type": "application/json"
|
||||
}
|
||||
# 如果是官方 API,通常使用 /v1/models (如果支持)
|
||||
# 注意:很多时候 Anthropic 并不返回完整列表,如果请求失败,建议返回硬编码的常用模型
|
||||
url = f"{provider_args.provider_url.rstrip('/')}/v1/models"
|
||||
|
||||
try:
|
||||
async with httpx.AsyncClient(timeout=10.0) as client:
|
||||
response = await client.get(url, headers=headers)
|
||||
if response.status_code == 200:
|
||||
data = response.json()
|
||||
model_ids = [m["id"] for m in data.get("data", [])]
|
||||
return sorted(model_ids)
|
||||
else:
|
||||
# 如果官方列表接口不可用,fallback 到已知常用模型
|
||||
return ["claude-3-5-sonnet-20240620", "claude-3-opus-20240229", "claude-3-haiku-20240307"]
|
||||
except Exception as e:
|
||||
print(f"[{provider_args.provider_title}] 获取 Claude 模型列表错误: {e}")
|
||||
return ["claude-3-5-sonnet-20240620"]
|
||||
|
||||
@staticmethod
|
||||
def _return_provider(provider_args: ProviderArgs, provider_models: List[str]) -> Provider:
|
||||
return Provider(provider_title=provider_args.provider_title,
|
||||
provider_apikey=provider_args.provider_apikey,
|
||||
provider_url=provider_args.provider_url,
|
||||
provider_models=provider_models,
|
||||
provider_type="claude")
|
||||
|
|
@ -0,0 +1,45 @@
|
|||
from pretor.adapter.model_adapter.model_provider.base_provider import BaseProvider, Provider, ProviderArgs
|
||||
import httpx
|
||||
from typing import List
|
||||
|
||||
class GeminiProvider(BaseProvider):
|
||||
@staticmethod
|
||||
async def create_model(provider_args: ProviderArgs) -> Provider:
|
||||
provider_models: List[str] = await GeminiProvider._load_models(provider_args)
|
||||
provider: Provider = GeminiProvider._return_provider(provider_args, provider_models)
|
||||
return provider
|
||||
|
||||
@staticmethod
|
||||
async def _load_models(provider_args: ProviderArgs) -> List[str]:
|
||||
# Google Gemini 原生鉴权通常使用 x-goog-api-key 或 query parameter
|
||||
headers = {
|
||||
"x-goog-api-key": provider_args.provider_apikey,
|
||||
"Content-Type": "application/json"
|
||||
}
|
||||
# 官方路径通常是 v1beta/models
|
||||
url = f"{provider_args.provider_url.rstrip('/')}/v1beta/models"
|
||||
|
||||
try:
|
||||
async with httpx.AsyncClient(timeout=10.0) as client:
|
||||
response = await client.get(url, headers=headers)
|
||||
if response.status_code != 200:
|
||||
print(f"[{provider_args.provider_title}] 获取 Gemini 模型失败: {response.status_code}")
|
||||
return []
|
||||
|
||||
data = response.json()
|
||||
# Gemini 返回的结构中模型 ID 通常带 "models/" 前缀
|
||||
raw_models = data.get("models", [])
|
||||
model_ids = [m["name"].split("/")[-1] for m in raw_models if
|
||||
"generateContent" in m.get("supportedGenerationMethods", [])]
|
||||
return sorted(list(set(model_ids)))
|
||||
except Exception as e:
|
||||
print(f"[{provider_args.provider_title}] 获取 Gemini 模型列表错误: {e}")
|
||||
return []
|
||||
|
||||
@staticmethod
|
||||
def _return_provider(provider_args: ProviderArgs, provider_models: List[str]) -> Provider:
|
||||
return Provider(provider_title=provider_args.provider_title,
|
||||
provider_apikey=provider_args.provider_apikey,
|
||||
provider_url=provider_args.provider_url,
|
||||
provider_models=provider_models,
|
||||
provider_type="gemini")
|
||||
|
|
@ -0,0 +1,42 @@
|
|||
from pretor.adapter.model_adapter.model_provider.base_provider import BaseProvider, Provider, ProviderArgs
|
||||
import httpx
|
||||
from typing import List
|
||||
|
||||
class OpenAIProvider(BaseProvider):
|
||||
@staticmethod
|
||||
async def create_model(provider_args: ProviderArgs) -> Provider:
|
||||
provider_models: List[str] = await OpenAIProvider._load_models(provider_args)
|
||||
provider: Provider = OpenAIProvider._return_provider(provider_args, provider_models)
|
||||
return provider
|
||||
|
||||
@staticmethod
|
||||
async def _load_models(provider_args: ProviderArgs) -> List[str]:
|
||||
headers = {
|
||||
"Authorization": f"Bearer {provider_args.provider_apikey}",
|
||||
"Content-Type": "application/json"
|
||||
}
|
||||
url = f"{provider_args.provider_url}/models" if "/v1" in provider_args.provider_url else f"{provider_args.provider_url}/v1/models"
|
||||
try:
|
||||
async with httpx.AsyncClient(timeout=10.0) as client:
|
||||
response = await client.get(url, headers=headers)
|
||||
if response.status_code != 200:
|
||||
print(f"[{provider_args.provider_title}] 获取模型失败: {response.status_code}")
|
||||
return []
|
||||
data = response.json()
|
||||
raw_models = data.get("data", [])
|
||||
model_ids = [m["id"] for m in raw_models]
|
||||
return sorted(model_ids)
|
||||
except httpx.RequestError as e:
|
||||
print(f"[{provider_args.provider_title}] 网络请求异常: {e}")
|
||||
return []
|
||||
except Exception as e:
|
||||
print(f"[{provider_args.provider_title}] 解析模型列表时发生错误: {e}")
|
||||
return []
|
||||
|
||||
@staticmethod
|
||||
def _return_provider(provider_args: ProviderArgs, provider_models: List[str]) -> Provider:
|
||||
return Provider(provider_title=provider_args.provider_title,
|
||||
provider_apikey=provider_args.provider_apikey,
|
||||
provider_url=provider_args.provider_url,
|
||||
provider_models=provider_models,
|
||||
provider_type="openai")
|
||||
|
|
@ -0,0 +1,57 @@
|
|||
from pretor.adapter.model_adapter._agent_factory import AgentFactory
|
||||
from pretor.adapter.model_adapter.model_provider import Provider, ProviderArgs, OpenAIProvider,GeminiProvider, ClaudeProvider
|
||||
from pydantic_ai import Agent
|
||||
import httpx
|
||||
from pretor.utils.error import ModelNotExistError, ProviderNotExistError
|
||||
from loguru import logger
|
||||
from typing import Dict
|
||||
|
||||
class ProviderManager:
|
||||
def __init__(self):
|
||||
self._provider_mapper = {"openai": OpenAIProvider, "gemini": GeminiProvider, "claude": ClaudeProvider}
|
||||
self._agent_factory = AgentFactory()
|
||||
self.provider_register = {}
|
||||
|
||||
async def add_provider(self, provider_type: str, provider_title: str, provider_url: str, provider_apikey: str) -> None:
|
||||
"""
|
||||
add_provider方法,注册供应商适配器
|
||||
:param provider_type: 注册商接口类型,目前只支持openai,gemini和claude接口
|
||||
:param provider_title: 供应商名称,为供应商提供的别名
|
||||
:param provider_url: 供应商url
|
||||
:param provider_apikey: 供应商所需要的apikey
|
||||
:return:
|
||||
"""
|
||||
provider_args: ProviderArgs = ProviderArgs(provider_title=provider_title, provider_url=provider_url, provider_apikey=provider_apikey)
|
||||
try:
|
||||
if provider_type not in self._provider_mapper.keys():
|
||||
logger.warning(f"Provider type {provider_type} is not supported.")
|
||||
return None
|
||||
provider_class = self._provider_mapper.get(provider_type)
|
||||
provider: Provider = await provider_class.create_model(provider_args)
|
||||
self.provider_register[provider_title] = provider
|
||||
logger.info(f"已添加适配器{provider_title}")
|
||||
except httpx.RequestError as e:
|
||||
logger.warning(f"[{provider_args.provider_title}] 网络请求异常: {e}")
|
||||
except Exception as e:
|
||||
logger.warning(f"[{provider_args.provider_title}] 解析模型列表时发生错误: {e}")
|
||||
|
||||
def create_agent(self, agent_name: str, system_prompt: str, provider_title: str, model_id: str) -> Agent:
|
||||
"""
|
||||
create_agent方法,将保存的适配器转化为agent对象并返回
|
||||
:param agent_name: agent名字,代表实例化个体起的名字
|
||||
:param system_prompt: 系统提示词,给llm的系统提示词
|
||||
:param provider_title: 供应商名称
|
||||
:param model_id: 模型Id,实例化agent所输入的model_id
|
||||
:return:
|
||||
"""
|
||||
if provider_title not in self.provider_register:
|
||||
raise ProviderNotExistError("提供商不存在")
|
||||
provider = self.provider_register[provider_title]
|
||||
if model_id not in provider.provider_models:
|
||||
raise ModelNotExistError("模型不存在")
|
||||
model = self._agent_factory.create_model(provider.provider_type, provider.provider_apikey, provider.provider_url, model_id)
|
||||
agent = Agent(model=model,name=agent_name,system_prompt=system_prompt)
|
||||
return agent
|
||||
|
||||
def get_provider_list(self) -> Dict[str, Provider]:
|
||||
return self.provider_register
|
||||
|
|
@ -1,76 +0,0 @@
|
|||
import httpx
|
||||
import json
|
||||
from typing import List, Dict, Any, AsyncGenerator
|
||||
from pretor.protocol_plugin.model_protocol.modelbase import ModelBase
|
||||
|
||||
class GeminiAdapter(ModelBase):
|
||||
def __init__(self, base_url: str, adapter_title: str, api_key: str):
|
||||
self.adapter_title: str = adapter_title
|
||||
self.base_url = base_url.rstrip('/')
|
||||
if not self.base_url.endswith('/v1'):
|
||||
self.base_url += '/v1'
|
||||
self.api_key = api_key
|
||||
self.model_list = []
|
||||
|
||||
async def get_model(self) -> List[str]:
|
||||
url = f"{self.base_url}/models"
|
||||
headers = {"Authorization": f"Bearer {self.api_key}"}
|
||||
|
||||
async with httpx.AsyncClient(timeout=10.0) as client:
|
||||
response = await client.get(url, headers=headers)
|
||||
response.raise_for_status()
|
||||
data = response.json()
|
||||
self.model_list = [m.get("id", "") for m in data.get("data", [])]
|
||||
return self.model_list
|
||||
|
||||
async def post_message(
|
||||
self,
|
||||
model: str,
|
||||
messages: List[Dict[str, str]],
|
||||
stream: bool = False,
|
||||
temperature: float = 0.7,
|
||||
max_tokens: int = 4096,
|
||||
**kwargs
|
||||
) -> Any:
|
||||
url = f"{self.base_url}/chat/completions"
|
||||
headers = {
|
||||
"Authorization": f"Bearer {self.api_key}",
|
||||
"Content-Type": "application/json"
|
||||
}
|
||||
|
||||
payload = {
|
||||
"model": model,
|
||||
"messages": messages,
|
||||
"stream": stream,
|
||||
"temperature": temperature,
|
||||
"max_tokens": max_tokens,
|
||||
**kwargs
|
||||
}
|
||||
|
||||
# 144GB 显存或云端长文本建议设置较长超时
|
||||
timeout = httpx.Timeout(120.0, connect=10.0)
|
||||
|
||||
if not stream:
|
||||
async with httpx.AsyncClient(timeout=timeout) as client:
|
||||
response = await client.post(url, headers=headers, json=payload)
|
||||
response.raise_for_status()
|
||||
return response.json()
|
||||
else:
|
||||
return self._handle_stream(url, headers, payload)
|
||||
|
||||
@staticmethod
|
||||
async def _handle_stream(self, url: str, headers: dict, payload: dict) -> AsyncGenerator[str, None]:
|
||||
async with httpx.AsyncClient(timeout=None) as client:
|
||||
async with client.stream("POST", url, headers=headers, json=payload) as response:
|
||||
response.raise_for_status()
|
||||
async for line in response.aiter_lines():
|
||||
if not line.strip() or line == "data: [DONE]":
|
||||
continue
|
||||
if line.startswith("data: "):
|
||||
try:
|
||||
chunk = json.loads(line[6:])
|
||||
delta = chunk["choices"][0]["delta"].get("content", "")
|
||||
if delta:
|
||||
yield delta
|
||||
except (json.JSONDecodeError, KeyError):
|
||||
continue
|
||||
|
|
@ -1,10 +0,0 @@
|
|||
from abc import ABC,abstractmethod
|
||||
|
||||
class ModelBase(ABC):
|
||||
@abstractmethod
|
||||
async def get_model(self):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def post_message(self, model: str, messages: list, stream: bool = False, **kwargs):
|
||||
pass
|
||||
|
|
@ -1,39 +0,0 @@
|
|||
import httpx
|
||||
from pretor.protocol_plugin.model_protocol.modelbase import ModelBase
|
||||
|
||||
class OpenAIAdapter(ModelBase):
|
||||
def __init__(self, base_url: str, adapter_title: str, api_key: str = "archon-local"):
|
||||
self.adapter_title: str = adapter_title
|
||||
self.base_url = base_url.rstrip('/')
|
||||
if not self.base_url.endswith('/v1'):
|
||||
self.base_url += '/v1'
|
||||
self.api_key = api_key
|
||||
self.model_list = []
|
||||
|
||||
async def get_model(self):
|
||||
url = "{}/models".format(self.base_url)
|
||||
headers = {
|
||||
"Authorization": f"Bearer {self.api_key}"
|
||||
}
|
||||
async with httpx.AsyncClient() as client:
|
||||
response = await client.get(url, headers=headers)
|
||||
response.raise_for_status()
|
||||
response = response.json()
|
||||
self.model_list = [m.get("id", "") for m in response.get("data", [])]
|
||||
return self.model_list
|
||||
|
||||
async def post_message(self,model: str, messages: list, stream: bool = False, **kwargs):
|
||||
url = f"{self.base_url}/chat/completions"
|
||||
headers = {"Authorization": f"Bearer {self.api_key}"}
|
||||
payload = {
|
||||
"model": model,
|
||||
"messages": messages,
|
||||
"stream": stream,
|
||||
**kwargs
|
||||
}
|
||||
async with httpx.AsyncClient(timeout=None) as client:
|
||||
if not stream:
|
||||
response = await client.post(url, headers=headers, json=payload)
|
||||
return response.json()
|
||||
else:
|
||||
return client.stream("POST", url, headers=headers, json=payload)
|
||||
|
|
@ -0,0 +1,23 @@
|
|||
from sqlalchemy.exc import IntegrityError, OperationalError
|
||||
from pydantic import ValidationError
|
||||
from loguru import logger
|
||||
|
||||
def database_exception(func):
|
||||
async def wrapper(*args, **kwargs):
|
||||
try:
|
||||
return await func(*args, **kwargs)
|
||||
except ValidationError as e:
|
||||
logger.error(f"对象校验失败:{e}")
|
||||
raise e
|
||||
except IntegrityError as e:
|
||||
logger.error(f"数据库完整性错误 (如重复记录): {e}")
|
||||
raise e
|
||||
except OperationalError as e:
|
||||
logger.error(f"数据库连接异常: {e}")
|
||||
raise e
|
||||
except UserNotExistError as e:
|
||||
logger.error(f"更改密码失败,用户不存在:{e}")
|
||||
except Exception as e:
|
||||
logger.exception(f"未预期的数据库错误: {e}")
|
||||
raise e
|
||||
return wrapper
|
||||
|
|
@ -0,0 +1,71 @@
|
|||
import ray
|
||||
from pretor.core.database.table import User
|
||||
from sqlalchemy.ext.asyncio import create_async_engine, AsyncSession
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
from sqlmodel import SQLModel, select
|
||||
from pretor.utils.error import UserNotExistError, UserPasswordError
|
||||
import os
|
||||
from pretor.core.database.database_exception import database_exception
|
||||
|
||||
@ray.remote
|
||||
class PostgresDatabase:
|
||||
def __init__(self):
|
||||
user = os.environ.get('POSTGRES_USER')
|
||||
password = os.environ.get('POSTGRES_PASSWORD')
|
||||
host = os.environ.get('POSTGRES_HOST')
|
||||
port = os.environ.get('POSTGRES_PORT')
|
||||
database = os.environ.get('POSTGRES_DB')
|
||||
database_url = f"postgresql+asyncpg://{user}:{password}@{host}:{port}/{database}"
|
||||
self.async_engine = create_async_engine(database_url, echo=True)
|
||||
self.async_session_maker = sessionmaker(self.async_engine, class_=AsyncSession, expire_on_commit=False)
|
||||
|
||||
async def init_db(self) -> None:
|
||||
async with self.async_engine.begin() as conn:
|
||||
await conn.run_sync(SQLModel.metadata.create_all)
|
||||
|
||||
@database_exception
|
||||
async def add_user(self, user_name: str, hashed_password: str) -> User:
|
||||
user = User(user_name=user_name, hashed_password=hashed_password)
|
||||
async with self.async_session_maker as session:
|
||||
session.add(user)
|
||||
await session.commit()
|
||||
await session.refresh(user)
|
||||
return user
|
||||
|
||||
@database_exception
|
||||
async def change_password(self, user_name, old_password, new_password) -> User:
|
||||
async with self.async_session_maker() as session:
|
||||
statement = select(User).where(User.user_name == user_name)
|
||||
results = await session.exec(statement)
|
||||
user = results.scalar_one_or_none()
|
||||
if user is None:
|
||||
raise UserNotExistError()
|
||||
if old_password != user.hashed_password:
|
||||
raise UserPasswordError()
|
||||
user.hashed_password = new_password
|
||||
session.add(user)
|
||||
await session.commit()
|
||||
await session.refresh(user)
|
||||
return user
|
||||
|
||||
@database_exception
|
||||
async def delete_user(self, user_name: str) -> None:
|
||||
async with self.async_session_maker() as session:
|
||||
statement = select(User).where(User.user_name == user_name)
|
||||
results = await session.exec(statement)
|
||||
user = results.scalar_one_or_none()
|
||||
if user is None:
|
||||
raise UserNotExistError()
|
||||
session.delete(user)
|
||||
await session.commit()
|
||||
|
||||
@database_exception
|
||||
async def get_user_password(self, user_name: str) -> str:
|
||||
async with self.async_session_maker() as session:
|
||||
statement = select(User).where(User.user_name == user_name)
|
||||
results = await session.exec(statement)
|
||||
user = results.scalar_one_or_none()
|
||||
if user is None:
|
||||
raise UserNotExistError()
|
||||
return user.hashed_password
|
||||
|
||||
|
|
@ -0,0 +1 @@
|
|||
from pretor.core.database.table.user import User
|
||||
|
|
@ -0,0 +1,8 @@
|
|||
from sqlmodel import SQLModel, Field, Column, String
|
||||
|
||||
|
||||
class User(SQLModel):
|
||||
__tablename__ = 'user'
|
||||
user_id: int = Field(default=None, primary_key=True)
|
||||
user_name: str = Field(index=True)
|
||||
hashed_password: str
|
||||
|
|
@ -8,7 +8,7 @@ class WorkflowManager:
|
|||
def __init__(self):
|
||||
self.workflow_template_generator = WorkflowTemplateGenerator()
|
||||
self.workflow_templates_registry = {}
|
||||
self.template_path = Path("pretor/workflow_plugin")
|
||||
self.template_path = Path("pretor/workflow_template")
|
||||
self._load_workflow_template()
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -5,7 +5,7 @@ class WorkflowTemplateGenerator:
|
|||
@staticmethod
|
||||
def generate_workflow_template(name: str, desc: str, steps: list) -> None:
|
||||
workflow_template = WorkflowTemplate(name=name, desc=desc, work_link=steps)
|
||||
output_dir = Path("pretor.workflow_plugin")
|
||||
output_dir = Path("pretor.workflow_template")
|
||||
if not output_dir.exists():
|
||||
output_dir.mkdir(parents=True)
|
||||
output_file = output_dir / f"{name}_workflow_template.json"
|
||||
|
|
|
|||
|
|
@ -1,26 +1,5 @@
|
|||
from pretor.core.protocol.runnable_object import RunnableObject
|
||||
from pretor.core.workflow_manager.workflow import PretorWorkflow
|
||||
from pretor.adapter_plugin.model_adapter.modelbase import ModelBase
|
||||
import ray
|
||||
from typing import Any,Dict
|
||||
from pretor.individual_plugin.control_node.control_register import ControlRegister
|
||||
from pretor.utils.inspector import inspector
|
||||
|
||||
#control_node 管控节点,掌管系统的全局状态
|
||||
@ray.remote
|
||||
class ControlNode(RunnableObject):
|
||||
def __init__(self, **kwargs: Dict[str: Any]) -> None:
|
||||
self.model_adapter : ModelBase = kwargs.get("model_adapter")
|
||||
self.model : str = kwargs.get("model")
|
||||
self.name : str = kwargs.get("name", "管控节点")
|
||||
self.control_register = ControlRegister()
|
||||
|
||||
def _load_control_register(self) :
|
||||
pass
|
||||
|
||||
@inspector("individual","control_node")
|
||||
async def run(self, workflow : PretorWorkflow) -> None:
|
||||
control_register = self.control_register.model_dump_json()
|
||||
demand = workflow.status.content.demand.model_dump_json()
|
||||
|
||||
from pydantic_ai import Agent
|
||||
|
||||
class ControlNode:
|
||||
def __init__(self):
|
||||
pass
|
||||
|
|
@ -1,22 +0,0 @@
|
|||
from pydantic import BaseModel, Field
|
||||
from typing import List, Dict, Any, Literal, Union, Optional
|
||||
|
||||
class SystemItem(BaseModel):
|
||||
command_template: str = Field(..., description="底层 shell 命令模板")
|
||||
args_schema: Dict[str, Any] = Field(default_factory=dict, description="该指令接受的参数约束")
|
||||
|
||||
class IndividualItem(BaseModel):
|
||||
description: str
|
||||
params: Dict[str: str]
|
||||
base_prompt: str = Field(..., description="个体的基础人格/背景设定")
|
||||
|
||||
class ToolItem(BaseModel):
|
||||
description: str
|
||||
plugin_path: str = Field(..., description="插件物理路径或类路径")
|
||||
|
||||
class ControlRegister(BaseModel):
|
||||
# 统一使用 Dict,方便通过 name 快速索引:{ "name": ItemObject }
|
||||
system_registry: Dict[str, SystemItem] = Field(default_factory=dict)
|
||||
individual_registry: Dict[str, IndividualItem] = Field(default_factory=dict)
|
||||
tool_registry: Dict[str, ToolItem] = Field(default_factory=dict)
|
||||
global_information : Dict[str, str] = Field(default_factory=dict)
|
||||
|
|
@ -1,2 +1,17 @@
|
|||
class DemandError(Exception):
|
||||
pass
|
||||
|
||||
class ModelNotExistError(Exception):
|
||||
pass
|
||||
|
||||
class ProviderNotExistError(Exception):
|
||||
pass
|
||||
|
||||
class UserError(Exception):
|
||||
pass
|
||||
|
||||
class UserNotExistError(UserError):
|
||||
pass
|
||||
|
||||
class UserPasswordError(UserError):
|
||||
pass
|
||||
|
|
@ -5,10 +5,13 @@ description = "Add your description here"
|
|||
readme = "README.md"
|
||||
requires-python = ">=3.13"
|
||||
dependencies = [
|
||||
"asyncpg>=0.31.0",
|
||||
"docker-py>=1.10.6",
|
||||
"httpx>=0.28.1",
|
||||
"jinja2>=3.1.6",
|
||||
"loguru>=0.7.3",
|
||||
"pydantic-ai>=1.73.0",
|
||||
"python-ulid>=3.1.0",
|
||||
"ray[defaule,serve]>=2.54.0",
|
||||
"ray[default,serve]>=2.54.0",
|
||||
"sqlmodel>=0.0.37",
|
||||
]
|
||||
|
|
|
|||
Loading…
Reference in New Issue