feat(provider):增加google,anthropic供应商
1.增加更多的模型供应商
This commit is contained in:
@@ -15,9 +15,11 @@
|
||||
from pydantic_ai import Agent
|
||||
from pydantic_ai.models.openai import OpenAIChatModel
|
||||
from pydantic_ai.models.anthropic import AnthropicModel
|
||||
from pydantic_ai.models.gemini import GeminiModel
|
||||
from pydantic_ai.providers.openai import OpenAIProvider
|
||||
from pydantic_ai.providers.anthropic import AnthropicProvider
|
||||
from kilostar.adapter.model_adapter.deepseek_reasoner import DeepSeekReasonerAgent
|
||||
from pydantic_ai.providers.deepseek import DeepSeekProvider
|
||||
from pydantic_ai.providers.google import GoogleProvider
|
||||
from kilostar.core.global_state_machine.model_provider import Provider
|
||||
from kilostar.utils.agent_model import ResponseModel, DepsModel
|
||||
from kilostar.utils.error import ModelNotExistError
|
||||
@@ -29,9 +31,26 @@ class AgentFactory:
|
||||
|
||||
def __init__(self):
|
||||
self._models_mapping = {
|
||||
"openai": (OpenAIChatModel, OpenAIProvider),
|
||||
"claude": (AnthropicModel, AnthropicProvider),
|
||||
"deepseek": (OpenAIChatModel, OpenAIProvider),
|
||||
"openai": {
|
||||
"model_class": OpenAIChatModel,
|
||||
"provider_class": OpenAIProvider,
|
||||
"provider_kwargs": {"base_url": True, "api_key": True},
|
||||
},
|
||||
"claude": {
|
||||
"model_class": AnthropicModel,
|
||||
"provider_class": AnthropicProvider,
|
||||
"provider_kwargs": {"api_key": True},
|
||||
},
|
||||
"deepseek": {
|
||||
"model_class": OpenAIChatModel,
|
||||
"provider_class": DeepSeekProvider,
|
||||
"provider_kwargs": {"api_key": True},
|
||||
},
|
||||
"gemini": {
|
||||
"model_class": GeminiModel,
|
||||
"provider_class": GoogleProvider,
|
||||
"provider_kwargs": {"api_key": True},
|
||||
},
|
||||
}
|
||||
|
||||
def create_agent(
|
||||
@@ -63,31 +82,37 @@ class AgentFactory:
|
||||
raise ModelNotExistError("模型不存在")
|
||||
if provider.provider_type not in self._models_mapping:
|
||||
raise ValueError(f"不支持的协议类型: {provider.provider_type}")
|
||||
model_class, provider_class = self._models_mapping[provider.provider_type]
|
||||
model = model_class(
|
||||
model_id,
|
||||
provider=provider_class(
|
||||
api_key=provider.provider_apikey, base_url=provider.provider_url
|
||||
),
|
||||
|
||||
config = self._models_mapping[provider.provider_type]
|
||||
model_class = config["model_class"]
|
||||
provider_class = config["provider_class"]
|
||||
provider_kwargs = config["provider_kwargs"]
|
||||
|
||||
# 构建 provider 实例化参数
|
||||
init_kwargs = {}
|
||||
if provider_kwargs.get("api_key"):
|
||||
init_kwargs["api_key"] = provider.provider_apikey
|
||||
if provider_kwargs.get("base_url"):
|
||||
init_kwargs["base_url"] = provider.provider_url
|
||||
|
||||
model_provider = provider_class(**init_kwargs)
|
||||
|
||||
# 对于 Gemini,provider 需要传递给 model
|
||||
if provider.provider_type == "gemini":
|
||||
model = model_class(
|
||||
model_name=model_id,
|
||||
provider=model_provider,
|
||||
)
|
||||
else:
|
||||
model = model_class(model_id, provider=model_provider)
|
||||
|
||||
# 创建 Agent
|
||||
agent = Agent(
|
||||
model=model,
|
||||
name=agent_name,
|
||||
system_prompt=system_prompt,
|
||||
output_type=output_type,
|
||||
deps_type=deps_type,
|
||||
tools=tools,
|
||||
)
|
||||
match provider.provider_type:
|
||||
case "deepseek":
|
||||
agent = DeepSeekReasonerAgent(
|
||||
model=model,
|
||||
name=agent_name,
|
||||
output_type=output_type,
|
||||
deps_type=deps_type,
|
||||
system_prompt=system_prompt,
|
||||
tools=tools,
|
||||
retries=3,
|
||||
)
|
||||
case _:
|
||||
agent = Agent(
|
||||
model=model,
|
||||
name=agent_name,
|
||||
system_prompt=system_prompt,
|
||||
output_type=output_type,
|
||||
deps_type=deps_type,
|
||||
tools=tools,
|
||||
)
|
||||
return agent
|
||||
|
||||
Reference in New Issue
Block a user