feat: 增加deepseek

This commit is contained in:
朝夕 2026-04-28 05:51:35 +08:00
parent 99bd6f65d7
commit 6ff59520e6
4 changed files with 85 additions and 11 deletions

View File

@ -19,13 +19,17 @@ from pydantic_ai.models.anthropic import AnthropicModel
from pydantic_ai.providers.openai import OpenAIProvider from pydantic_ai.providers.openai import OpenAIProvider
from pydantic_ai.providers.google import GoogleProvider from pydantic_ai.providers.google import GoogleProvider
from pydantic_ai.providers.anthropic import AnthropicProvider from pydantic_ai.providers.anthropic import AnthropicProvider
from pretor.adapter.model_adapter.deepseek_reasoner import DeepSeekReasonerAgent
from pretor.core.global_state_machine.model_provider import Provider from pretor.core.global_state_machine.model_provider import Provider
from pretor.utils.agent_model import ResponseModel, DepsModel from pretor.utils.agent_model import ResponseModel, DepsModel
from pretor.utils.error import ModelNotExistError from pretor.utils.error import ModelNotExistError
class AgentFactory: class AgentFactory:
def __init__(self): def __init__(self):
self._models_mapping = {"openai": (OpenAIChatModel, OpenAIProvider), "gemini": (GoogleModel, GoogleProvider), "claude": (AnthropicModel, AnthropicProvider)} self._models_mapping = {"openai": (OpenAIChatModel, OpenAIProvider),
"gemini": (GoogleModel, GoogleProvider),
"claude": (AnthropicModel, AnthropicProvider),
"deepseek": (OpenAIChatModel, OpenAIProvider),}
def create_agent(self, def create_agent(self,
provider: Provider, provider: Provider,
@ -45,6 +49,7 @@ class AgentFactory:
system_prompt: 系统提示词 system_prompt: 系统提示词
deps_type: 依赖类型在agent运行时动态输入的格式化消息 deps_type: 依赖类型在agent运行时动态输入的格式化消息
agent_name: agent的名字 agent_name: agent的名字
tools: 工具列表
Returns: Returns:
返回被实例化的pydantic-ai的Agent对象 返回被实例化的pydantic-ai的Agent对象
@ -55,6 +60,15 @@ class AgentFactory:
raise ValueError(f"不支持的协议类型: {provider.provider_type}") raise ValueError(f"不支持的协议类型: {provider.provider_type}")
model_class, provider_class = self._models_mapping[provider.provider_type] model_class, provider_class = self._models_mapping[provider.provider_type]
model = model_class(model_id, provider=provider_class(api_key=provider.provider_apikey, base_url=provider.provider_url)) model = model_class(model_id, provider=provider_class(api_key=provider.provider_apikey, base_url=provider.provider_url))
match provider.provider_type:
case "deepseek":
agent = DeepSeekReasonerAgent(model=model,
name=agent_name,
output_type=output_type,
deps_type=deps_type,
system_prompt=system_prompt,
)
case _:
agent = Agent(model=model, agent = Agent(model=model,
name=agent_name, name=agent_name,
system_prompt=system_prompt, system_prompt=system_prompt,

View File

@ -3,7 +3,6 @@ import os
from typing import Type, TypeVar, Any, Generic from typing import Type, TypeVar, Any, Generic
from pydantic import BaseModel, ValidationError from pydantic import BaseModel, ValidationError
from pydantic_ai import Agent, RunContext from pydantic_ai import Agent, RunContext
from pydantic_ai.models.openai import OpenAIChatModel
T = TypeVar('T', bound=BaseModel) T = TypeVar('T', bound=BaseModel)
@ -16,7 +15,8 @@ class DeepSeekReasonerAgent(Generic[T]):
def __init__( def __init__(
self, self,
model_id: str = "deepseek-v4-pro", model,
name,
output_type: Type[T] = str, output_type: Type[T] = str,
system_prompt: str = "", system_prompt: str = "",
deps_type: Type[Any] = None deps_type: Type[Any] = None
@ -32,7 +32,8 @@ class DeepSeekReasonerAgent(Generic[T]):
) )
self.agent = Agent( self.agent = Agent(
model=OpenAIChatModel(model_id), model=model,
name=name,
output_type=str, # 内部通信用字符串 output_type=str, # 内部通信用字符串
system_prompt=system_prompt + format_instruction, system_prompt=system_prompt + format_instruction,
deps_type=deps_type, deps_type=deps_type,

View File

@ -0,0 +1,59 @@
# Copyright 2026 zhaoxi826
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from pretor.utils.retry import retry_on_retryable_error
from pretor.core.global_state_machine.model_provider.base_provider import BaseProvider, Provider, ProviderArgs
import httpx
from typing import List
class OpenAIProvider(BaseProvider):
@staticmethod
async def create_provider(provider_args: ProviderArgs) -> Provider:
provider_models: List[str] = await OpenAIProvider._load_models(provider_args)
provider: Provider = OpenAIProvider._return_provider(provider_args, provider_models)
return provider
@staticmethod
@retry_on_retryable_error()
async def _load_models(provider_args: ProviderArgs) -> List[str]:
headers = {
"Authorization": f"Bearer {provider_args.provider_apikey}",
"Content-Type": "application/json"
}
url = f"{provider_args.provider_url}/models" if "/v1" in provider_args.provider_url else f"{provider_args.provider_url}/v1/models"
try:
async with httpx.AsyncClient(timeout=10.0) as client:
response = await client.get(url, headers=headers)
if response.status_code != 200:
print(f"[{provider_args.provider_title}] 获取模型失败: {response.status_code}")
return []
data = response.json()
raw_models = data.get("data", [])
model_ids = [m["id"] for m in raw_models]
return sorted(model_ids)
except httpx.RequestError as e:
from pretor.utils.error import RetryableError
print(f"[{provider_args.provider_title}] 网络请求异常: {e}")
raise RetryableError(f"[{provider_args.provider_title}] 网络请求异常: {e}") from e
except Exception as e:
print(f"[{provider_args.provider_title}] 解析模型列表时发生错误: {e}")
return []
@staticmethod
def _return_provider(provider_args: ProviderArgs, provider_models: List[str]) -> Provider:
return Provider(provider_title=provider_args.provider_title,
provider_apikey=provider_args.provider_apikey,
provider_url=provider_args.provider_url,
provider_models=provider_models,
provider_type="deepseek")

View File

@ -1,4 +1,3 @@
from pretor.utils.retry import retry_on_retryable_error
# Copyright 2026 zhaoxi826 # Copyright 2026 zhaoxi826
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
@ -13,6 +12,7 @@ from pretor.utils.retry import retry_on_retryable_error
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from pretor.utils.retry import retry_on_retryable_error
from pretor.core.global_state_machine.model_provider.base_provider import BaseProvider, Provider, ProviderArgs from pretor.core.global_state_machine.model_provider.base_provider import BaseProvider, Provider, ProviderArgs
import httpx import httpx
from typing import List from typing import List