feat:使用pytanticAI再次重构,增加了对于api的接口管理

This commit is contained in:
朝夕 2026-03-31 18:15:18 +08:00
parent c672c60af6
commit 46937fbc10
28 changed files with 1984 additions and 242 deletions

View File

@ -0,0 +1 @@
from pretor.adapter.model_adapter.provider_manager import ProviderManager

View File

@ -0,0 +1,20 @@
from pydantic_ai.models.openai import OpenAIChatModel
from pydantic_ai.models.google import GoogleModel
from pydantic_ai.models.anthropic import AnthropicModel
from pydantic_ai.providers.openai import OpenAIProvider
from pydantic_ai.providers.google import GoogleProvider
from pydantic_ai.providers.anthropic import AnthropicProvider
from pathlib import Path
class AgentFactory:
def __init__(self):
self._models_mapping = {"openai": (OpenAIChatModel, OpenAIProvider), "gemini": (GoogleModel, GoogleProvider), "claude": (AnthropicModel, AnthropicProvider)}
def _load_agent_protocol(self):
pass
def create_model(self, protocol_name: str, api_key: str, url: str | None, model_id: str):
if protocol_name not in self._models_mapping:
raise ValueError(f"不支持的协议类型: {protocol_name}")
model_class, provider_class = self._models_mapping[protocol_name]
return model_class(model_id, provider_class(api_key = api_key, url = url))

View File

@ -0,0 +1,4 @@
from pretor.adapter.model_adapter.model_provider.base_provider import Provider, ProviderArgs
from pretor.adapter.model_adapter.model_provider.openai_provider import OpenAIProvider
from pretor.adapter.model_adapter.model_provider.gemini_provider import GeminiProvider
from pretor.adapter.model_adapter.model_provider.claude_provider import ClaudeProvider

View File

@ -0,0 +1,33 @@
from abc import ABC, abstractmethod
from pydantic import BaseModel
from typing import List
class Provider(BaseModel):
provider_title: str
provider_url: str
provider_apikey: str
provider_models: List[str]
provider_type: str
class ProviderArgs(BaseModel):
provider_title: str
provider_url: str
provider_apikey: str
class BaseProvider(ABC):
@staticmethod
@abstractmethod
async def create_model(provider_args: ProviderArgs) -> Provider:
pass
@staticmethod
@abstractmethod
async def _load_models(provider_args: ProviderArgs) -> List[str]:
pass
@staticmethod
@abstractmethod
def _return_provider(provider_args: ProviderArgs, provider_models: List[str]) -> Provider:
pass

View File

@ -0,0 +1,44 @@
from pretor.adapter.model_adapter.model_provider.base_provider import BaseProvider, Provider, ProviderArgs
import httpx
from typing import List
class ClaudeProvider(BaseProvider):
@staticmethod
async def create_model(provider_args: ProviderArgs) -> Provider:
provider_models: List[str] = await ClaudeProvider._load_models(provider_args)
provider: Provider = ClaudeProvider._return_provider(provider_args, provider_models)
return provider
@staticmethod
async def _load_models(provider_args: ProviderArgs) -> List[str]:
# Anthropic 官方需要 version 头
headers = {
"x-api-key": provider_args.provider_apikey,
"anthropic-version": "2023-06-01",
"Content-Type": "application/json"
}
# 如果是官方 API通常使用 /v1/models (如果支持)
# 注意:很多时候 Anthropic 并不返回完整列表,如果请求失败,建议返回硬编码的常用模型
url = f"{provider_args.provider_url.rstrip('/')}/v1/models"
try:
async with httpx.AsyncClient(timeout=10.0) as client:
response = await client.get(url, headers=headers)
if response.status_code == 200:
data = response.json()
model_ids = [m["id"] for m in data.get("data", [])]
return sorted(model_ids)
else:
# 如果官方列表接口不可用fallback 到已知常用模型
return ["claude-3-5-sonnet-20240620", "claude-3-opus-20240229", "claude-3-haiku-20240307"]
except Exception as e:
print(f"[{provider_args.provider_title}] 获取 Claude 模型列表错误: {e}")
return ["claude-3-5-sonnet-20240620"]
@staticmethod
def _return_provider(provider_args: ProviderArgs, provider_models: List[str]) -> Provider:
return Provider(provider_title=provider_args.provider_title,
provider_apikey=provider_args.provider_apikey,
provider_url=provider_args.provider_url,
provider_models=provider_models,
provider_type="claude")

View File

@ -0,0 +1,45 @@
from pretor.adapter.model_adapter.model_provider.base_provider import BaseProvider, Provider, ProviderArgs
import httpx
from typing import List
class GeminiProvider(BaseProvider):
@staticmethod
async def create_model(provider_args: ProviderArgs) -> Provider:
provider_models: List[str] = await GeminiProvider._load_models(provider_args)
provider: Provider = GeminiProvider._return_provider(provider_args, provider_models)
return provider
@staticmethod
async def _load_models(provider_args: ProviderArgs) -> List[str]:
# Google Gemini 原生鉴权通常使用 x-goog-api-key 或 query parameter
headers = {
"x-goog-api-key": provider_args.provider_apikey,
"Content-Type": "application/json"
}
# 官方路径通常是 v1beta/models
url = f"{provider_args.provider_url.rstrip('/')}/v1beta/models"
try:
async with httpx.AsyncClient(timeout=10.0) as client:
response = await client.get(url, headers=headers)
if response.status_code != 200:
print(f"[{provider_args.provider_title}] 获取 Gemini 模型失败: {response.status_code}")
return []
data = response.json()
# Gemini 返回的结构中模型 ID 通常带 "models/" 前缀
raw_models = data.get("models", [])
model_ids = [m["name"].split("/")[-1] for m in raw_models if
"generateContent" in m.get("supportedGenerationMethods", [])]
return sorted(list(set(model_ids)))
except Exception as e:
print(f"[{provider_args.provider_title}] 获取 Gemini 模型列表错误: {e}")
return []
@staticmethod
def _return_provider(provider_args: ProviderArgs, provider_models: List[str]) -> Provider:
return Provider(provider_title=provider_args.provider_title,
provider_apikey=provider_args.provider_apikey,
provider_url=provider_args.provider_url,
provider_models=provider_models,
provider_type="gemini")

View File

@ -0,0 +1,42 @@
from pretor.adapter.model_adapter.model_provider.base_provider import BaseProvider, Provider, ProviderArgs
import httpx
from typing import List
class OpenAIProvider(BaseProvider):
@staticmethod
async def create_model(provider_args: ProviderArgs) -> Provider:
provider_models: List[str] = await OpenAIProvider._load_models(provider_args)
provider: Provider = OpenAIProvider._return_provider(provider_args, provider_models)
return provider
@staticmethod
async def _load_models(provider_args: ProviderArgs) -> List[str]:
headers = {
"Authorization": f"Bearer {provider_args.provider_apikey}",
"Content-Type": "application/json"
}
url = f"{provider_args.provider_url}/models" if "/v1" in provider_args.provider_url else f"{provider_args.provider_url}/v1/models"
try:
async with httpx.AsyncClient(timeout=10.0) as client:
response = await client.get(url, headers=headers)
if response.status_code != 200:
print(f"[{provider_args.provider_title}] 获取模型失败: {response.status_code}")
return []
data = response.json()
raw_models = data.get("data", [])
model_ids = [m["id"] for m in raw_models]
return sorted(model_ids)
except httpx.RequestError as e:
print(f"[{provider_args.provider_title}] 网络请求异常: {e}")
return []
except Exception as e:
print(f"[{provider_args.provider_title}] 解析模型列表时发生错误: {e}")
return []
@staticmethod
def _return_provider(provider_args: ProviderArgs, provider_models: List[str]) -> Provider:
return Provider(provider_title=provider_args.provider_title,
provider_apikey=provider_args.provider_apikey,
provider_url=provider_args.provider_url,
provider_models=provider_models,
provider_type="openai")

View File

@ -0,0 +1,57 @@
from pretor.adapter.model_adapter._agent_factory import AgentFactory
from pretor.adapter.model_adapter.model_provider import Provider, ProviderArgs, OpenAIProvider,GeminiProvider, ClaudeProvider
from pydantic_ai import Agent
import httpx
from pretor.utils.error import ModelNotExistError, ProviderNotExistError
from loguru import logger
from typing import Dict
class ProviderManager:
def __init__(self):
self._provider_mapper = {"openai": OpenAIProvider, "gemini": GeminiProvider, "claude": ClaudeProvider}
self._agent_factory = AgentFactory()
self.provider_register = {}
async def add_provider(self, provider_type: str, provider_title: str, provider_url: str, provider_apikey: str) -> None:
"""
add_provider方法注册供应商适配器
:param provider_type: 注册商接口类型目前只支持openai,gemini和claude接口
:param provider_title: 供应商名称为供应商提供的别名
:param provider_url: 供应商url
:param provider_apikey: 供应商所需要的apikey
:return:
"""
provider_args: ProviderArgs = ProviderArgs(provider_title=provider_title, provider_url=provider_url, provider_apikey=provider_apikey)
try:
if provider_type not in self._provider_mapper.keys():
logger.warning(f"Provider type {provider_type} is not supported.")
return None
provider_class = self._provider_mapper.get(provider_type)
provider: Provider = await provider_class.create_model(provider_args)
self.provider_register[provider_title] = provider
logger.info(f"已添加适配器{provider_title}")
except httpx.RequestError as e:
logger.warning(f"[{provider_args.provider_title}] 网络请求异常: {e}")
except Exception as e:
logger.warning(f"[{provider_args.provider_title}] 解析模型列表时发生错误: {e}")
def create_agent(self, agent_name: str, system_prompt: str, provider_title: str, model_id: str) -> Agent:
"""
create_agent方法将保存的适配器转化为agent对象并返回
:param agent_name: agent名字代表实例化个体起的名字
:param system_prompt: 系统提示词给llm的系统提示词
:param provider_title: 供应商名称
:param model_id: 模型Id实例化agent所输入的model_id
:return:
"""
if provider_title not in self.provider_register:
raise ProviderNotExistError("提供商不存在")
provider = self.provider_register[provider_title]
if model_id not in provider.provider_models:
raise ModelNotExistError("模型不存在")
model = self._agent_factory.create_model(provider.provider_type, provider.provider_apikey, provider.provider_url, model_id)
agent = Agent(model=model,name=agent_name,system_prompt=system_prompt)
return agent
def get_provider_list(self) -> Dict[str, Provider]:
return self.provider_register

View File

@ -1,76 +0,0 @@
import httpx
import json
from typing import List, Dict, Any, AsyncGenerator
from pretor.protocol_plugin.model_protocol.modelbase import ModelBase
class GeminiAdapter(ModelBase):
def __init__(self, base_url: str, adapter_title: str, api_key: str):
self.adapter_title: str = adapter_title
self.base_url = base_url.rstrip('/')
if not self.base_url.endswith('/v1'):
self.base_url += '/v1'
self.api_key = api_key
self.model_list = []
async def get_model(self) -> List[str]:
url = f"{self.base_url}/models"
headers = {"Authorization": f"Bearer {self.api_key}"}
async with httpx.AsyncClient(timeout=10.0) as client:
response = await client.get(url, headers=headers)
response.raise_for_status()
data = response.json()
self.model_list = [m.get("id", "") for m in data.get("data", [])]
return self.model_list
async def post_message(
self,
model: str,
messages: List[Dict[str, str]],
stream: bool = False,
temperature: float = 0.7,
max_tokens: int = 4096,
**kwargs
) -> Any:
url = f"{self.base_url}/chat/completions"
headers = {
"Authorization": f"Bearer {self.api_key}",
"Content-Type": "application/json"
}
payload = {
"model": model,
"messages": messages,
"stream": stream,
"temperature": temperature,
"max_tokens": max_tokens,
**kwargs
}
# 144GB 显存或云端长文本建议设置较长超时
timeout = httpx.Timeout(120.0, connect=10.0)
if not stream:
async with httpx.AsyncClient(timeout=timeout) as client:
response = await client.post(url, headers=headers, json=payload)
response.raise_for_status()
return response.json()
else:
return self._handle_stream(url, headers, payload)
@staticmethod
async def _handle_stream(self, url: str, headers: dict, payload: dict) -> AsyncGenerator[str, None]:
async with httpx.AsyncClient(timeout=None) as client:
async with client.stream("POST", url, headers=headers, json=payload) as response:
response.raise_for_status()
async for line in response.aiter_lines():
if not line.strip() or line == "data: [DONE]":
continue
if line.startswith("data: "):
try:
chunk = json.loads(line[6:])
delta = chunk["choices"][0]["delta"].get("content", "")
if delta:
yield delta
except (json.JSONDecodeError, KeyError):
continue

View File

@ -1,10 +0,0 @@
from abc import ABC,abstractmethod
class ModelBase(ABC):
@abstractmethod
async def get_model(self):
pass
@abstractmethod
async def post_message(self, model: str, messages: list, stream: bool = False, **kwargs):
pass

View File

@ -1,39 +0,0 @@
import httpx
from pretor.protocol_plugin.model_protocol.modelbase import ModelBase
class OpenAIAdapter(ModelBase):
def __init__(self, base_url: str, adapter_title: str, api_key: str = "archon-local"):
self.adapter_title: str = adapter_title
self.base_url = base_url.rstrip('/')
if not self.base_url.endswith('/v1'):
self.base_url += '/v1'
self.api_key = api_key
self.model_list = []
async def get_model(self):
url = "{}/models".format(self.base_url)
headers = {
"Authorization": f"Bearer {self.api_key}"
}
async with httpx.AsyncClient() as client:
response = await client.get(url, headers=headers)
response.raise_for_status()
response = response.json()
self.model_list = [m.get("id", "") for m in response.get("data", [])]
return self.model_list
async def post_message(self,model: str, messages: list, stream: bool = False, **kwargs):
url = f"{self.base_url}/chat/completions"
headers = {"Authorization": f"Bearer {self.api_key}"}
payload = {
"model": model,
"messages": messages,
"stream": stream,
**kwargs
}
async with httpx.AsyncClient(timeout=None) as client:
if not stream:
response = await client.post(url, headers=headers, json=payload)
return response.json()
else:
return client.stream("POST", url, headers=headers, json=payload)

View File

View File

@ -0,0 +1,23 @@
from sqlalchemy.exc import IntegrityError, OperationalError
from pydantic import ValidationError
from loguru import logger
def database_exception(func):
async def wrapper(*args, **kwargs):
try:
return await func(*args, **kwargs)
except ValidationError as e:
logger.error(f"对象校验失败:{e}")
raise e
except IntegrityError as e:
logger.error(f"数据库完整性错误 (如重复记录): {e}")
raise e
except OperationalError as e:
logger.error(f"数据库连接异常: {e}")
raise e
except UserNotExistError as e:
logger.error(f"更改密码失败,用户不存在:{e}")
except Exception as e:
logger.exception(f"未预期的数据库错误: {e}")
raise e
return wrapper

View File

@ -0,0 +1,71 @@
import ray
from pretor.core.database.table import User
from sqlalchemy.ext.asyncio import create_async_engine, AsyncSession
from sqlalchemy.orm import sessionmaker
from sqlmodel import SQLModel, select
from pretor.utils.error import UserNotExistError, UserPasswordError
import os
from pretor.core.database.database_exception import database_exception
@ray.remote
class PostgresDatabase:
def __init__(self):
user = os.environ.get('POSTGRES_USER')
password = os.environ.get('POSTGRES_PASSWORD')
host = os.environ.get('POSTGRES_HOST')
port = os.environ.get('POSTGRES_PORT')
database = os.environ.get('POSTGRES_DB')
database_url = f"postgresql+asyncpg://{user}:{password}@{host}:{port}/{database}"
self.async_engine = create_async_engine(database_url, echo=True)
self.async_session_maker = sessionmaker(self.async_engine, class_=AsyncSession, expire_on_commit=False)
async def init_db(self) -> None:
async with self.async_engine.begin() as conn:
await conn.run_sync(SQLModel.metadata.create_all)
@database_exception
async def add_user(self, user_name: str, hashed_password: str) -> User:
user = User(user_name=user_name, hashed_password=hashed_password)
async with self.async_session_maker as session:
session.add(user)
await session.commit()
await session.refresh(user)
return user
@database_exception
async def change_password(self, user_name, old_password, new_password) -> User:
async with self.async_session_maker() as session:
statement = select(User).where(User.user_name == user_name)
results = await session.exec(statement)
user = results.scalar_one_or_none()
if user is None:
raise UserNotExistError()
if old_password != user.hashed_password:
raise UserPasswordError()
user.hashed_password = new_password
session.add(user)
await session.commit()
await session.refresh(user)
return user
@database_exception
async def delete_user(self, user_name: str) -> None:
async with self.async_session_maker() as session:
statement = select(User).where(User.user_name == user_name)
results = await session.exec(statement)
user = results.scalar_one_or_none()
if user is None:
raise UserNotExistError()
session.delete(user)
await session.commit()
@database_exception
async def get_user_password(self, user_name: str) -> str:
async with self.async_session_maker() as session:
statement = select(User).where(User.user_name == user_name)
results = await session.exec(statement)
user = results.scalar_one_or_none()
if user is None:
raise UserNotExistError()
return user.hashed_password

View File

@ -0,0 +1 @@
from pretor.core.database.table.user import User

View File

View File

@ -0,0 +1,8 @@
from sqlmodel import SQLModel, Field, Column, String
class User(SQLModel):
__tablename__ = 'user'
user_id: int = Field(default=None, primary_key=True)
user_name: str = Field(index=True)
hashed_password: str

View File

@ -8,7 +8,7 @@ class WorkflowManager:
def __init__(self):
self.workflow_template_generator = WorkflowTemplateGenerator()
self.workflow_templates_registry = {}
self.template_path = Path("pretor/workflow_plugin")
self.template_path = Path("pretor/workflow_template")
self._load_workflow_template()

View File

@ -5,7 +5,7 @@ class WorkflowTemplateGenerator:
@staticmethod
def generate_workflow_template(name: str, desc: str, steps: list) -> None:
workflow_template = WorkflowTemplate(name=name, desc=desc, work_link=steps)
output_dir = Path("pretor.workflow_plugin")
output_dir = Path("pretor.workflow_template")
if not output_dir.exists():
output_dir.mkdir(parents=True)
output_file = output_dir / f"{name}_workflow_template.json"

View File

@ -1,26 +1,5 @@
from pretor.core.protocol.runnable_object import RunnableObject
from pretor.core.workflow_manager.workflow import PretorWorkflow
from pretor.adapter_plugin.model_adapter.modelbase import ModelBase
import ray
from typing import Any,Dict
from pretor.individual_plugin.control_node.control_register import ControlRegister
from pretor.utils.inspector import inspector
from pydantic_ai import Agent
#control_node 管控节点,掌管系统的全局状态
@ray.remote
class ControlNode(RunnableObject):
def __init__(self, **kwargs: Dict[str: Any]) -> None:
self.model_adapter : ModelBase = kwargs.get("model_adapter")
self.model : str = kwargs.get("model")
self.name : str = kwargs.get("name", "管控节点")
self.control_register = ControlRegister()
def _load_control_register(self) :
class ControlNode:
def __init__(self):
pass
@inspector("individual","control_node")
async def run(self, workflow : PretorWorkflow) -> None:
control_register = self.control_register.model_dump_json()
demand = workflow.status.content.demand.model_dump_json()

View File

@ -1,22 +0,0 @@
from pydantic import BaseModel, Field
from typing import List, Dict, Any, Literal, Union, Optional
class SystemItem(BaseModel):
command_template: str = Field(..., description="底层 shell 命令模板")
args_schema: Dict[str, Any] = Field(default_factory=dict, description="该指令接受的参数约束")
class IndividualItem(BaseModel):
description: str
params: Dict[str: str]
base_prompt: str = Field(..., description="个体的基础人格/背景设定")
class ToolItem(BaseModel):
description: str
plugin_path: str = Field(..., description="插件物理路径或类路径")
class ControlRegister(BaseModel):
# 统一使用 Dict方便通过 name 快速索引:{ "name": ItemObject }
system_registry: Dict[str, SystemItem] = Field(default_factory=dict)
individual_registry: Dict[str, IndividualItem] = Field(default_factory=dict)
tool_registry: Dict[str, ToolItem] = Field(default_factory=dict)
global_information : Dict[str, str] = Field(default_factory=dict)

View File

@ -1,2 +1,17 @@
class DemandError(Exception):
pass
class ModelNotExistError(Exception):
pass
class ProviderNotExistError(Exception):
pass
class UserError(Exception):
pass
class UserNotExistError(UserError):
pass
class UserPasswordError(UserError):
pass

View File

@ -5,10 +5,13 @@ description = "Add your description here"
readme = "README.md"
requires-python = ">=3.13"
dependencies = [
"asyncpg>=0.31.0",
"docker-py>=1.10.6",
"httpx>=0.28.1",
"jinja2>=3.1.6",
"loguru>=0.7.3",
"pydantic-ai>=1.73.0",
"python-ulid>=3.1.0",
"ray[defaule,serve]>=2.54.0",
"ray[default,serve]>=2.54.0",
"sqlmodel>=0.0.37",
]

1677
uv.lock

File diff suppressed because it is too large Load Diff