Pretor/pretor/adapter/model_adapter/deepseek_reasoner.py

59 lines
2.1 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import re
import os
from typing import Type, TypeVar, Any, Generic
from pydantic import BaseModel, ValidationError
from pydantic_ai import Agent, RunContext
from pydantic_ai.models.openai import OpenAIChatModel
T = TypeVar('T', bound=BaseModel)
class DeepSeekReasonerAgent(Generic[T]):
"""
专为 DeepSeek-V4/R1 设计的适配器。
将结构化输出降级为文本解析模式,以规避工具调用(Tool Calling)的兼容性问题。
"""
def __init__(
self,
model_id: str = "deepseek-v4-pro",
output_type: Type[T] = str,
system_prompt: str = "",
deps_type: Type[Any] = None
):
# 1. 强制声明输出为 str确保底层不发送 tools 字段
self.output_schema = output_type
# 2. 注入强制格式指令到 System Prompt
format_instruction = (
f"\n\nCRITICAL: 你必须输出且只能输出一段纯 JSON 格式的数据,"
f"并包裹在 ```json 和 ``` 之间。格式必须符合以下 Pydantic 模型结构:\n"
f"{self.output_schema.model_json_schema()}"
)
self.agent = Agent(
model=OpenAIChatModel(model_id),
output_type=str, # 内部通信用字符串
system_prompt=system_prompt + format_instruction,
deps_type=deps_type,
retries=0
)
async def run(self, user_prompt: str, deps: Any = None) -> T:
# 调用 PydanticAI 原生 run
result = await self.agent.run(user_prompt, deps=deps)
return self._parse_json(result.output)
def _parse_json(self, text: str) -> T:
# 使用正则提取 JSON 块
match = re.search(r'```json\s*(.*?)\s*```', text, re.DOTALL)
json_str = match.group(1).strip() if match else text
# 如果正则没抓到,尝试寻找首尾大括号
if not json_str.startswith('{'):
json_str = text[text.find('{'):text.rfind('}') + 1]
try:
return self.output_schema.model_validate_json(json_str)
except ValidationError as e:
raise ValueError(f"DeepSeek 返回格式非法: {e}\n原文: {text}")