Pretor/pretor/adapter/model_adapter/deepseek_reasoner.py

150 lines
6.0 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import re
import json
from typing import Type, TypeVar, Any, Generic
from pydantic import BaseModel, ValidationError
from pydantic_ai import Agent, RunContext
from pydantic_ai.run import AgentRunResult
T = TypeVar('T', bound=BaseModel)
class DeepSeekReasonerAgent(Generic[T]):
"""
专为 DeepSeek-V4/R1 设计的适配器。
将结构化输出降级为文本解析模式,并支持重试逻辑以确保系统兼容性。
"""
def __init__(
self,
model,
name,
output_type: Any = str,
system_prompt: str = "",
deps_type: Type[Any] = None,
tools: list = None,
retries: int = 3,
**kwargs
):
self.output_schema = output_type
self.has_custom_output = output_type is not str and output_type is not None
self.tools = tools or []
self.retries = retries
format_instruction = ""
if self.has_custom_output:
try:
if hasattr(self.output_schema, 'model_json_schema'):
schema_str = json.dumps(self.output_schema.model_json_schema(), ensure_ascii=False)
else:
# Don't inject <class 'dict'> into prompt
schema_str = str(getattr(self.output_schema, '__name__', str(self.output_schema)))
format_instruction = (
f"\n\nCRITICAL: 你必须输出且只能输出一段纯 JSON 格式的数据,"
f"并包裹在 ```json 和 ``` 之间。格式必须符合以下 JSON Schema 结构(或对应数据类型):\n"
f"{schema_str}"
)
except Exception:
pass
tool_instruction = ""
if self.tools:
tool_descs = []
for t in self.tools:
desc = getattr(t, '__name__', str(t))
if hasattr(t, '__doc__') and t.__doc__:
desc += f": {t.__doc__.strip()}"
tool_descs.append(f"- {desc}")
tool_instruction = (
"\n\n系统为您提供了以下工具。由于当前处于结构化降级模式,无法原生调用。"
"但如果您在思考过程中判断必须使用这些工具,请在返回的结构中(或如果是自由文本)注明意图,由外层逻辑进行调度:\n" +
"\n".join(tool_descs)
)
self.agent = Agent(
model=model,
name=name,
output_type=str, # Force native agent to return str to disable function calling
system_prompt=system_prompt + format_instruction + tool_instruction,
deps_type=deps_type,
**kwargs
)
def _parse_output(self, text: str) -> Any:
if not self.has_custom_output:
return text
match = re.search(r'```json\s*(.*?)\s*```', text, re.DOTALL)
json_str = match.group(1).strip() if match else text
if not json_str.startswith('{') and not json_str.startswith('['):
start_obj = json_str.find('{')
start_arr = json_str.find('[')
start = -1
end = -1
if start_obj != -1 and (start_arr == -1 or start_obj < start_arr):
start = start_obj
end = json_str.rfind('}')
elif start_arr != -1:
start = start_arr
end = json_str.rfind(']')
if start != -1 and end != -1 and end > start:
json_str = json_str[start:end+1]
if not json_str:
raise ValueError("未找到有效的 JSON 块。请将结果包装在 ```json 中。")
try:
if hasattr(self.output_schema, 'model_validate_json'):
return self.output_schema.model_validate_json(json_str)
else:
return json.loads(json_str)
except ValidationError as e:
raise ValueError(f"返回的 JSON 无法匹配所需结构:{e}")
except json.JSONDecodeError as e:
raise ValueError(f"返回的不是合法的 JSON{e}")
async def run(self, user_prompt: str, deps: Any = None, message_history: list = None, **kwargs) -> Any:
# Custom retry loop
current_history = message_history or []
last_exception = None
for attempt in range(self.retries + 1):
result = await self.agent.run(
user_prompt,
deps=deps,
message_history=current_history,
**kwargs
)
raw_text = result.data if hasattr(result, 'data') else getattr(result, 'output', str(result))
try:
parsed_data = self._parse_output(raw_text)
# Proxy the result to inject the parsed data seamlessly
class AgentRunResultProxy:
def __init__(self, original, parsed):
self._original = original
self._parsed = parsed
def __getattr__(self, name):
if name == 'data':
return self._parsed
if name == 'output':
return self._parsed
return getattr(self._original, name)
return AgentRunResultProxy(result, parsed_data)
except ValueError as e:
last_exception = e
# Prepare retry prompt
user_prompt = f"你的上一次输出解析失败,错误原因是: {e}\n请修正格式后重新输出。"
# We need to maintain history manually so the model sees what it did wrong
# Actually, pydantic-ai manages history inside the result. Let's use the all_messages from result
if hasattr(result, 'all_messages'):
current_history = result.all_messages()
raise ValueError(f"Exceeded maximum retries ({self.retries}) for output validation. Last error: {last_exception}")