From 6ff59520e61a1771856664cb659c982b73b6660d Mon Sep 17 00:00:00 2001 From: zhaoxi Date: Tue, 28 Apr 2026 05:51:35 +0800 Subject: [PATCH] =?UTF-8?q?feat:=20=E5=A2=9E=E5=8A=A0deepseek?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- pretor/adapter/model_adapter/agent_factory.py | 28 ++++++--- .../model_adapter/deepseek_reasoner.py | 7 ++- .../model_provider/deepseek_provider.py | 59 +++++++++++++++++++ .../model_provider/openai_provider.py | 2 +- 4 files changed, 85 insertions(+), 11 deletions(-) create mode 100644 pretor/core/global_state_machine/model_provider/deepseek_provider.py diff --git a/pretor/adapter/model_adapter/agent_factory.py b/pretor/adapter/model_adapter/agent_factory.py index 3cffd6a..328c854 100644 --- a/pretor/adapter/model_adapter/agent_factory.py +++ b/pretor/adapter/model_adapter/agent_factory.py @@ -19,13 +19,17 @@ from pydantic_ai.models.anthropic import AnthropicModel from pydantic_ai.providers.openai import OpenAIProvider from pydantic_ai.providers.google import GoogleProvider from pydantic_ai.providers.anthropic import AnthropicProvider +from pretor.adapter.model_adapter.deepseek_reasoner import DeepSeekReasonerAgent from pretor.core.global_state_machine.model_provider import Provider from pretor.utils.agent_model import ResponseModel, DepsModel from pretor.utils.error import ModelNotExistError class AgentFactory: def __init__(self): - self._models_mapping = {"openai": (OpenAIChatModel, OpenAIProvider), "gemini": (GoogleModel, GoogleProvider), "claude": (AnthropicModel, AnthropicProvider)} + self._models_mapping = {"openai": (OpenAIChatModel, OpenAIProvider), + "gemini": (GoogleModel, GoogleProvider), + "claude": (AnthropicModel, AnthropicProvider), + "deepseek": (OpenAIChatModel, OpenAIProvider),} def create_agent(self, provider: Provider, @@ -45,6 +49,7 @@ class AgentFactory: system_prompt: 系统提示词 deps_type: 依赖类型,在agent运行时动态输入的格式化消息 agent_name: agent的名字 + tools: 工具列表 Returns: 返回被实例化的pydantic-ai的Agent对象 @@ -55,10 +60,19 @@ class AgentFactory: raise ValueError(f"不支持的协议类型: {provider.provider_type}") model_class, provider_class = self._models_mapping[provider.provider_type] model = model_class(model_id, provider=provider_class(api_key=provider.provider_apikey, base_url=provider.provider_url)) - agent = Agent(model=model, - name=agent_name, - system_prompt=system_prompt, - output_type=output_type, - deps_type=deps_type, - tools=tools) + match provider.provider_type: + case "deepseek": + agent = DeepSeekReasonerAgent(model=model, + name=agent_name, + output_type=output_type, + deps_type=deps_type, + system_prompt=system_prompt, + ) + case _: + agent = Agent(model=model, + name=agent_name, + system_prompt=system_prompt, + output_type=output_type, + deps_type=deps_type, + tools=tools) return agent \ No newline at end of file diff --git a/pretor/adapter/model_adapter/deepseek_reasoner.py b/pretor/adapter/model_adapter/deepseek_reasoner.py index e25a320..a20ec60 100644 --- a/pretor/adapter/model_adapter/deepseek_reasoner.py +++ b/pretor/adapter/model_adapter/deepseek_reasoner.py @@ -3,7 +3,6 @@ import os from typing import Type, TypeVar, Any, Generic from pydantic import BaseModel, ValidationError from pydantic_ai import Agent, RunContext -from pydantic_ai.models.openai import OpenAIChatModel T = TypeVar('T', bound=BaseModel) @@ -16,7 +15,8 @@ class DeepSeekReasonerAgent(Generic[T]): def __init__( self, - model_id: str = "deepseek-v4-pro", + model, + name, output_type: Type[T] = str, system_prompt: str = "", deps_type: Type[Any] = None @@ -32,7 +32,8 @@ class DeepSeekReasonerAgent(Generic[T]): ) self.agent = Agent( - model=OpenAIChatModel(model_id), + model=model, + name=name, output_type=str, # 内部通信用字符串 system_prompt=system_prompt + format_instruction, deps_type=deps_type, diff --git a/pretor/core/global_state_machine/model_provider/deepseek_provider.py b/pretor/core/global_state_machine/model_provider/deepseek_provider.py new file mode 100644 index 0000000..032c6d9 --- /dev/null +++ b/pretor/core/global_state_machine/model_provider/deepseek_provider.py @@ -0,0 +1,59 @@ +# Copyright 2026 zhaoxi826 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from pretor.utils.retry import retry_on_retryable_error +from pretor.core.global_state_machine.model_provider.base_provider import BaseProvider, Provider, ProviderArgs +import httpx +from typing import List + +class OpenAIProvider(BaseProvider): + @staticmethod + async def create_provider(provider_args: ProviderArgs) -> Provider: + provider_models: List[str] = await OpenAIProvider._load_models(provider_args) + provider: Provider = OpenAIProvider._return_provider(provider_args, provider_models) + return provider + + @staticmethod + @retry_on_retryable_error() + async def _load_models(provider_args: ProviderArgs) -> List[str]: + headers = { + "Authorization": f"Bearer {provider_args.provider_apikey}", + "Content-Type": "application/json" + } + url = f"{provider_args.provider_url}/models" if "/v1" in provider_args.provider_url else f"{provider_args.provider_url}/v1/models" + try: + async with httpx.AsyncClient(timeout=10.0) as client: + response = await client.get(url, headers=headers) + if response.status_code != 200: + print(f"[{provider_args.provider_title}] 获取模型失败: {response.status_code}") + return [] + data = response.json() + raw_models = data.get("data", []) + model_ids = [m["id"] for m in raw_models] + return sorted(model_ids) + except httpx.RequestError as e: + from pretor.utils.error import RetryableError + print(f"[{provider_args.provider_title}] 网络请求异常: {e}") + raise RetryableError(f"[{provider_args.provider_title}] 网络请求异常: {e}") from e + except Exception as e: + print(f"[{provider_args.provider_title}] 解析模型列表时发生错误: {e}") + return [] + + @staticmethod + def _return_provider(provider_args: ProviderArgs, provider_models: List[str]) -> Provider: + return Provider(provider_title=provider_args.provider_title, + provider_apikey=provider_args.provider_apikey, + provider_url=provider_args.provider_url, + provider_models=provider_models, + provider_type="deepseek") \ No newline at end of file diff --git a/pretor/core/global_state_machine/model_provider/openai_provider.py b/pretor/core/global_state_machine/model_provider/openai_provider.py index 2e545a4..26ac3c4 100644 --- a/pretor/core/global_state_machine/model_provider/openai_provider.py +++ b/pretor/core/global_state_machine/model_provider/openai_provider.py @@ -1,4 +1,3 @@ -from pretor.utils.retry import retry_on_retryable_error # Copyright 2026 zhaoxi826 # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -13,6 +12,7 @@ from pretor.utils.retry import retry_on_retryable_error # See the License for the specific language governing permissions and # limitations under the License. +from pretor.utils.retry import retry_on_retryable_error from pretor.core.global_state_machine.model_provider.base_provider import BaseProvider, Provider, ProviderArgs import httpx from typing import List