57 lines
3.0 KiB
Python
57 lines
3.0 KiB
Python
from pretor.adapter.model_adapter._agent_factory import AgentFactory
|
||
from pretor.adapter.model_adapter.model_provider import Provider, ProviderArgs, OpenAIProvider,GeminiProvider, ClaudeProvider
|
||
from pydantic_ai import Agent
|
||
import httpx
|
||
from pretor.utils.error import ModelNotExistError, ProviderNotExistError
|
||
from loguru import logger
|
||
from typing import Dict
|
||
|
||
class ProviderManager:
|
||
def __init__(self):
|
||
self._provider_mapper = {"openai": OpenAIProvider, "gemini": GeminiProvider, "claude": ClaudeProvider}
|
||
self._agent_factory = AgentFactory()
|
||
self.provider_register = {}
|
||
|
||
async def add_provider(self, provider_type: str, provider_title: str, provider_url: str, provider_apikey: str) -> None:
|
||
"""
|
||
add_provider方法,注册供应商适配器
|
||
:param provider_type: 注册商接口类型,目前只支持openai,gemini和claude接口
|
||
:param provider_title: 供应商名称,为供应商提供的别名
|
||
:param provider_url: 供应商url
|
||
:param provider_apikey: 供应商所需要的apikey
|
||
:return:
|
||
"""
|
||
provider_args: ProviderArgs = ProviderArgs(provider_title=provider_title, provider_url=provider_url, provider_apikey=provider_apikey)
|
||
try:
|
||
if provider_type not in self._provider_mapper.keys():
|
||
logger.warning(f"Provider type {provider_type} is not supported.")
|
||
return None
|
||
provider_class = self._provider_mapper.get(provider_type)
|
||
provider: Provider = await provider_class.create_model(provider_args)
|
||
self.provider_register[provider_title] = provider
|
||
logger.info(f"已添加适配器{provider_title}")
|
||
except httpx.RequestError as e:
|
||
logger.warning(f"[{provider_args.provider_title}] 网络请求异常: {e}")
|
||
except Exception as e:
|
||
logger.warning(f"[{provider_args.provider_title}] 解析模型列表时发生错误: {e}")
|
||
|
||
def create_agent(self, agent_name: str, system_prompt: str, provider_title: str, model_id: str) -> Agent:
|
||
"""
|
||
create_agent方法,将保存的适配器转化为agent对象并返回
|
||
:param agent_name: agent名字,代表实例化个体起的名字
|
||
:param system_prompt: 系统提示词,给llm的系统提示词
|
||
:param provider_title: 供应商名称
|
||
:param model_id: 模型Id,实例化agent所输入的model_id
|
||
:return:
|
||
"""
|
||
if provider_title not in self.provider_register:
|
||
raise ProviderNotExistError("提供商不存在")
|
||
provider = self.provider_register[provider_title]
|
||
if model_id not in provider.provider_models:
|
||
raise ModelNotExistError("模型不存在")
|
||
model = self._agent_factory.create_model(provider.provider_type, provider.provider_apikey, provider.provider_url, model_id)
|
||
agent = Agent(model=model,name=agent_name,system_prompt=system_prompt)
|
||
return agent
|
||
|
||
def get_provider_list(self) -> Dict[str, Provider]:
|
||
return self.provider_register |