feat: 增加deepseek
This commit is contained in:
parent
99bd6f65d7
commit
6ff59520e6
|
|
@ -19,13 +19,17 @@ from pydantic_ai.models.anthropic import AnthropicModel
|
||||||
from pydantic_ai.providers.openai import OpenAIProvider
|
from pydantic_ai.providers.openai import OpenAIProvider
|
||||||
from pydantic_ai.providers.google import GoogleProvider
|
from pydantic_ai.providers.google import GoogleProvider
|
||||||
from pydantic_ai.providers.anthropic import AnthropicProvider
|
from pydantic_ai.providers.anthropic import AnthropicProvider
|
||||||
|
from pretor.adapter.model_adapter.deepseek_reasoner import DeepSeekReasonerAgent
|
||||||
from pretor.core.global_state_machine.model_provider import Provider
|
from pretor.core.global_state_machine.model_provider import Provider
|
||||||
from pretor.utils.agent_model import ResponseModel, DepsModel
|
from pretor.utils.agent_model import ResponseModel, DepsModel
|
||||||
from pretor.utils.error import ModelNotExistError
|
from pretor.utils.error import ModelNotExistError
|
||||||
|
|
||||||
class AgentFactory:
|
class AgentFactory:
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
self._models_mapping = {"openai": (OpenAIChatModel, OpenAIProvider), "gemini": (GoogleModel, GoogleProvider), "claude": (AnthropicModel, AnthropicProvider)}
|
self._models_mapping = {"openai": (OpenAIChatModel, OpenAIProvider),
|
||||||
|
"gemini": (GoogleModel, GoogleProvider),
|
||||||
|
"claude": (AnthropicModel, AnthropicProvider),
|
||||||
|
"deepseek": (OpenAIChatModel, OpenAIProvider),}
|
||||||
|
|
||||||
def create_agent(self,
|
def create_agent(self,
|
||||||
provider: Provider,
|
provider: Provider,
|
||||||
|
|
@ -45,6 +49,7 @@ class AgentFactory:
|
||||||
system_prompt: 系统提示词
|
system_prompt: 系统提示词
|
||||||
deps_type: 依赖类型,在agent运行时动态输入的格式化消息
|
deps_type: 依赖类型,在agent运行时动态输入的格式化消息
|
||||||
agent_name: agent的名字
|
agent_name: agent的名字
|
||||||
|
tools: 工具列表
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
返回被实例化的pydantic-ai的Agent对象
|
返回被实例化的pydantic-ai的Agent对象
|
||||||
|
|
@ -55,10 +60,19 @@ class AgentFactory:
|
||||||
raise ValueError(f"不支持的协议类型: {provider.provider_type}")
|
raise ValueError(f"不支持的协议类型: {provider.provider_type}")
|
||||||
model_class, provider_class = self._models_mapping[provider.provider_type]
|
model_class, provider_class = self._models_mapping[provider.provider_type]
|
||||||
model = model_class(model_id, provider=provider_class(api_key=provider.provider_apikey, base_url=provider.provider_url))
|
model = model_class(model_id, provider=provider_class(api_key=provider.provider_apikey, base_url=provider.provider_url))
|
||||||
agent = Agent(model=model,
|
match provider.provider_type:
|
||||||
name=agent_name,
|
case "deepseek":
|
||||||
system_prompt=system_prompt,
|
agent = DeepSeekReasonerAgent(model=model,
|
||||||
output_type=output_type,
|
name=agent_name,
|
||||||
deps_type=deps_type,
|
output_type=output_type,
|
||||||
tools=tools)
|
deps_type=deps_type,
|
||||||
|
system_prompt=system_prompt,
|
||||||
|
)
|
||||||
|
case _:
|
||||||
|
agent = Agent(model=model,
|
||||||
|
name=agent_name,
|
||||||
|
system_prompt=system_prompt,
|
||||||
|
output_type=output_type,
|
||||||
|
deps_type=deps_type,
|
||||||
|
tools=tools)
|
||||||
return agent
|
return agent
|
||||||
|
|
@ -3,7 +3,6 @@ import os
|
||||||
from typing import Type, TypeVar, Any, Generic
|
from typing import Type, TypeVar, Any, Generic
|
||||||
from pydantic import BaseModel, ValidationError
|
from pydantic import BaseModel, ValidationError
|
||||||
from pydantic_ai import Agent, RunContext
|
from pydantic_ai import Agent, RunContext
|
||||||
from pydantic_ai.models.openai import OpenAIChatModel
|
|
||||||
|
|
||||||
T = TypeVar('T', bound=BaseModel)
|
T = TypeVar('T', bound=BaseModel)
|
||||||
|
|
||||||
|
|
@ -16,7 +15,8 @@ class DeepSeekReasonerAgent(Generic[T]):
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
model_id: str = "deepseek-v4-pro",
|
model,
|
||||||
|
name,
|
||||||
output_type: Type[T] = str,
|
output_type: Type[T] = str,
|
||||||
system_prompt: str = "",
|
system_prompt: str = "",
|
||||||
deps_type: Type[Any] = None
|
deps_type: Type[Any] = None
|
||||||
|
|
@ -32,7 +32,8 @@ class DeepSeekReasonerAgent(Generic[T]):
|
||||||
)
|
)
|
||||||
|
|
||||||
self.agent = Agent(
|
self.agent = Agent(
|
||||||
model=OpenAIChatModel(model_id),
|
model=model,
|
||||||
|
name=name,
|
||||||
output_type=str, # 内部通信用字符串
|
output_type=str, # 内部通信用字符串
|
||||||
system_prompt=system_prompt + format_instruction,
|
system_prompt=system_prompt + format_instruction,
|
||||||
deps_type=deps_type,
|
deps_type=deps_type,
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,59 @@
|
||||||
|
# Copyright 2026 zhaoxi826
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
from pretor.utils.retry import retry_on_retryable_error
|
||||||
|
from pretor.core.global_state_machine.model_provider.base_provider import BaseProvider, Provider, ProviderArgs
|
||||||
|
import httpx
|
||||||
|
from typing import List
|
||||||
|
|
||||||
|
class OpenAIProvider(BaseProvider):
|
||||||
|
@staticmethod
|
||||||
|
async def create_provider(provider_args: ProviderArgs) -> Provider:
|
||||||
|
provider_models: List[str] = await OpenAIProvider._load_models(provider_args)
|
||||||
|
provider: Provider = OpenAIProvider._return_provider(provider_args, provider_models)
|
||||||
|
return provider
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
@retry_on_retryable_error()
|
||||||
|
async def _load_models(provider_args: ProviderArgs) -> List[str]:
|
||||||
|
headers = {
|
||||||
|
"Authorization": f"Bearer {provider_args.provider_apikey}",
|
||||||
|
"Content-Type": "application/json"
|
||||||
|
}
|
||||||
|
url = f"{provider_args.provider_url}/models" if "/v1" in provider_args.provider_url else f"{provider_args.provider_url}/v1/models"
|
||||||
|
try:
|
||||||
|
async with httpx.AsyncClient(timeout=10.0) as client:
|
||||||
|
response = await client.get(url, headers=headers)
|
||||||
|
if response.status_code != 200:
|
||||||
|
print(f"[{provider_args.provider_title}] 获取模型失败: {response.status_code}")
|
||||||
|
return []
|
||||||
|
data = response.json()
|
||||||
|
raw_models = data.get("data", [])
|
||||||
|
model_ids = [m["id"] for m in raw_models]
|
||||||
|
return sorted(model_ids)
|
||||||
|
except httpx.RequestError as e:
|
||||||
|
from pretor.utils.error import RetryableError
|
||||||
|
print(f"[{provider_args.provider_title}] 网络请求异常: {e}")
|
||||||
|
raise RetryableError(f"[{provider_args.provider_title}] 网络请求异常: {e}") from e
|
||||||
|
except Exception as e:
|
||||||
|
print(f"[{provider_args.provider_title}] 解析模型列表时发生错误: {e}")
|
||||||
|
return []
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _return_provider(provider_args: ProviderArgs, provider_models: List[str]) -> Provider:
|
||||||
|
return Provider(provider_title=provider_args.provider_title,
|
||||||
|
provider_apikey=provider_args.provider_apikey,
|
||||||
|
provider_url=provider_args.provider_url,
|
||||||
|
provider_models=provider_models,
|
||||||
|
provider_type="deepseek")
|
||||||
|
|
@ -1,4 +1,3 @@
|
||||||
from pretor.utils.retry import retry_on_retryable_error
|
|
||||||
# Copyright 2026 zhaoxi826
|
# Copyright 2026 zhaoxi826
|
||||||
#
|
#
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
|
@ -13,6 +12,7 @@ from pretor.utils.retry import retry_on_retryable_error
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
|
from pretor.utils.retry import retry_on_retryable_error
|
||||||
from pretor.core.global_state_machine.model_provider.base_provider import BaseProvider, Provider, ProviderArgs
|
from pretor.core.global_state_machine.model_provider.base_provider import BaseProvider, Provider, ProviderArgs
|
||||||
import httpx
|
import httpx
|
||||||
from typing import List
|
from typing import List
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue