Files
KiloStar/kilostar/adapter/model_adapter/agent_factory.py
T

122 lines
4.4 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
# Copyright 2026 zhaoxi826
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from pydantic_ai import Agent
from pydantic_ai.models.openai import OpenAIChatModel
from pydantic_ai.models.anthropic import AnthropicModel
from pydantic_ai.models.gemini import GeminiModel
from pydantic_ai.providers.openai import OpenAIProvider
from pydantic_ai.providers.anthropic import AnthropicProvider
from pydantic_ai.providers.deepseek import DeepSeekProvider
from pydantic_ai.providers.google import GoogleProvider
from kilostar.core.global_state_machine.model_provider import Provider
from kilostar.utils.agent_model import ResponseModel, DepsModel
from kilostar.utils.error import ModelNotExistError
class AgentFactory:
"""模型工厂:把内部的 ``Provider`` 元数据翻译成 pydantic-ai 的 ``Agent``。
支持 openai / claude / deepseek / gemini 四类后端,差异通过
``_models_mapping`` 中的 ``model_class`` + ``provider_class`` 键值对屏蔽。
"""
def __init__(self):
self._models_mapping = {
"openai": {
"model_class": OpenAIChatModel,
"provider_class": OpenAIProvider,
"provider_kwargs": {"base_url": True, "api_key": True},
},
"claude": {
"model_class": AnthropicModel,
"provider_class": AnthropicProvider,
"provider_kwargs": {"api_key": True},
},
"deepseek": {
"model_class": OpenAIChatModel,
"provider_class": DeepSeekProvider,
"provider_kwargs": {"api_key": True},
},
"gemini": {
"model_class": GeminiModel,
"provider_class": GoogleProvider,
"provider_kwargs": {"api_key": True},
},
}
def create_agent(
self,
provider: Provider,
model_id: str,
output_type: ResponseModel,
system_prompt: str,
deps_type: DepsModel,
agent_name: str,
tools: list = None,
) -> Agent:
"""
create_agent方法,将输入的provider对象实例化为一个pydantic-ai的agent对象
Args:
provider: Provider对象,从global_state_machine中获取
model_id: 模型名
output_type: 输出格式
system_prompt: 系统提示词
deps_type: 依赖类型,在agent运行时动态输入的格式化消息
agent_name: agent的名字
tools: 工具列表
Returns:
返回被实例化的pydantic-ai的Agent对象
"""
if model_id not in provider.provider_models:
raise ModelNotExistError("模型不存在")
if provider.provider_type not in self._models_mapping:
raise ValueError(f"不支持的协议类型: {provider.provider_type}")
config = self._models_mapping[provider.provider_type]
model_class = config["model_class"]
provider_class = config["provider_class"]
provider_kwargs = config["provider_kwargs"]
# 构建 provider 实例化参数
init_kwargs = {}
if provider_kwargs.get("api_key"):
init_kwargs["api_key"] = provider.provider_apikey
if provider_kwargs.get("base_url"):
init_kwargs["base_url"] = provider.provider_url
model_provider = provider_class(**init_kwargs)
# 对于 Geminiprovider 需要传递给 model
if provider.provider_type == "gemini":
model = model_class(
model_name=model_id,
provider=model_provider,
)
else:
model = model_class(model_id, provider=model_provider)
# 创建 Agent
agent = Agent(
model=model,
name=agent_name,
system_prompt=system_prompt,
output_type=output_type,
deps_type=deps_type,
tools=tools,
)
return agent