feat: 增加deepseek
This commit is contained in:
parent
99bd6f65d7
commit
6ff59520e6
|
|
@ -19,13 +19,17 @@ from pydantic_ai.models.anthropic import AnthropicModel
|
|||
from pydantic_ai.providers.openai import OpenAIProvider
|
||||
from pydantic_ai.providers.google import GoogleProvider
|
||||
from pydantic_ai.providers.anthropic import AnthropicProvider
|
||||
from pretor.adapter.model_adapter.deepseek_reasoner import DeepSeekReasonerAgent
|
||||
from pretor.core.global_state_machine.model_provider import Provider
|
||||
from pretor.utils.agent_model import ResponseModel, DepsModel
|
||||
from pretor.utils.error import ModelNotExistError
|
||||
|
||||
class AgentFactory:
|
||||
def __init__(self):
|
||||
self._models_mapping = {"openai": (OpenAIChatModel, OpenAIProvider), "gemini": (GoogleModel, GoogleProvider), "claude": (AnthropicModel, AnthropicProvider)}
|
||||
self._models_mapping = {"openai": (OpenAIChatModel, OpenAIProvider),
|
||||
"gemini": (GoogleModel, GoogleProvider),
|
||||
"claude": (AnthropicModel, AnthropicProvider),
|
||||
"deepseek": (OpenAIChatModel, OpenAIProvider),}
|
||||
|
||||
def create_agent(self,
|
||||
provider: Provider,
|
||||
|
|
@ -45,6 +49,7 @@ class AgentFactory:
|
|||
system_prompt: 系统提示词
|
||||
deps_type: 依赖类型,在agent运行时动态输入的格式化消息
|
||||
agent_name: agent的名字
|
||||
tools: 工具列表
|
||||
|
||||
Returns:
|
||||
返回被实例化的pydantic-ai的Agent对象
|
||||
|
|
@ -55,10 +60,19 @@ class AgentFactory:
|
|||
raise ValueError(f"不支持的协议类型: {provider.provider_type}")
|
||||
model_class, provider_class = self._models_mapping[provider.provider_type]
|
||||
model = model_class(model_id, provider=provider_class(api_key=provider.provider_apikey, base_url=provider.provider_url))
|
||||
agent = Agent(model=model,
|
||||
name=agent_name,
|
||||
system_prompt=system_prompt,
|
||||
output_type=output_type,
|
||||
deps_type=deps_type,
|
||||
tools=tools)
|
||||
match provider.provider_type:
|
||||
case "deepseek":
|
||||
agent = DeepSeekReasonerAgent(model=model,
|
||||
name=agent_name,
|
||||
output_type=output_type,
|
||||
deps_type=deps_type,
|
||||
system_prompt=system_prompt,
|
||||
)
|
||||
case _:
|
||||
agent = Agent(model=model,
|
||||
name=agent_name,
|
||||
system_prompt=system_prompt,
|
||||
output_type=output_type,
|
||||
deps_type=deps_type,
|
||||
tools=tools)
|
||||
return agent
|
||||
|
|
@ -3,7 +3,6 @@ import os
|
|||
from typing import Type, TypeVar, Any, Generic
|
||||
from pydantic import BaseModel, ValidationError
|
||||
from pydantic_ai import Agent, RunContext
|
||||
from pydantic_ai.models.openai import OpenAIChatModel
|
||||
|
||||
T = TypeVar('T', bound=BaseModel)
|
||||
|
||||
|
|
@ -16,7 +15,8 @@ class DeepSeekReasonerAgent(Generic[T]):
|
|||
|
||||
def __init__(
|
||||
self,
|
||||
model_id: str = "deepseek-v4-pro",
|
||||
model,
|
||||
name,
|
||||
output_type: Type[T] = str,
|
||||
system_prompt: str = "",
|
||||
deps_type: Type[Any] = None
|
||||
|
|
@ -32,7 +32,8 @@ class DeepSeekReasonerAgent(Generic[T]):
|
|||
)
|
||||
|
||||
self.agent = Agent(
|
||||
model=OpenAIChatModel(model_id),
|
||||
model=model,
|
||||
name=name,
|
||||
output_type=str, # 内部通信用字符串
|
||||
system_prompt=system_prompt + format_instruction,
|
||||
deps_type=deps_type,
|
||||
|
|
|
|||
|
|
@ -0,0 +1,59 @@
|
|||
# Copyright 2026 zhaoxi826
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from pretor.utils.retry import retry_on_retryable_error
|
||||
from pretor.core.global_state_machine.model_provider.base_provider import BaseProvider, Provider, ProviderArgs
|
||||
import httpx
|
||||
from typing import List
|
||||
|
||||
class OpenAIProvider(BaseProvider):
|
||||
@staticmethod
|
||||
async def create_provider(provider_args: ProviderArgs) -> Provider:
|
||||
provider_models: List[str] = await OpenAIProvider._load_models(provider_args)
|
||||
provider: Provider = OpenAIProvider._return_provider(provider_args, provider_models)
|
||||
return provider
|
||||
|
||||
@staticmethod
|
||||
@retry_on_retryable_error()
|
||||
async def _load_models(provider_args: ProviderArgs) -> List[str]:
|
||||
headers = {
|
||||
"Authorization": f"Bearer {provider_args.provider_apikey}",
|
||||
"Content-Type": "application/json"
|
||||
}
|
||||
url = f"{provider_args.provider_url}/models" if "/v1" in provider_args.provider_url else f"{provider_args.provider_url}/v1/models"
|
||||
try:
|
||||
async with httpx.AsyncClient(timeout=10.0) as client:
|
||||
response = await client.get(url, headers=headers)
|
||||
if response.status_code != 200:
|
||||
print(f"[{provider_args.provider_title}] 获取模型失败: {response.status_code}")
|
||||
return []
|
||||
data = response.json()
|
||||
raw_models = data.get("data", [])
|
||||
model_ids = [m["id"] for m in raw_models]
|
||||
return sorted(model_ids)
|
||||
except httpx.RequestError as e:
|
||||
from pretor.utils.error import RetryableError
|
||||
print(f"[{provider_args.provider_title}] 网络请求异常: {e}")
|
||||
raise RetryableError(f"[{provider_args.provider_title}] 网络请求异常: {e}") from e
|
||||
except Exception as e:
|
||||
print(f"[{provider_args.provider_title}] 解析模型列表时发生错误: {e}")
|
||||
return []
|
||||
|
||||
@staticmethod
|
||||
def _return_provider(provider_args: ProviderArgs, provider_models: List[str]) -> Provider:
|
||||
return Provider(provider_title=provider_args.provider_title,
|
||||
provider_apikey=provider_args.provider_apikey,
|
||||
provider_url=provider_args.provider_url,
|
||||
provider_models=provider_models,
|
||||
provider_type="deepseek")
|
||||
|
|
@ -1,4 +1,3 @@
|
|||
from pretor.utils.retry import retry_on_retryable_error
|
||||
# Copyright 2026 zhaoxi826
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
|
|
@ -13,6 +12,7 @@ from pretor.utils.retry import retry_on_retryable_error
|
|||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from pretor.utils.retry import retry_on_retryable_error
|
||||
from pretor.core.global_state_machine.model_provider.base_provider import BaseProvider, Provider, ProviderArgs
|
||||
import httpx
|
||||
from typing import List
|
||||
|
|
|
|||
Loading…
Reference in New Issue