Pretor/pretor/adapter/model_adapter/agent_factory.py

78 lines
3.6 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

# Copyright 2026 zhaoxi826
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from pydantic_ai import Agent
from pydantic_ai.models.openai import OpenAIChatModel
from pydantic_ai.models.google import GoogleModel
from pydantic_ai.models.anthropic import AnthropicModel
from pydantic_ai.providers.openai import OpenAIProvider
from pydantic_ai.providers.google import GoogleProvider
from pydantic_ai.providers.anthropic import AnthropicProvider
from pretor.adapter.model_adapter.deepseek_reasoner import DeepSeekReasonerAgent
from pretor.core.global_state_machine.model_provider import Provider
from pretor.utils.agent_model import ResponseModel, DepsModel
from pretor.utils.error import ModelNotExistError
class AgentFactory:
def __init__(self):
self._models_mapping = {"openai": (OpenAIChatModel, OpenAIProvider),
"gemini": (GoogleModel, GoogleProvider),
"claude": (AnthropicModel, AnthropicProvider),
"deepseek": (OpenAIChatModel, OpenAIProvider),}
def create_agent(self,
provider: Provider,
model_id: str,
output_type: ResponseModel,
system_prompt: str,
deps_type: DepsModel,
agent_name: str,
tools: list = None) -> Agent:
"""
create_agent方法将输入的provider对象实例化为一个pydantic-ai的agent对象
Args:
provider: Provider对象从global_state_machine中获取
model_id: 模型名
output_type: 输出格式
system_prompt: 系统提示词
deps_type: 依赖类型在agent运行时动态输入的格式化消息
agent_name: agent的名字
tools: 工具列表
Returns:
返回被实例化的pydantic-ai的Agent对象
"""
if model_id not in provider.provider_models:
raise ModelNotExistError("模型不存在")
if provider.provider_type not in self._models_mapping:
raise ValueError(f"不支持的协议类型: {provider.provider_type}")
model_class, provider_class = self._models_mapping[provider.provider_type]
model = model_class(model_id, provider=provider_class(api_key=provider.provider_apikey, base_url=provider.provider_url))
match provider.provider_type:
case "deepseek":
agent = DeepSeekReasonerAgent(model=model,
name=agent_name,
output_type=output_type,
deps_type=deps_type,
system_prompt=system_prompt,
)
case _:
agent = Agent(model=model,
name=agent_name,
system_prompt=system_prompt,
output_type=output_type,
deps_type=deps_type,
tools=tools)
return agent