diff --git a/pretor/adapter/model_adapter/deepseek_reasoner.py b/pretor/adapter/model_adapter/deepseek_reasoner.py index ee3193c..131b0f8 100644 --- a/pretor/adapter/model_adapter/deepseek_reasoner.py +++ b/pretor/adapter/model_adapter/deepseek_reasoner.py @@ -7,6 +7,17 @@ from pydantic_ai.run import AgentRunResult T = TypeVar('T', bound=BaseModel) +class AgentRunResultProxy: + def __init__(self, original, parsed): + self._original = original + self._parsed = parsed + def __getattr__(self, name): + if name == 'data': + return self._parsed + if name == 'output': + return self._parsed + return getattr(self._original, name) + class DeepSeekReasonerAgent(Generic[T]): """ 专为 DeepSeek-V4/R1 设计的适配器。 @@ -32,12 +43,9 @@ class DeepSeekReasonerAgent(Generic[T]): format_instruction = "" if self.has_custom_output: try: - if hasattr(self.output_schema, 'model_json_schema'): - schema_str = json.dumps(self.output_schema.model_json_schema(), ensure_ascii=False) - else: - # Don't inject into prompt - schema_str = str(getattr(self.output_schema, '__name__', str(self.output_schema))) - + from pydantic import TypeAdapter + schema_dict = TypeAdapter(self.output_schema).json_schema() + schema_str = json.dumps(schema_dict, ensure_ascii=False) format_instruction = ( f"\n\nCRITICAL: 你必须输出且只能输出一段纯 JSON 格式的数据," f"并包裹在 ```json 和 ``` 之间。格式必须符合以下 JSON Schema 结构(或对应数据类型):\n" @@ -95,10 +103,9 @@ class DeepSeekReasonerAgent(Generic[T]): raise ValueError("未找到有效的 JSON 块。请将结果包装在 ```json 中。") try: - if hasattr(self.output_schema, 'model_validate_json'): - return self.output_schema.model_validate_json(json_str) - else: - return json.loads(json_str) + from pydantic import TypeAdapter + adapter = TypeAdapter(self.output_schema) + return adapter.validate_json(json_str) except ValidationError as e: raise ValueError(f"返回的 JSON 无法匹配所需结构:{e}") except json.JSONDecodeError as e: @@ -128,17 +135,6 @@ class DeepSeekReasonerAgent(Generic[T]): parsed_data = self._parse_output(raw_text) # Proxy the result to inject the parsed data seamlessly - class AgentRunResultProxy: - def __init__(self, original, parsed): - self._original = original - self._parsed = parsed - def __getattr__(self, name): - if name == 'data': - return self._parsed - if name == 'output': - return self._parsed - return getattr(self._original, name) - return AgentRunResultProxy(result, parsed_data) except ValueError as e: