Compare commits
3 Commits
121d13d7f4
...
e706c3352e
| Author | SHA1 | Date |
|---|---|---|
|
|
e706c3352e | |
|
|
c1af32e604 | |
|
|
5694a30ca8 |
|
|
@ -182,6 +182,7 @@ export function ProvidersSettings() {
|
|||
>
|
||||
<option value="openai">OpenAI</option>
|
||||
<option value="gemini">Gemini</option>
|
||||
<option value="deepseek">DeepSeek</option>
|
||||
<option value="claude">Claude</option>
|
||||
<option value="local">Local</option>
|
||||
</select>
|
||||
|
|
|
|||
|
|
@ -14,7 +14,7 @@ export interface User {
|
|||
|
||||
// Provider types
|
||||
export interface Provider {
|
||||
provider_type: 'openai' | 'gemini' | 'claude' | 'local';
|
||||
provider_type: 'openai' | 'gemini' | 'claude' | 'local' | 'deepseek';
|
||||
provider_title: string;
|
||||
provider_url?: string;
|
||||
provider_owner?: string;
|
||||
|
|
@ -25,7 +25,7 @@ export interface Provider {
|
|||
}
|
||||
|
||||
export interface ProviderRegisterRequest {
|
||||
provider_type: 'openai' | 'gemini' | 'claude' | 'local';
|
||||
provider_type: 'openai' | 'gemini' | 'claude' | 'local' | 'deepseek';
|
||||
provider_title: string;
|
||||
provider_url: string;
|
||||
provider_apikey: string;
|
||||
|
|
|
|||
|
|
@ -67,6 +67,8 @@ class AgentFactory:
|
|||
output_type=output_type,
|
||||
deps_type=deps_type,
|
||||
system_prompt=system_prompt,
|
||||
tools=tools,
|
||||
retries=3,
|
||||
)
|
||||
case _:
|
||||
agent = Agent(model=model,
|
||||
|
|
|
|||
|
|
@ -1,60 +1,149 @@
|
|||
import re
|
||||
import os
|
||||
import json
|
||||
from typing import Type, TypeVar, Any, Generic
|
||||
from pydantic import BaseModel, ValidationError
|
||||
from pydantic_ai import Agent, RunContext
|
||||
from pydantic_ai.run import AgentRunResult
|
||||
|
||||
T = TypeVar('T', bound=BaseModel)
|
||||
|
||||
|
||||
class DeepSeekReasonerAgent(Generic[T]):
|
||||
"""
|
||||
专为 DeepSeek-V4/R1 设计的适配器。
|
||||
将结构化输出降级为文本解析模式,以规避工具调用(Tool Calling)的兼容性问题。
|
||||
将结构化输出降级为文本解析模式,并支持重试逻辑以确保系统兼容性。
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model,
|
||||
name,
|
||||
output_type: Type[T] = str,
|
||||
output_type: Any = str,
|
||||
system_prompt: str = "",
|
||||
deps_type: Type[Any] = None
|
||||
deps_type: Type[Any] = None,
|
||||
tools: list = None,
|
||||
retries: int = 3,
|
||||
**kwargs
|
||||
):
|
||||
# 1. 强制声明输出为 str,确保底层不发送 tools 字段
|
||||
self.output_schema = output_type
|
||||
self.has_custom_output = output_type is not str and output_type is not None
|
||||
self.tools = tools or []
|
||||
self.retries = retries
|
||||
|
||||
format_instruction = ""
|
||||
if self.has_custom_output:
|
||||
try:
|
||||
if hasattr(self.output_schema, 'model_json_schema'):
|
||||
schema_str = json.dumps(self.output_schema.model_json_schema(), ensure_ascii=False)
|
||||
else:
|
||||
# Don't inject <class 'dict'> into prompt
|
||||
schema_str = str(getattr(self.output_schema, '__name__', str(self.output_schema)))
|
||||
|
||||
# 2. 注入强制格式指令到 System Prompt
|
||||
format_instruction = (
|
||||
f"\n\nCRITICAL: 你必须输出且只能输出一段纯 JSON 格式的数据,"
|
||||
f"并包裹在 ```json 和 ``` 之间。格式必须符合以下 Pydantic 模型结构:\n"
|
||||
f"{self.output_schema.model_json_schema()}"
|
||||
f"并包裹在 ```json 和 ``` 之间。格式必须符合以下 JSON Schema 结构(或对应数据类型):\n"
|
||||
f"{schema_str}"
|
||||
)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
tool_instruction = ""
|
||||
if self.tools:
|
||||
tool_descs = []
|
||||
for t in self.tools:
|
||||
desc = getattr(t, '__name__', str(t))
|
||||
if hasattr(t, '__doc__') and t.__doc__:
|
||||
desc += f": {t.__doc__.strip()}"
|
||||
tool_descs.append(f"- {desc}")
|
||||
tool_instruction = (
|
||||
"\n\n系统为您提供了以下工具。由于当前处于结构化降级模式,无法原生调用。"
|
||||
"但如果您在思考过程中判断必须使用这些工具,请在返回的结构中(或如果是自由文本)注明意图,由外层逻辑进行调度:\n" +
|
||||
"\n".join(tool_descs)
|
||||
)
|
||||
|
||||
self.agent = Agent(
|
||||
model=model,
|
||||
name=name,
|
||||
output_type=str, # 内部通信用字符串
|
||||
system_prompt=system_prompt + format_instruction,
|
||||
output_type=str, # Force native agent to return str to disable function calling
|
||||
system_prompt=system_prompt + format_instruction + tool_instruction,
|
||||
deps_type=deps_type,
|
||||
retries=0
|
||||
**kwargs
|
||||
)
|
||||
|
||||
async def run(self, user_prompt: str, deps: Any = None) -> T:
|
||||
# 调用 PydanticAI 原生 run
|
||||
result = await self.agent.run(user_prompt, deps=deps)
|
||||
return self._parse_json(result.output)
|
||||
def _parse_output(self, text: str) -> Any:
|
||||
if not self.has_custom_output:
|
||||
return text
|
||||
|
||||
def _parse_json(self, text: str) -> T:
|
||||
# 使用正则提取 JSON 块
|
||||
match = re.search(r'```json\s*(.*?)\s*```', text, re.DOTALL)
|
||||
json_str = match.group(1).strip() if match else text
|
||||
|
||||
# 如果正则没抓到,尝试寻找首尾大括号
|
||||
if not json_str.startswith('{'):
|
||||
json_str = text[text.find('{'):text.rfind('}') + 1]
|
||||
if not json_str.startswith('{') and not json_str.startswith('['):
|
||||
start_obj = json_str.find('{')
|
||||
start_arr = json_str.find('[')
|
||||
start = -1
|
||||
end = -1
|
||||
if start_obj != -1 and (start_arr == -1 or start_obj < start_arr):
|
||||
start = start_obj
|
||||
end = json_str.rfind('}')
|
||||
elif start_arr != -1:
|
||||
start = start_arr
|
||||
end = json_str.rfind(']')
|
||||
|
||||
if start != -1 and end != -1 and end > start:
|
||||
json_str = json_str[start:end+1]
|
||||
|
||||
if not json_str:
|
||||
raise ValueError("未找到有效的 JSON 块。请将结果包装在 ```json 中。")
|
||||
|
||||
try:
|
||||
if hasattr(self.output_schema, 'model_validate_json'):
|
||||
return self.output_schema.model_validate_json(json_str)
|
||||
else:
|
||||
return json.loads(json_str)
|
||||
except ValidationError as e:
|
||||
raise ValueError(f"DeepSeek 返回格式非法: {e}\n原文: {text}")
|
||||
raise ValueError(f"返回的 JSON 无法匹配所需结构:{e}")
|
||||
except json.JSONDecodeError as e:
|
||||
raise ValueError(f"返回的不是合法的 JSON:{e}")
|
||||
|
||||
async def run(self, user_prompt: str, deps: Any = None, message_history: list = None, **kwargs) -> Any:
|
||||
# Custom retry loop
|
||||
current_history = message_history or []
|
||||
last_exception = None
|
||||
|
||||
for attempt in range(self.retries + 1):
|
||||
result = await self.agent.run(
|
||||
user_prompt,
|
||||
deps=deps,
|
||||
message_history=current_history,
|
||||
**kwargs
|
||||
)
|
||||
|
||||
raw_text = result.data if hasattr(result, 'data') else getattr(result, 'output', str(result))
|
||||
|
||||
try:
|
||||
parsed_data = self._parse_output(raw_text)
|
||||
|
||||
# Proxy the result to inject the parsed data seamlessly
|
||||
class AgentRunResultProxy:
|
||||
def __init__(self, original, parsed):
|
||||
self._original = original
|
||||
self._parsed = parsed
|
||||
def __getattr__(self, name):
|
||||
if name == 'data':
|
||||
return self._parsed
|
||||
if name == 'output':
|
||||
return self._parsed
|
||||
return getattr(self._original, name)
|
||||
|
||||
return AgentRunResultProxy(result, parsed_data)
|
||||
|
||||
except ValueError as e:
|
||||
last_exception = e
|
||||
# Prepare retry prompt
|
||||
user_prompt = f"你的上一次输出解析失败,错误原因是: {e}\n请修正格式后重新输出。"
|
||||
|
||||
# We need to maintain history manually so the model sees what it did wrong
|
||||
# Actually, pydantic-ai manages history inside the result. Let's use the all_messages from result
|
||||
if hasattr(result, 'all_messages'):
|
||||
current_history = result.all_messages()
|
||||
|
||||
raise ValueError(f"Exceeded maximum retries ({self.retries}) for output validation. Last error: {last_exception}")
|
||||
|
|
|
|||
|
|
@ -25,7 +25,7 @@ from pretor.utils.ray_hook import ray_actor_hook
|
|||
provider_router = APIRouter(prefix="/api/v1/provider", tags=["provider"])
|
||||
|
||||
class ProviderRegister(BaseModel):
|
||||
provider_type: Literal["openai", "gemini", "claude"]
|
||||
provider_type: Literal["openai", "gemini", "claude", "deepseek"]
|
||||
provider_title: str
|
||||
provider_url: str
|
||||
provider_apikey: str
|
||||
|
|
|
|||
Loading…
Reference in New Issue