feat:使用pytanticAI再次重构,增加了对于api的接口管理
This commit is contained in:
parent
c672c60af6
commit
46937fbc10
|
|
@ -0,0 +1 @@
|
||||||
|
from pretor.adapter.model_adapter.provider_manager import ProviderManager
|
||||||
|
|
@ -0,0 +1,20 @@
|
||||||
|
from pydantic_ai.models.openai import OpenAIChatModel
|
||||||
|
from pydantic_ai.models.google import GoogleModel
|
||||||
|
from pydantic_ai.models.anthropic import AnthropicModel
|
||||||
|
from pydantic_ai.providers.openai import OpenAIProvider
|
||||||
|
from pydantic_ai.providers.google import GoogleProvider
|
||||||
|
from pydantic_ai.providers.anthropic import AnthropicProvider
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
class AgentFactory:
|
||||||
|
def __init__(self):
|
||||||
|
self._models_mapping = {"openai": (OpenAIChatModel, OpenAIProvider), "gemini": (GoogleModel, GoogleProvider), "claude": (AnthropicModel, AnthropicProvider)}
|
||||||
|
|
||||||
|
def _load_agent_protocol(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def create_model(self, protocol_name: str, api_key: str, url: str | None, model_id: str):
|
||||||
|
if protocol_name not in self._models_mapping:
|
||||||
|
raise ValueError(f"不支持的协议类型: {protocol_name}")
|
||||||
|
model_class, provider_class = self._models_mapping[protocol_name]
|
||||||
|
return model_class(model_id, provider_class(api_key = api_key, url = url))
|
||||||
|
|
@ -0,0 +1,4 @@
|
||||||
|
from pretor.adapter.model_adapter.model_provider.base_provider import Provider, ProviderArgs
|
||||||
|
from pretor.adapter.model_adapter.model_provider.openai_provider import OpenAIProvider
|
||||||
|
from pretor.adapter.model_adapter.model_provider.gemini_provider import GeminiProvider
|
||||||
|
from pretor.adapter.model_adapter.model_provider.claude_provider import ClaudeProvider
|
||||||
|
|
@ -0,0 +1,33 @@
|
||||||
|
from abc import ABC, abstractmethod
|
||||||
|
from pydantic import BaseModel
|
||||||
|
from typing import List
|
||||||
|
|
||||||
|
class Provider(BaseModel):
|
||||||
|
provider_title: str
|
||||||
|
provider_url: str
|
||||||
|
provider_apikey: str
|
||||||
|
provider_models: List[str]
|
||||||
|
provider_type: str
|
||||||
|
|
||||||
|
class ProviderArgs(BaseModel):
|
||||||
|
provider_title: str
|
||||||
|
provider_url: str
|
||||||
|
provider_apikey: str
|
||||||
|
|
||||||
|
class BaseProvider(ABC):
|
||||||
|
@staticmethod
|
||||||
|
@abstractmethod
|
||||||
|
async def create_model(provider_args: ProviderArgs) -> Provider:
|
||||||
|
pass
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
@abstractmethod
|
||||||
|
async def _load_models(provider_args: ProviderArgs) -> List[str]:
|
||||||
|
pass
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
@abstractmethod
|
||||||
|
def _return_provider(provider_args: ProviderArgs, provider_models: List[str]) -> Provider:
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -0,0 +1,44 @@
|
||||||
|
from pretor.adapter.model_adapter.model_provider.base_provider import BaseProvider, Provider, ProviderArgs
|
||||||
|
import httpx
|
||||||
|
from typing import List
|
||||||
|
|
||||||
|
class ClaudeProvider(BaseProvider):
|
||||||
|
@staticmethod
|
||||||
|
async def create_model(provider_args: ProviderArgs) -> Provider:
|
||||||
|
provider_models: List[str] = await ClaudeProvider._load_models(provider_args)
|
||||||
|
provider: Provider = ClaudeProvider._return_provider(provider_args, provider_models)
|
||||||
|
return provider
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
async def _load_models(provider_args: ProviderArgs) -> List[str]:
|
||||||
|
# Anthropic 官方需要 version 头
|
||||||
|
headers = {
|
||||||
|
"x-api-key": provider_args.provider_apikey,
|
||||||
|
"anthropic-version": "2023-06-01",
|
||||||
|
"Content-Type": "application/json"
|
||||||
|
}
|
||||||
|
# 如果是官方 API,通常使用 /v1/models (如果支持)
|
||||||
|
# 注意:很多时候 Anthropic 并不返回完整列表,如果请求失败,建议返回硬编码的常用模型
|
||||||
|
url = f"{provider_args.provider_url.rstrip('/')}/v1/models"
|
||||||
|
|
||||||
|
try:
|
||||||
|
async with httpx.AsyncClient(timeout=10.0) as client:
|
||||||
|
response = await client.get(url, headers=headers)
|
||||||
|
if response.status_code == 200:
|
||||||
|
data = response.json()
|
||||||
|
model_ids = [m["id"] for m in data.get("data", [])]
|
||||||
|
return sorted(model_ids)
|
||||||
|
else:
|
||||||
|
# 如果官方列表接口不可用,fallback 到已知常用模型
|
||||||
|
return ["claude-3-5-sonnet-20240620", "claude-3-opus-20240229", "claude-3-haiku-20240307"]
|
||||||
|
except Exception as e:
|
||||||
|
print(f"[{provider_args.provider_title}] 获取 Claude 模型列表错误: {e}")
|
||||||
|
return ["claude-3-5-sonnet-20240620"]
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _return_provider(provider_args: ProviderArgs, provider_models: List[str]) -> Provider:
|
||||||
|
return Provider(provider_title=provider_args.provider_title,
|
||||||
|
provider_apikey=provider_args.provider_apikey,
|
||||||
|
provider_url=provider_args.provider_url,
|
||||||
|
provider_models=provider_models,
|
||||||
|
provider_type="claude")
|
||||||
|
|
@ -0,0 +1,45 @@
|
||||||
|
from pretor.adapter.model_adapter.model_provider.base_provider import BaseProvider, Provider, ProviderArgs
|
||||||
|
import httpx
|
||||||
|
from typing import List
|
||||||
|
|
||||||
|
class GeminiProvider(BaseProvider):
|
||||||
|
@staticmethod
|
||||||
|
async def create_model(provider_args: ProviderArgs) -> Provider:
|
||||||
|
provider_models: List[str] = await GeminiProvider._load_models(provider_args)
|
||||||
|
provider: Provider = GeminiProvider._return_provider(provider_args, provider_models)
|
||||||
|
return provider
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
async def _load_models(provider_args: ProviderArgs) -> List[str]:
|
||||||
|
# Google Gemini 原生鉴权通常使用 x-goog-api-key 或 query parameter
|
||||||
|
headers = {
|
||||||
|
"x-goog-api-key": provider_args.provider_apikey,
|
||||||
|
"Content-Type": "application/json"
|
||||||
|
}
|
||||||
|
# 官方路径通常是 v1beta/models
|
||||||
|
url = f"{provider_args.provider_url.rstrip('/')}/v1beta/models"
|
||||||
|
|
||||||
|
try:
|
||||||
|
async with httpx.AsyncClient(timeout=10.0) as client:
|
||||||
|
response = await client.get(url, headers=headers)
|
||||||
|
if response.status_code != 200:
|
||||||
|
print(f"[{provider_args.provider_title}] 获取 Gemini 模型失败: {response.status_code}")
|
||||||
|
return []
|
||||||
|
|
||||||
|
data = response.json()
|
||||||
|
# Gemini 返回的结构中模型 ID 通常带 "models/" 前缀
|
||||||
|
raw_models = data.get("models", [])
|
||||||
|
model_ids = [m["name"].split("/")[-1] for m in raw_models if
|
||||||
|
"generateContent" in m.get("supportedGenerationMethods", [])]
|
||||||
|
return sorted(list(set(model_ids)))
|
||||||
|
except Exception as e:
|
||||||
|
print(f"[{provider_args.provider_title}] 获取 Gemini 模型列表错误: {e}")
|
||||||
|
return []
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _return_provider(provider_args: ProviderArgs, provider_models: List[str]) -> Provider:
|
||||||
|
return Provider(provider_title=provider_args.provider_title,
|
||||||
|
provider_apikey=provider_args.provider_apikey,
|
||||||
|
provider_url=provider_args.provider_url,
|
||||||
|
provider_models=provider_models,
|
||||||
|
provider_type="gemini")
|
||||||
|
|
@ -0,0 +1,42 @@
|
||||||
|
from pretor.adapter.model_adapter.model_provider.base_provider import BaseProvider, Provider, ProviderArgs
|
||||||
|
import httpx
|
||||||
|
from typing import List
|
||||||
|
|
||||||
|
class OpenAIProvider(BaseProvider):
|
||||||
|
@staticmethod
|
||||||
|
async def create_model(provider_args: ProviderArgs) -> Provider:
|
||||||
|
provider_models: List[str] = await OpenAIProvider._load_models(provider_args)
|
||||||
|
provider: Provider = OpenAIProvider._return_provider(provider_args, provider_models)
|
||||||
|
return provider
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
async def _load_models(provider_args: ProviderArgs) -> List[str]:
|
||||||
|
headers = {
|
||||||
|
"Authorization": f"Bearer {provider_args.provider_apikey}",
|
||||||
|
"Content-Type": "application/json"
|
||||||
|
}
|
||||||
|
url = f"{provider_args.provider_url}/models" if "/v1" in provider_args.provider_url else f"{provider_args.provider_url}/v1/models"
|
||||||
|
try:
|
||||||
|
async with httpx.AsyncClient(timeout=10.0) as client:
|
||||||
|
response = await client.get(url, headers=headers)
|
||||||
|
if response.status_code != 200:
|
||||||
|
print(f"[{provider_args.provider_title}] 获取模型失败: {response.status_code}")
|
||||||
|
return []
|
||||||
|
data = response.json()
|
||||||
|
raw_models = data.get("data", [])
|
||||||
|
model_ids = [m["id"] for m in raw_models]
|
||||||
|
return sorted(model_ids)
|
||||||
|
except httpx.RequestError as e:
|
||||||
|
print(f"[{provider_args.provider_title}] 网络请求异常: {e}")
|
||||||
|
return []
|
||||||
|
except Exception as e:
|
||||||
|
print(f"[{provider_args.provider_title}] 解析模型列表时发生错误: {e}")
|
||||||
|
return []
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _return_provider(provider_args: ProviderArgs, provider_models: List[str]) -> Provider:
|
||||||
|
return Provider(provider_title=provider_args.provider_title,
|
||||||
|
provider_apikey=provider_args.provider_apikey,
|
||||||
|
provider_url=provider_args.provider_url,
|
||||||
|
provider_models=provider_models,
|
||||||
|
provider_type="openai")
|
||||||
|
|
@ -0,0 +1,57 @@
|
||||||
|
from pretor.adapter.model_adapter._agent_factory import AgentFactory
|
||||||
|
from pretor.adapter.model_adapter.model_provider import Provider, ProviderArgs, OpenAIProvider,GeminiProvider, ClaudeProvider
|
||||||
|
from pydantic_ai import Agent
|
||||||
|
import httpx
|
||||||
|
from pretor.utils.error import ModelNotExistError, ProviderNotExistError
|
||||||
|
from loguru import logger
|
||||||
|
from typing import Dict
|
||||||
|
|
||||||
|
class ProviderManager:
|
||||||
|
def __init__(self):
|
||||||
|
self._provider_mapper = {"openai": OpenAIProvider, "gemini": GeminiProvider, "claude": ClaudeProvider}
|
||||||
|
self._agent_factory = AgentFactory()
|
||||||
|
self.provider_register = {}
|
||||||
|
|
||||||
|
async def add_provider(self, provider_type: str, provider_title: str, provider_url: str, provider_apikey: str) -> None:
|
||||||
|
"""
|
||||||
|
add_provider方法,注册供应商适配器
|
||||||
|
:param provider_type: 注册商接口类型,目前只支持openai,gemini和claude接口
|
||||||
|
:param provider_title: 供应商名称,为供应商提供的别名
|
||||||
|
:param provider_url: 供应商url
|
||||||
|
:param provider_apikey: 供应商所需要的apikey
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
|
provider_args: ProviderArgs = ProviderArgs(provider_title=provider_title, provider_url=provider_url, provider_apikey=provider_apikey)
|
||||||
|
try:
|
||||||
|
if provider_type not in self._provider_mapper.keys():
|
||||||
|
logger.warning(f"Provider type {provider_type} is not supported.")
|
||||||
|
return None
|
||||||
|
provider_class = self._provider_mapper.get(provider_type)
|
||||||
|
provider: Provider = await provider_class.create_model(provider_args)
|
||||||
|
self.provider_register[provider_title] = provider
|
||||||
|
logger.info(f"已添加适配器{provider_title}")
|
||||||
|
except httpx.RequestError as e:
|
||||||
|
logger.warning(f"[{provider_args.provider_title}] 网络请求异常: {e}")
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"[{provider_args.provider_title}] 解析模型列表时发生错误: {e}")
|
||||||
|
|
||||||
|
def create_agent(self, agent_name: str, system_prompt: str, provider_title: str, model_id: str) -> Agent:
|
||||||
|
"""
|
||||||
|
create_agent方法,将保存的适配器转化为agent对象并返回
|
||||||
|
:param agent_name: agent名字,代表实例化个体起的名字
|
||||||
|
:param system_prompt: 系统提示词,给llm的系统提示词
|
||||||
|
:param provider_title: 供应商名称
|
||||||
|
:param model_id: 模型Id,实例化agent所输入的model_id
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
|
if provider_title not in self.provider_register:
|
||||||
|
raise ProviderNotExistError("提供商不存在")
|
||||||
|
provider = self.provider_register[provider_title]
|
||||||
|
if model_id not in provider.provider_models:
|
||||||
|
raise ModelNotExistError("模型不存在")
|
||||||
|
model = self._agent_factory.create_model(provider.provider_type, provider.provider_apikey, provider.provider_url, model_id)
|
||||||
|
agent = Agent(model=model,name=agent_name,system_prompt=system_prompt)
|
||||||
|
return agent
|
||||||
|
|
||||||
|
def get_provider_list(self) -> Dict[str, Provider]:
|
||||||
|
return self.provider_register
|
||||||
|
|
@ -1,76 +0,0 @@
|
||||||
import httpx
|
|
||||||
import json
|
|
||||||
from typing import List, Dict, Any, AsyncGenerator
|
|
||||||
from pretor.protocol_plugin.model_protocol.modelbase import ModelBase
|
|
||||||
|
|
||||||
class GeminiAdapter(ModelBase):
|
|
||||||
def __init__(self, base_url: str, adapter_title: str, api_key: str):
|
|
||||||
self.adapter_title: str = adapter_title
|
|
||||||
self.base_url = base_url.rstrip('/')
|
|
||||||
if not self.base_url.endswith('/v1'):
|
|
||||||
self.base_url += '/v1'
|
|
||||||
self.api_key = api_key
|
|
||||||
self.model_list = []
|
|
||||||
|
|
||||||
async def get_model(self) -> List[str]:
|
|
||||||
url = f"{self.base_url}/models"
|
|
||||||
headers = {"Authorization": f"Bearer {self.api_key}"}
|
|
||||||
|
|
||||||
async with httpx.AsyncClient(timeout=10.0) as client:
|
|
||||||
response = await client.get(url, headers=headers)
|
|
||||||
response.raise_for_status()
|
|
||||||
data = response.json()
|
|
||||||
self.model_list = [m.get("id", "") for m in data.get("data", [])]
|
|
||||||
return self.model_list
|
|
||||||
|
|
||||||
async def post_message(
|
|
||||||
self,
|
|
||||||
model: str,
|
|
||||||
messages: List[Dict[str, str]],
|
|
||||||
stream: bool = False,
|
|
||||||
temperature: float = 0.7,
|
|
||||||
max_tokens: int = 4096,
|
|
||||||
**kwargs
|
|
||||||
) -> Any:
|
|
||||||
url = f"{self.base_url}/chat/completions"
|
|
||||||
headers = {
|
|
||||||
"Authorization": f"Bearer {self.api_key}",
|
|
||||||
"Content-Type": "application/json"
|
|
||||||
}
|
|
||||||
|
|
||||||
payload = {
|
|
||||||
"model": model,
|
|
||||||
"messages": messages,
|
|
||||||
"stream": stream,
|
|
||||||
"temperature": temperature,
|
|
||||||
"max_tokens": max_tokens,
|
|
||||||
**kwargs
|
|
||||||
}
|
|
||||||
|
|
||||||
# 144GB 显存或云端长文本建议设置较长超时
|
|
||||||
timeout = httpx.Timeout(120.0, connect=10.0)
|
|
||||||
|
|
||||||
if not stream:
|
|
||||||
async with httpx.AsyncClient(timeout=timeout) as client:
|
|
||||||
response = await client.post(url, headers=headers, json=payload)
|
|
||||||
response.raise_for_status()
|
|
||||||
return response.json()
|
|
||||||
else:
|
|
||||||
return self._handle_stream(url, headers, payload)
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
async def _handle_stream(self, url: str, headers: dict, payload: dict) -> AsyncGenerator[str, None]:
|
|
||||||
async with httpx.AsyncClient(timeout=None) as client:
|
|
||||||
async with client.stream("POST", url, headers=headers, json=payload) as response:
|
|
||||||
response.raise_for_status()
|
|
||||||
async for line in response.aiter_lines():
|
|
||||||
if not line.strip() or line == "data: [DONE]":
|
|
||||||
continue
|
|
||||||
if line.startswith("data: "):
|
|
||||||
try:
|
|
||||||
chunk = json.loads(line[6:])
|
|
||||||
delta = chunk["choices"][0]["delta"].get("content", "")
|
|
||||||
if delta:
|
|
||||||
yield delta
|
|
||||||
except (json.JSONDecodeError, KeyError):
|
|
||||||
continue
|
|
||||||
|
|
@ -1,10 +0,0 @@
|
||||||
from abc import ABC,abstractmethod
|
|
||||||
|
|
||||||
class ModelBase(ABC):
|
|
||||||
@abstractmethod
|
|
||||||
async def get_model(self):
|
|
||||||
pass
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
async def post_message(self, model: str, messages: list, stream: bool = False, **kwargs):
|
|
||||||
pass
|
|
||||||
|
|
@ -1,39 +0,0 @@
|
||||||
import httpx
|
|
||||||
from pretor.protocol_plugin.model_protocol.modelbase import ModelBase
|
|
||||||
|
|
||||||
class OpenAIAdapter(ModelBase):
|
|
||||||
def __init__(self, base_url: str, adapter_title: str, api_key: str = "archon-local"):
|
|
||||||
self.adapter_title: str = adapter_title
|
|
||||||
self.base_url = base_url.rstrip('/')
|
|
||||||
if not self.base_url.endswith('/v1'):
|
|
||||||
self.base_url += '/v1'
|
|
||||||
self.api_key = api_key
|
|
||||||
self.model_list = []
|
|
||||||
|
|
||||||
async def get_model(self):
|
|
||||||
url = "{}/models".format(self.base_url)
|
|
||||||
headers = {
|
|
||||||
"Authorization": f"Bearer {self.api_key}"
|
|
||||||
}
|
|
||||||
async with httpx.AsyncClient() as client:
|
|
||||||
response = await client.get(url, headers=headers)
|
|
||||||
response.raise_for_status()
|
|
||||||
response = response.json()
|
|
||||||
self.model_list = [m.get("id", "") for m in response.get("data", [])]
|
|
||||||
return self.model_list
|
|
||||||
|
|
||||||
async def post_message(self,model: str, messages: list, stream: bool = False, **kwargs):
|
|
||||||
url = f"{self.base_url}/chat/completions"
|
|
||||||
headers = {"Authorization": f"Bearer {self.api_key}"}
|
|
||||||
payload = {
|
|
||||||
"model": model,
|
|
||||||
"messages": messages,
|
|
||||||
"stream": stream,
|
|
||||||
**kwargs
|
|
||||||
}
|
|
||||||
async with httpx.AsyncClient(timeout=None) as client:
|
|
||||||
if not stream:
|
|
||||||
response = await client.post(url, headers=headers, json=payload)
|
|
||||||
return response.json()
|
|
||||||
else:
|
|
||||||
return client.stream("POST", url, headers=headers, json=payload)
|
|
||||||
|
|
@ -0,0 +1,23 @@
|
||||||
|
from sqlalchemy.exc import IntegrityError, OperationalError
|
||||||
|
from pydantic import ValidationError
|
||||||
|
from loguru import logger
|
||||||
|
|
||||||
|
def database_exception(func):
|
||||||
|
async def wrapper(*args, **kwargs):
|
||||||
|
try:
|
||||||
|
return await func(*args, **kwargs)
|
||||||
|
except ValidationError as e:
|
||||||
|
logger.error(f"对象校验失败:{e}")
|
||||||
|
raise e
|
||||||
|
except IntegrityError as e:
|
||||||
|
logger.error(f"数据库完整性错误 (如重复记录): {e}")
|
||||||
|
raise e
|
||||||
|
except OperationalError as e:
|
||||||
|
logger.error(f"数据库连接异常: {e}")
|
||||||
|
raise e
|
||||||
|
except UserNotExistError as e:
|
||||||
|
logger.error(f"更改密码失败,用户不存在:{e}")
|
||||||
|
except Exception as e:
|
||||||
|
logger.exception(f"未预期的数据库错误: {e}")
|
||||||
|
raise e
|
||||||
|
return wrapper
|
||||||
|
|
@ -0,0 +1,71 @@
|
||||||
|
import ray
|
||||||
|
from pretor.core.database.table import User
|
||||||
|
from sqlalchemy.ext.asyncio import create_async_engine, AsyncSession
|
||||||
|
from sqlalchemy.orm import sessionmaker
|
||||||
|
from sqlmodel import SQLModel, select
|
||||||
|
from pretor.utils.error import UserNotExistError, UserPasswordError
|
||||||
|
import os
|
||||||
|
from pretor.core.database.database_exception import database_exception
|
||||||
|
|
||||||
|
@ray.remote
|
||||||
|
class PostgresDatabase:
|
||||||
|
def __init__(self):
|
||||||
|
user = os.environ.get('POSTGRES_USER')
|
||||||
|
password = os.environ.get('POSTGRES_PASSWORD')
|
||||||
|
host = os.environ.get('POSTGRES_HOST')
|
||||||
|
port = os.environ.get('POSTGRES_PORT')
|
||||||
|
database = os.environ.get('POSTGRES_DB')
|
||||||
|
database_url = f"postgresql+asyncpg://{user}:{password}@{host}:{port}/{database}"
|
||||||
|
self.async_engine = create_async_engine(database_url, echo=True)
|
||||||
|
self.async_session_maker = sessionmaker(self.async_engine, class_=AsyncSession, expire_on_commit=False)
|
||||||
|
|
||||||
|
async def init_db(self) -> None:
|
||||||
|
async with self.async_engine.begin() as conn:
|
||||||
|
await conn.run_sync(SQLModel.metadata.create_all)
|
||||||
|
|
||||||
|
@database_exception
|
||||||
|
async def add_user(self, user_name: str, hashed_password: str) -> User:
|
||||||
|
user = User(user_name=user_name, hashed_password=hashed_password)
|
||||||
|
async with self.async_session_maker as session:
|
||||||
|
session.add(user)
|
||||||
|
await session.commit()
|
||||||
|
await session.refresh(user)
|
||||||
|
return user
|
||||||
|
|
||||||
|
@database_exception
|
||||||
|
async def change_password(self, user_name, old_password, new_password) -> User:
|
||||||
|
async with self.async_session_maker() as session:
|
||||||
|
statement = select(User).where(User.user_name == user_name)
|
||||||
|
results = await session.exec(statement)
|
||||||
|
user = results.scalar_one_or_none()
|
||||||
|
if user is None:
|
||||||
|
raise UserNotExistError()
|
||||||
|
if old_password != user.hashed_password:
|
||||||
|
raise UserPasswordError()
|
||||||
|
user.hashed_password = new_password
|
||||||
|
session.add(user)
|
||||||
|
await session.commit()
|
||||||
|
await session.refresh(user)
|
||||||
|
return user
|
||||||
|
|
||||||
|
@database_exception
|
||||||
|
async def delete_user(self, user_name: str) -> None:
|
||||||
|
async with self.async_session_maker() as session:
|
||||||
|
statement = select(User).where(User.user_name == user_name)
|
||||||
|
results = await session.exec(statement)
|
||||||
|
user = results.scalar_one_or_none()
|
||||||
|
if user is None:
|
||||||
|
raise UserNotExistError()
|
||||||
|
session.delete(user)
|
||||||
|
await session.commit()
|
||||||
|
|
||||||
|
@database_exception
|
||||||
|
async def get_user_password(self, user_name: str) -> str:
|
||||||
|
async with self.async_session_maker() as session:
|
||||||
|
statement = select(User).where(User.user_name == user_name)
|
||||||
|
results = await session.exec(statement)
|
||||||
|
user = results.scalar_one_or_none()
|
||||||
|
if user is None:
|
||||||
|
raise UserNotExistError()
|
||||||
|
return user.hashed_password
|
||||||
|
|
||||||
|
|
@ -0,0 +1 @@
|
||||||
|
from pretor.core.database.table.user import User
|
||||||
|
|
@ -0,0 +1,8 @@
|
||||||
|
from sqlmodel import SQLModel, Field, Column, String
|
||||||
|
|
||||||
|
|
||||||
|
class User(SQLModel):
|
||||||
|
__tablename__ = 'user'
|
||||||
|
user_id: int = Field(default=None, primary_key=True)
|
||||||
|
user_name: str = Field(index=True)
|
||||||
|
hashed_password: str
|
||||||
|
|
@ -8,7 +8,7 @@ class WorkflowManager:
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
self.workflow_template_generator = WorkflowTemplateGenerator()
|
self.workflow_template_generator = WorkflowTemplateGenerator()
|
||||||
self.workflow_templates_registry = {}
|
self.workflow_templates_registry = {}
|
||||||
self.template_path = Path("pretor/workflow_plugin")
|
self.template_path = Path("pretor/workflow_template")
|
||||||
self._load_workflow_template()
|
self._load_workflow_template()
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -5,7 +5,7 @@ class WorkflowTemplateGenerator:
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def generate_workflow_template(name: str, desc: str, steps: list) -> None:
|
def generate_workflow_template(name: str, desc: str, steps: list) -> None:
|
||||||
workflow_template = WorkflowTemplate(name=name, desc=desc, work_link=steps)
|
workflow_template = WorkflowTemplate(name=name, desc=desc, work_link=steps)
|
||||||
output_dir = Path("pretor.workflow_plugin")
|
output_dir = Path("pretor.workflow_template")
|
||||||
if not output_dir.exists():
|
if not output_dir.exists():
|
||||||
output_dir.mkdir(parents=True)
|
output_dir.mkdir(parents=True)
|
||||||
output_file = output_dir / f"{name}_workflow_template.json"
|
output_file = output_dir / f"{name}_workflow_template.json"
|
||||||
|
|
|
||||||
|
|
@ -1,26 +1,5 @@
|
||||||
from pretor.core.protocol.runnable_object import RunnableObject
|
from pydantic_ai import Agent
|
||||||
from pretor.core.workflow_manager.workflow import PretorWorkflow
|
|
||||||
from pretor.adapter_plugin.model_adapter.modelbase import ModelBase
|
|
||||||
import ray
|
|
||||||
from typing import Any,Dict
|
|
||||||
from pretor.individual_plugin.control_node.control_register import ControlRegister
|
|
||||||
from pretor.utils.inspector import inspector
|
|
||||||
|
|
||||||
#control_node 管控节点,掌管系统的全局状态
|
|
||||||
@ray.remote
|
|
||||||
class ControlNode(RunnableObject):
|
|
||||||
def __init__(self, **kwargs: Dict[str: Any]) -> None:
|
|
||||||
self.model_adapter : ModelBase = kwargs.get("model_adapter")
|
|
||||||
self.model : str = kwargs.get("model")
|
|
||||||
self.name : str = kwargs.get("name", "管控节点")
|
|
||||||
self.control_register = ControlRegister()
|
|
||||||
|
|
||||||
def _load_control_register(self) :
|
|
||||||
pass
|
|
||||||
|
|
||||||
@inspector("individual","control_node")
|
|
||||||
async def run(self, workflow : PretorWorkflow) -> None:
|
|
||||||
control_register = self.control_register.model_dump_json()
|
|
||||||
demand = workflow.status.content.demand.model_dump_json()
|
|
||||||
|
|
||||||
|
|
||||||
|
class ControlNode:
|
||||||
|
def __init__(self):
|
||||||
|
pass
|
||||||
|
|
@ -1,22 +0,0 @@
|
||||||
from pydantic import BaseModel, Field
|
|
||||||
from typing import List, Dict, Any, Literal, Union, Optional
|
|
||||||
|
|
||||||
class SystemItem(BaseModel):
|
|
||||||
command_template: str = Field(..., description="底层 shell 命令模板")
|
|
||||||
args_schema: Dict[str, Any] = Field(default_factory=dict, description="该指令接受的参数约束")
|
|
||||||
|
|
||||||
class IndividualItem(BaseModel):
|
|
||||||
description: str
|
|
||||||
params: Dict[str: str]
|
|
||||||
base_prompt: str = Field(..., description="个体的基础人格/背景设定")
|
|
||||||
|
|
||||||
class ToolItem(BaseModel):
|
|
||||||
description: str
|
|
||||||
plugin_path: str = Field(..., description="插件物理路径或类路径")
|
|
||||||
|
|
||||||
class ControlRegister(BaseModel):
|
|
||||||
# 统一使用 Dict,方便通过 name 快速索引:{ "name": ItemObject }
|
|
||||||
system_registry: Dict[str, SystemItem] = Field(default_factory=dict)
|
|
||||||
individual_registry: Dict[str, IndividualItem] = Field(default_factory=dict)
|
|
||||||
tool_registry: Dict[str, ToolItem] = Field(default_factory=dict)
|
|
||||||
global_information : Dict[str, str] = Field(default_factory=dict)
|
|
||||||
|
|
@ -1,2 +1,17 @@
|
||||||
class DemandError(Exception):
|
class DemandError(Exception):
|
||||||
|
pass
|
||||||
|
|
||||||
|
class ModelNotExistError(Exception):
|
||||||
|
pass
|
||||||
|
|
||||||
|
class ProviderNotExistError(Exception):
|
||||||
|
pass
|
||||||
|
|
||||||
|
class UserError(Exception):
|
||||||
|
pass
|
||||||
|
|
||||||
|
class UserNotExistError(UserError):
|
||||||
|
pass
|
||||||
|
|
||||||
|
class UserPasswordError(UserError):
|
||||||
pass
|
pass
|
||||||
|
|
@ -5,10 +5,13 @@ description = "Add your description here"
|
||||||
readme = "README.md"
|
readme = "README.md"
|
||||||
requires-python = ">=3.13"
|
requires-python = ">=3.13"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
|
"asyncpg>=0.31.0",
|
||||||
"docker-py>=1.10.6",
|
"docker-py>=1.10.6",
|
||||||
"httpx>=0.28.1",
|
"httpx>=0.28.1",
|
||||||
"jinja2>=3.1.6",
|
"jinja2>=3.1.6",
|
||||||
"loguru>=0.7.3",
|
"loguru>=0.7.3",
|
||||||
|
"pydantic-ai>=1.73.0",
|
||||||
"python-ulid>=3.1.0",
|
"python-ulid>=3.1.0",
|
||||||
"ray[defaule,serve]>=2.54.0",
|
"ray[default,serve]>=2.54.0",
|
||||||
|
"sqlmodel>=0.0.37",
|
||||||
]
|
]
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue