179 lines
6.7 KiB
Python
179 lines
6.7 KiB
Python
# Copyright 2026 zhaoxi826
|
||
#
|
||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||
# you may not use this file except in compliance with the License.
|
||
# You may obtain a copy of the License at
|
||
#
|
||
# http://www.apache.org/licenses/LICENSE-2.0
|
||
#
|
||
# Unless required by applicable law or agreed to in writing, software
|
||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||
# See the License for the specific language governing permissions and
|
||
# limitations under the License.
|
||
|
||
import re
|
||
import json
|
||
from typing import Type, TypeVar, Any, Generic
|
||
from pydantic import BaseModel, ValidationError
|
||
from pydantic_ai import Agent
|
||
|
||
T = TypeVar("T", bound=BaseModel)
|
||
|
||
|
||
class AgentRunResultProxy:
|
||
"""``Agent.run`` 结果的轻量代理:把已解析的结构化对象暴露为 ``.data`` / ``.output``。"""
|
||
|
||
def __init__(self, original, parsed):
|
||
self._original = original
|
||
self._parsed = parsed
|
||
|
||
def __getattr__(self, name):
|
||
if name == "data":
|
||
return self._parsed
|
||
if name == "output":
|
||
return self._parsed
|
||
return getattr(self._original, name)
|
||
|
||
|
||
class DeepSeekReasonerAgent(Generic[T]):
|
||
"""
|
||
专为 DeepSeek-V4/R1 设计的适配器。
|
||
将结构化输出降级为文本解析模式,并支持重试逻辑以确保系统兼容性。
|
||
"""
|
||
|
||
def __init__(
|
||
self,
|
||
model,
|
||
name,
|
||
output_type: Any = str,
|
||
system_prompt: str = "",
|
||
deps_type: Type[Any] = None,
|
||
tools: list = None,
|
||
retries: int = 3,
|
||
**kwargs,
|
||
):
|
||
self.output_schema = output_type
|
||
self.has_custom_output = output_type is not str and output_type is not None
|
||
self.tools = tools or []
|
||
self.retries = retries
|
||
|
||
format_instruction = ""
|
||
if self.has_custom_output:
|
||
try:
|
||
from pydantic import TypeAdapter
|
||
|
||
schema_dict = TypeAdapter(self.output_schema).json_schema()
|
||
schema_str = json.dumps(schema_dict, ensure_ascii=False)
|
||
format_instruction = (
|
||
f"\n\nCRITICAL: 你必须输出且只能输出一段纯 JSON 格式的数据,"
|
||
f"并包裹在 ```json 和 ``` 之间。格式必须符合以下 JSON Schema 结构(或对应数据类型):\n"
|
||
f"{schema_str}"
|
||
)
|
||
except Exception:
|
||
pass
|
||
|
||
tool_instruction = ""
|
||
if self.tools:
|
||
tool_descs = []
|
||
for t in self.tools:
|
||
desc = getattr(t, "__name__", str(t))
|
||
if hasattr(t, "__doc__") and t.__doc__:
|
||
desc += f": {t.__doc__.strip()}"
|
||
tool_descs.append(f"- {desc}")
|
||
tool_instruction = (
|
||
"\n\n系统为您提供了以下工具。由于当前处于结构化降级模式,无法原生调用。"
|
||
"但如果您在思考过程中判断必须使用这些工具,请在返回的结构中(或如果是自由文本)注明意图,由外层逻辑进行调度:\n"
|
||
+ "\n".join(tool_descs)
|
||
)
|
||
|
||
self.agent = Agent(
|
||
model=model,
|
||
name=name,
|
||
output_type=str, # Force native agent to return str to disable function calling
|
||
system_prompt=system_prompt + format_instruction + tool_instruction,
|
||
deps_type=deps_type,
|
||
**kwargs,
|
||
)
|
||
|
||
def _parse_output(self, text: str) -> Any:
|
||
"""从模型自由文本中抽取 ```json 块并按 ``output_schema`` 校验为对象。"""
|
||
if not self.has_custom_output:
|
||
return text
|
||
|
||
match = re.search(r"```json\s*(.*?)\s*```", text, re.DOTALL)
|
||
json_str = match.group(1).strip() if match else text
|
||
|
||
if not json_str.startswith("{") and not json_str.startswith("["):
|
||
start_obj = json_str.find("{")
|
||
start_arr = json_str.find("[")
|
||
start = -1
|
||
end = -1
|
||
if start_obj != -1 and (start_arr == -1 or start_obj < start_arr):
|
||
start = start_obj
|
||
end = json_str.rfind("}")
|
||
elif start_arr != -1:
|
||
start = start_arr
|
||
end = json_str.rfind("]")
|
||
|
||
if start != -1 and end != -1 and end > start:
|
||
json_str = json_str[start : end + 1]
|
||
|
||
if not json_str:
|
||
raise ValueError("未找到有效的 JSON 块。请将结果包装在 ```json 中。")
|
||
|
||
try:
|
||
from pydantic import TypeAdapter
|
||
|
||
adapter = TypeAdapter(self.output_schema)
|
||
return adapter.validate_json(json_str)
|
||
except ValidationError as e:
|
||
raise ValueError(f"返回的 JSON 无法匹配所需结构:{e}")
|
||
except json.JSONDecodeError as e:
|
||
raise ValueError(f"返回的不是合法的 JSON:{e}")
|
||
|
||
def __getattr__(self, item):
|
||
# Delegate any unknown attributes (like .system_prompt, .tool) to the underlying pydantic_ai Agent
|
||
return getattr(self.agent, item)
|
||
|
||
async def run(
|
||
self, user_prompt: str, deps: Any = None, message_history: list = None, **kwargs
|
||
) -> Any:
|
||
# Custom retry loop
|
||
"""运行一次 deepseek-reasoner 推理:失败时根据错误反馈让模型重试,最多 ``self.retries`` 轮。"""
|
||
current_history = message_history or []
|
||
last_exception = None
|
||
|
||
for attempt in range(self.retries + 1):
|
||
result = await self.agent.run(
|
||
user_prompt, deps=deps, message_history=current_history, **kwargs
|
||
)
|
||
|
||
raw_text = (
|
||
result.data
|
||
if hasattr(result, "data")
|
||
else getattr(result, "output", str(result))
|
||
)
|
||
|
||
try:
|
||
parsed_data = self._parse_output(raw_text)
|
||
|
||
# Proxy the result to inject the parsed data seamlessly
|
||
return AgentRunResultProxy(result, parsed_data)
|
||
|
||
except ValueError as e:
|
||
last_exception = e
|
||
# Prepare retry prompt
|
||
user_prompt = (
|
||
f"你的上一次输出解析失败,错误原因是: {e}\n请修正格式后重新输出。"
|
||
)
|
||
|
||
# We need to maintain history manually so the model sees what it did wrong
|
||
# Actually, pydantic-ai manages history inside the result. Let's use the all_messages from result
|
||
if hasattr(result, "all_messages"):
|
||
current_history = result.all_messages()
|
||
|
||
raise ValueError(
|
||
f"Exceeded maximum retries ({self.retries}) for output validation. Last error: {last_exception}"
|
||
)
|