diff --git a/pretor/adapter/model_adapter/agent_factory.py b/pretor/adapter/model_adapter/agent_factory.py index 328c854..db281a3 100644 --- a/pretor/adapter/model_adapter/agent_factory.py +++ b/pretor/adapter/model_adapter/agent_factory.py @@ -67,6 +67,8 @@ class AgentFactory: output_type=output_type, deps_type=deps_type, system_prompt=system_prompt, + tools=tools, + retries=3, ) case _: agent = Agent(model=model, diff --git a/pretor/adapter/model_adapter/deepseek_reasoner.py b/pretor/adapter/model_adapter/deepseek_reasoner.py index a20ec60..753354c 100644 --- a/pretor/adapter/model_adapter/deepseek_reasoner.py +++ b/pretor/adapter/model_adapter/deepseek_reasoner.py @@ -1,60 +1,149 @@ import re -import os +import json from typing import Type, TypeVar, Any, Generic from pydantic import BaseModel, ValidationError from pydantic_ai import Agent, RunContext +from pydantic_ai.run import AgentRunResult T = TypeVar('T', bound=BaseModel) - class DeepSeekReasonerAgent(Generic[T]): """ 专为 DeepSeek-V4/R1 设计的适配器。 - 将结构化输出降级为文本解析模式,以规避工具调用(Tool Calling)的兼容性问题。 + 将结构化输出降级为文本解析模式,并支持重试逻辑以确保系统兼容性。 """ def __init__( self, model, name, - output_type: Type[T] = str, + output_type: Any = str, system_prompt: str = "", - deps_type: Type[Any] = None + deps_type: Type[Any] = None, + tools: list = None, + retries: int = 3, + **kwargs ): - # 1. 强制声明输出为 str,确保底层不发送 tools 字段 self.output_schema = output_type + self.has_custom_output = output_type is not str and output_type is not None + self.tools = tools or [] + self.retries = retries - # 2. 注入强制格式指令到 System Prompt - format_instruction = ( - f"\n\nCRITICAL: 你必须输出且只能输出一段纯 JSON 格式的数据," - f"并包裹在 ```json 和 ``` 之间。格式必须符合以下 Pydantic 模型结构:\n" - f"{self.output_schema.model_json_schema()}" - ) + format_instruction = "" + if self.has_custom_output: + try: + if hasattr(self.output_schema, 'model_json_schema'): + schema_str = json.dumps(self.output_schema.model_json_schema(), ensure_ascii=False) + else: + # Don't inject into prompt + schema_str = str(getattr(self.output_schema, '__name__', str(self.output_schema))) + + format_instruction = ( + f"\n\nCRITICAL: 你必须输出且只能输出一段纯 JSON 格式的数据," + f"并包裹在 ```json 和 ``` 之间。格式必须符合以下 JSON Schema 结构(或对应数据类型):\n" + f"{schema_str}" + ) + except Exception: + pass + + tool_instruction = "" + if self.tools: + tool_descs = [] + for t in self.tools: + desc = getattr(t, '__name__', str(t)) + if hasattr(t, '__doc__') and t.__doc__: + desc += f": {t.__doc__.strip()}" + tool_descs.append(f"- {desc}") + tool_instruction = ( + "\n\n系统为您提供了以下工具。由于当前处于结构化降级模式,无法原生调用。" + "但如果您在思考过程中判断必须使用这些工具,请在返回的结构中(或如果是自由文本)注明意图,由外层逻辑进行调度:\n" + + "\n".join(tool_descs) + ) self.agent = Agent( model=model, name=name, - output_type=str, # 内部通信用字符串 - system_prompt=system_prompt + format_instruction, + output_type=str, # Force native agent to return str to disable function calling + system_prompt=system_prompt + format_instruction + tool_instruction, deps_type=deps_type, - retries=0 + **kwargs ) - async def run(self, user_prompt: str, deps: Any = None) -> T: - # 调用 PydanticAI 原生 run - result = await self.agent.run(user_prompt, deps=deps) - return self._parse_json(result.output) + def _parse_output(self, text: str) -> Any: + if not self.has_custom_output: + return text - def _parse_json(self, text: str) -> T: - # 使用正则提取 JSON 块 match = re.search(r'```json\s*(.*?)\s*```', text, re.DOTALL) json_str = match.group(1).strip() if match else text - # 如果正则没抓到,尝试寻找首尾大括号 - if not json_str.startswith('{'): - json_str = text[text.find('{'):text.rfind('}') + 1] + if not json_str.startswith('{') and not json_str.startswith('['): + start_obj = json_str.find('{') + start_arr = json_str.find('[') + start = -1 + end = -1 + if start_obj != -1 and (start_arr == -1 or start_obj < start_arr): + start = start_obj + end = json_str.rfind('}') + elif start_arr != -1: + start = start_arr + end = json_str.rfind(']') + + if start != -1 and end != -1 and end > start: + json_str = json_str[start:end+1] + + if not json_str: + raise ValueError("未找到有效的 JSON 块。请将结果包装在 ```json 中。") try: - return self.output_schema.model_validate_json(json_str) + if hasattr(self.output_schema, 'model_validate_json'): + return self.output_schema.model_validate_json(json_str) + else: + return json.loads(json_str) except ValidationError as e: - raise ValueError(f"DeepSeek 返回格式非法: {e}\n原文: {text}") \ No newline at end of file + raise ValueError(f"返回的 JSON 无法匹配所需结构:{e}") + except json.JSONDecodeError as e: + raise ValueError(f"返回的不是合法的 JSON:{e}") + + async def run(self, user_prompt: str, deps: Any = None, message_history: list = None, **kwargs) -> Any: + # Custom retry loop + current_history = message_history or [] + last_exception = None + + for attempt in range(self.retries + 1): + result = await self.agent.run( + user_prompt, + deps=deps, + message_history=current_history, + **kwargs + ) + + raw_text = result.data if hasattr(result, 'data') else getattr(result, 'output', str(result)) + + try: + parsed_data = self._parse_output(raw_text) + + # Proxy the result to inject the parsed data seamlessly + class AgentRunResultProxy: + def __init__(self, original, parsed): + self._original = original + self._parsed = parsed + def __getattr__(self, name): + if name == 'data': + return self._parsed + if name == 'output': + return self._parsed + return getattr(self._original, name) + + return AgentRunResultProxy(result, parsed_data) + + except ValueError as e: + last_exception = e + # Prepare retry prompt + user_prompt = f"你的上一次输出解析失败,错误原因是: {e}\n请修正格式后重新输出。" + + # We need to maintain history manually so the model sees what it did wrong + # Actually, pydantic-ai manages history inside the result. Let's use the all_messages from result + if hasattr(result, 'all_messages'): + current_history = result.all_messages() + + raise ValueError(f"Exceeded maximum retries ({self.retries}) for output validation. Last error: {last_exception}")