Pretor/pretor/adapter/model_adapter/provider_manager.py

57 lines
3.0 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

from pretor.adapter.model_adapter._agent_factory import AgentFactory
from pretor.adapter.model_adapter.model_provider import Provider, ProviderArgs, OpenAIProvider,GeminiProvider, ClaudeProvider
from pydantic_ai import Agent
import httpx
from pretor.utils.error import ModelNotExistError, ProviderNotExistError
from loguru import logger
from typing import Dict
class ProviderManager:
def __init__(self):
self._provider_mapper = {"openai": OpenAIProvider, "gemini": GeminiProvider, "claude": ClaudeProvider}
self._agent_factory = AgentFactory()
self.provider_register = {}
async def add_provider(self, provider_type: str, provider_title: str, provider_url: str, provider_apikey: str) -> None:
"""
add_provider方法注册供应商适配器
:param provider_type: 注册商接口类型目前只支持openai,gemini和claude接口
:param provider_title: 供应商名称,为供应商提供的别名
:param provider_url: 供应商url
:param provider_apikey: 供应商所需要的apikey
:return:
"""
provider_args: ProviderArgs = ProviderArgs(provider_title=provider_title, provider_url=provider_url, provider_apikey=provider_apikey)
try:
if provider_type not in self._provider_mapper.keys():
logger.warning(f"Provider type {provider_type} is not supported.")
return None
provider_class = self._provider_mapper.get(provider_type)
provider: Provider = await provider_class.create_model(provider_args)
self.provider_register[provider_title] = provider
logger.info(f"已添加适配器{provider_title}")
except httpx.RequestError as e:
logger.warning(f"[{provider_args.provider_title}] 网络请求异常: {e}")
except Exception as e:
logger.warning(f"[{provider_args.provider_title}] 解析模型列表时发生错误: {e}")
def create_agent(self, agent_name: str, system_prompt: str, provider_title: str, model_id: str) -> Agent:
"""
create_agent方法将保存的适配器转化为agent对象并返回
:param agent_name: agent名字代表实例化个体起的名字
:param system_prompt: 系统提示词给llm的系统提示词
:param provider_title: 供应商名称
:param model_id: 模型Id实例化agent所输入的model_id
:return:
"""
if provider_title not in self.provider_register:
raise ProviderNotExistError("提供商不存在")
provider = self.provider_register[provider_title]
if model_id not in provider.provider_models:
raise ModelNotExistError("模型不存在")
model = self._agent_factory.create_model(provider.provider_type, provider.provider_apikey, provider.provider_url, model_id)
agent = Agent(model=model,name=agent_name,system_prompt=system_prompt)
return agent
def get_provider_list(self) -> Dict[str, Provider]:
return self.provider_register