Compare commits
6 Commits
e706c3352e
...
65652399e0
| Author | SHA1 | Date |
|---|---|---|
|
|
65652399e0 | |
|
|
fa22999d8c | |
|
|
9d7b980769 | |
|
|
600f7c42ab | |
|
|
b8f0372a7f | |
|
|
4a0679fe2c |
|
|
@ -181,7 +181,6 @@ export function ProvidersSettings() {
|
||||||
className="w-full bg-slate-50 border border-slate-200 text-sm rounded-lg px-3 py-2.5 focus:outline-none focus:ring-2 focus:ring-blue-500/20 focus:border-blue-500 transition-all cursor-pointer"
|
className="w-full bg-slate-50 border border-slate-200 text-sm rounded-lg px-3 py-2.5 focus:outline-none focus:ring-2 focus:ring-blue-500/20 focus:border-blue-500 transition-all cursor-pointer"
|
||||||
>
|
>
|
||||||
<option value="openai">OpenAI</option>
|
<option value="openai">OpenAI</option>
|
||||||
<option value="gemini">Gemini</option>
|
|
||||||
<option value="deepseek">DeepSeek</option>
|
<option value="deepseek">DeepSeek</option>
|
||||||
<option value="claude">Claude</option>
|
<option value="claude">Claude</option>
|
||||||
<option value="local">Local</option>
|
<option value="local">Local</option>
|
||||||
|
|
|
||||||
|
|
@ -14,7 +14,7 @@ export interface User {
|
||||||
|
|
||||||
// Provider types
|
// Provider types
|
||||||
export interface Provider {
|
export interface Provider {
|
||||||
provider_type: 'openai' | 'gemini' | 'claude' | 'local' | 'deepseek';
|
provider_type: 'openai' | 'claude' | 'local' | 'deepseek';
|
||||||
provider_title: string;
|
provider_title: string;
|
||||||
provider_url?: string;
|
provider_url?: string;
|
||||||
provider_owner?: string;
|
provider_owner?: string;
|
||||||
|
|
@ -25,7 +25,7 @@ export interface Provider {
|
||||||
}
|
}
|
||||||
|
|
||||||
export interface ProviderRegisterRequest {
|
export interface ProviderRegisterRequest {
|
||||||
provider_type: 'openai' | 'gemini' | 'claude' | 'local' | 'deepseek';
|
provider_type: 'openai' | 'claude' | 'local' | 'deepseek';
|
||||||
provider_title: string;
|
provider_title: string;
|
||||||
provider_url: string;
|
provider_url: string;
|
||||||
provider_apikey: string;
|
provider_apikey: string;
|
||||||
|
|
|
||||||
|
|
@ -14,10 +14,8 @@
|
||||||
|
|
||||||
from pydantic_ai import Agent
|
from pydantic_ai import Agent
|
||||||
from pydantic_ai.models.openai import OpenAIChatModel
|
from pydantic_ai.models.openai import OpenAIChatModel
|
||||||
from pydantic_ai.models.google import GoogleModel
|
|
||||||
from pydantic_ai.models.anthropic import AnthropicModel
|
from pydantic_ai.models.anthropic import AnthropicModel
|
||||||
from pydantic_ai.providers.openai import OpenAIProvider
|
from pydantic_ai.providers.openai import OpenAIProvider
|
||||||
from pydantic_ai.providers.google import GoogleProvider
|
|
||||||
from pydantic_ai.providers.anthropic import AnthropicProvider
|
from pydantic_ai.providers.anthropic import AnthropicProvider
|
||||||
from pretor.adapter.model_adapter.deepseek_reasoner import DeepSeekReasonerAgent
|
from pretor.adapter.model_adapter.deepseek_reasoner import DeepSeekReasonerAgent
|
||||||
from pretor.core.global_state_machine.model_provider import Provider
|
from pretor.core.global_state_machine.model_provider import Provider
|
||||||
|
|
@ -27,7 +25,6 @@ from pretor.utils.error import ModelNotExistError
|
||||||
class AgentFactory:
|
class AgentFactory:
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
self._models_mapping = {"openai": (OpenAIChatModel, OpenAIProvider),
|
self._models_mapping = {"openai": (OpenAIChatModel, OpenAIProvider),
|
||||||
"gemini": (GoogleModel, GoogleProvider),
|
|
||||||
"claude": (AnthropicModel, AnthropicProvider),
|
"claude": (AnthropicModel, AnthropicProvider),
|
||||||
"deepseek": (OpenAIChatModel, OpenAIProvider),}
|
"deepseek": (OpenAIChatModel, OpenAIProvider),}
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -7,6 +7,17 @@ from pydantic_ai.run import AgentRunResult
|
||||||
|
|
||||||
T = TypeVar('T', bound=BaseModel)
|
T = TypeVar('T', bound=BaseModel)
|
||||||
|
|
||||||
|
class AgentRunResultProxy:
|
||||||
|
def __init__(self, original, parsed):
|
||||||
|
self._original = original
|
||||||
|
self._parsed = parsed
|
||||||
|
def __getattr__(self, name):
|
||||||
|
if name == 'data':
|
||||||
|
return self._parsed
|
||||||
|
if name == 'output':
|
||||||
|
return self._parsed
|
||||||
|
return getattr(self._original, name)
|
||||||
|
|
||||||
class DeepSeekReasonerAgent(Generic[T]):
|
class DeepSeekReasonerAgent(Generic[T]):
|
||||||
"""
|
"""
|
||||||
专为 DeepSeek-V4/R1 设计的适配器。
|
专为 DeepSeek-V4/R1 设计的适配器。
|
||||||
|
|
@ -32,12 +43,9 @@ class DeepSeekReasonerAgent(Generic[T]):
|
||||||
format_instruction = ""
|
format_instruction = ""
|
||||||
if self.has_custom_output:
|
if self.has_custom_output:
|
||||||
try:
|
try:
|
||||||
if hasattr(self.output_schema, 'model_json_schema'):
|
from pydantic import TypeAdapter
|
||||||
schema_str = json.dumps(self.output_schema.model_json_schema(), ensure_ascii=False)
|
schema_dict = TypeAdapter(self.output_schema).json_schema()
|
||||||
else:
|
schema_str = json.dumps(schema_dict, ensure_ascii=False)
|
||||||
# Don't inject <class 'dict'> into prompt
|
|
||||||
schema_str = str(getattr(self.output_schema, '__name__', str(self.output_schema)))
|
|
||||||
|
|
||||||
format_instruction = (
|
format_instruction = (
|
||||||
f"\n\nCRITICAL: 你必须输出且只能输出一段纯 JSON 格式的数据,"
|
f"\n\nCRITICAL: 你必须输出且只能输出一段纯 JSON 格式的数据,"
|
||||||
f"并包裹在 ```json 和 ``` 之间。格式必须符合以下 JSON Schema 结构(或对应数据类型):\n"
|
f"并包裹在 ```json 和 ``` 之间。格式必须符合以下 JSON Schema 结构(或对应数据类型):\n"
|
||||||
|
|
@ -95,15 +103,19 @@ class DeepSeekReasonerAgent(Generic[T]):
|
||||||
raise ValueError("未找到有效的 JSON 块。请将结果包装在 ```json 中。")
|
raise ValueError("未找到有效的 JSON 块。请将结果包装在 ```json 中。")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
if hasattr(self.output_schema, 'model_validate_json'):
|
from pydantic import TypeAdapter
|
||||||
return self.output_schema.model_validate_json(json_str)
|
adapter = TypeAdapter(self.output_schema)
|
||||||
else:
|
return adapter.validate_json(json_str)
|
||||||
return json.loads(json_str)
|
|
||||||
except ValidationError as e:
|
except ValidationError as e:
|
||||||
raise ValueError(f"返回的 JSON 无法匹配所需结构:{e}")
|
raise ValueError(f"返回的 JSON 无法匹配所需结构:{e}")
|
||||||
except json.JSONDecodeError as e:
|
except json.JSONDecodeError as e:
|
||||||
raise ValueError(f"返回的不是合法的 JSON:{e}")
|
raise ValueError(f"返回的不是合法的 JSON:{e}")
|
||||||
|
|
||||||
|
|
||||||
|
def __getattr__(self, item):
|
||||||
|
# Delegate any unknown attributes (like .system_prompt, .tool) to the underlying pydantic_ai Agent
|
||||||
|
return getattr(self.agent, item)
|
||||||
|
|
||||||
async def run(self, user_prompt: str, deps: Any = None, message_history: list = None, **kwargs) -> Any:
|
async def run(self, user_prompt: str, deps: Any = None, message_history: list = None, **kwargs) -> Any:
|
||||||
# Custom retry loop
|
# Custom retry loop
|
||||||
current_history = message_history or []
|
current_history = message_history or []
|
||||||
|
|
@ -123,17 +135,6 @@ class DeepSeekReasonerAgent(Generic[T]):
|
||||||
parsed_data = self._parse_output(raw_text)
|
parsed_data = self._parse_output(raw_text)
|
||||||
|
|
||||||
# Proxy the result to inject the parsed data seamlessly
|
# Proxy the result to inject the parsed data seamlessly
|
||||||
class AgentRunResultProxy:
|
|
||||||
def __init__(self, original, parsed):
|
|
||||||
self._original = original
|
|
||||||
self._parsed = parsed
|
|
||||||
def __getattr__(self, name):
|
|
||||||
if name == 'data':
|
|
||||||
return self._parsed
|
|
||||||
if name == 'output':
|
|
||||||
return self._parsed
|
|
||||||
return getattr(self._original, name)
|
|
||||||
|
|
||||||
return AgentRunResultProxy(result, parsed_data)
|
return AgentRunResultProxy(result, parsed_data)
|
||||||
|
|
||||||
except ValueError as e:
|
except ValueError as e:
|
||||||
|
|
|
||||||
|
|
@ -25,7 +25,7 @@ from pretor.utils.ray_hook import ray_actor_hook
|
||||||
provider_router = APIRouter(prefix="/api/v1/provider", tags=["provider"])
|
provider_router = APIRouter(prefix="/api/v1/provider", tags=["provider"])
|
||||||
|
|
||||||
class ProviderRegister(BaseModel):
|
class ProviderRegister(BaseModel):
|
||||||
provider_type: Literal["openai", "gemini", "claude", "deepseek"]
|
provider_type: Literal["openai", "claude", "deepseek"]
|
||||||
provider_title: str
|
provider_title: str
|
||||||
provider_url: str
|
provider_url: str
|
||||||
provider_apikey: str
|
provider_apikey: str
|
||||||
|
|
|
||||||
|
|
@ -14,6 +14,6 @@
|
||||||
|
|
||||||
from pretor.core.global_state_machine.model_provider.base_provider import Provider, ProviderArgs
|
from pretor.core.global_state_machine.model_provider.base_provider import Provider, ProviderArgs
|
||||||
from pretor.core.global_state_machine.model_provider.openai_provider import OpenAIProvider
|
from pretor.core.global_state_machine.model_provider.openai_provider import OpenAIProvider
|
||||||
from pretor.core.global_state_machine.model_provider.gemini_provider import GeminiProvider
|
|
||||||
from pretor.core.global_state_machine.model_provider.claude_provider import ClaudeProvider
|
from pretor.core.global_state_machine.model_provider.claude_provider import ClaudeProvider
|
||||||
__all__ = ["Provider", "ProviderArgs", "OpenAIProvider", "GeminiProvider", "ClaudeProvider"]
|
from pretor.core.global_state_machine.model_provider.deepseek_provider import DeepseekProvider
|
||||||
|
__all__ = ["Provider", "ProviderArgs", "OpenAIProvider", "ClaudeProvider", "DeepseekProvider"]
|
||||||
|
|
|
||||||
|
|
@ -1,65 +0,0 @@
|
||||||
from pretor.utils.retry import retry_on_retryable_error
|
|
||||||
# Copyright 2026 zhaoxi826
|
|
||||||
#
|
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
# you may not use this file except in compliance with the License.
|
|
||||||
# You may obtain a copy of the License at
|
|
||||||
#
|
|
||||||
# http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
#
|
|
||||||
# Unless required by applicable law or agreed to in writing, software
|
|
||||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
# See the License for the specific language governing permissions and
|
|
||||||
# limitations under the License.
|
|
||||||
|
|
||||||
from pretor.core.global_state_machine.model_provider.base_provider import BaseProvider, Provider, ProviderArgs
|
|
||||||
import httpx
|
|
||||||
from typing import List
|
|
||||||
|
|
||||||
class GeminiProvider(BaseProvider):
|
|
||||||
@staticmethod
|
|
||||||
async def create_provider(provider_args: ProviderArgs) -> Provider:
|
|
||||||
provider_models: List[str] = await GeminiProvider._load_models(provider_args)
|
|
||||||
provider: Provider = GeminiProvider._return_provider(provider_args, provider_models)
|
|
||||||
return provider
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
@retry_on_retryable_error()
|
|
||||||
async def _load_models(provider_args: ProviderArgs) -> List[str]:
|
|
||||||
# Google Gemini 原生鉴权通常使用 x-goog-api-key 或 query parameter
|
|
||||||
headers = {
|
|
||||||
"x-goog-api-key": provider_args.provider_apikey,
|
|
||||||
"Content-Type": "application/json"
|
|
||||||
}
|
|
||||||
# 官方路径通常是 v1beta/models
|
|
||||||
url = f"{provider_args.provider_url.rstrip('/')}/v1beta/models"
|
|
||||||
|
|
||||||
try:
|
|
||||||
async with httpx.AsyncClient(timeout=10.0) as client:
|
|
||||||
response = await client.get(url, headers=headers)
|
|
||||||
if response.status_code != 200:
|
|
||||||
print(f"[{provider_args.provider_title}] 获取 Gemini 模型失败: {response.status_code}")
|
|
||||||
return []
|
|
||||||
|
|
||||||
data = response.json()
|
|
||||||
# Gemini 返回的结构中模型 ID 通常带 "models/" 前缀
|
|
||||||
raw_models = data.get("models", [])
|
|
||||||
model_ids = [m["name"].split("/")[-1] for m in raw_models if
|
|
||||||
"generateContent" in m.get("supportedGenerationMethods", [])]
|
|
||||||
return sorted(list(set(model_ids)))
|
|
||||||
except httpx.RequestError as e:
|
|
||||||
from pretor.utils.error import RetryableError
|
|
||||||
print(f"[{provider_args.provider_title}] 网络请求异常: {e}")
|
|
||||||
raise RetryableError(f"[{provider_args.provider_title}] 网络请求异常: {e}") from e
|
|
||||||
except Exception as e:
|
|
||||||
print(f"[{provider_args.provider_title}] 获取 Gemini 模型列表错误: {e}")
|
|
||||||
return []
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def _return_provider(provider_args: ProviderArgs, provider_models: List[str]) -> Provider:
|
|
||||||
return Provider(provider_title=provider_args.provider_title,
|
|
||||||
provider_apikey=provider_args.provider_apikey,
|
|
||||||
provider_url=provider_args.provider_url,
|
|
||||||
provider_models=provider_models,
|
|
||||||
provider_type="gemini")
|
|
||||||
|
|
@ -12,7 +12,7 @@
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
from pretor.core.global_state_machine.model_provider import Provider, OpenAIProvider,GeminiProvider, ClaudeProvider
|
from pretor.core.global_state_machine.model_provider import Provider, OpenAIProvider, ClaudeProvider, DeepseekProvider
|
||||||
from typing import Dict, Type
|
from typing import Dict, Type
|
||||||
|
|
||||||
class ProviderManager:
|
class ProviderManager:
|
||||||
|
|
@ -28,8 +28,8 @@ class ProviderManager:
|
||||||
"""供应商注册表:键为用户自定义别名,值为已实例化的 Provider 对象。"""
|
"""供应商注册表:键为用户自定义别名,值为已实例化的 Provider 对象。"""
|
||||||
def __init__(self, postgres):
|
def __init__(self, postgres):
|
||||||
self.provider_mapper = {"openai": OpenAIProvider,
|
self.provider_mapper = {"openai": OpenAIProvider,
|
||||||
"gemini": GeminiProvider,
|
"claude": ClaudeProvider,
|
||||||
"claude": ClaudeProvider}
|
"deepseek": DeepseekProvider}
|
||||||
self.provider_register = {}
|
self.provider_register = {}
|
||||||
|
|
||||||
async def init_provider_register(self, postgres) -> None:
|
async def init_provider_register(self, postgres) -> None:
|
||||||
|
|
|
||||||
|
|
@ -111,7 +111,7 @@ class SupervisoryNode:
|
||||||
if isinstance(payload, PretorEvent):
|
if isinstance(payload, PretorEvent):
|
||||||
payload.context["workflow_template"] = result.workflow_template
|
payload.context["workflow_template"] = result.workflow_template
|
||||||
try:
|
try:
|
||||||
workflow_running_engine = ray_actor_hook("workflow_running_engine")
|
workflow_running_engine = ray_actor_hook("workflow_running_engine").workflow_running_engine
|
||||||
await workflow_running_engine.put_event.remote(payload)
|
await workflow_running_engine.put_event.remote(payload)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
self.logger.error(f"SupervisoryNode: 无法将事件放入 WorkflowRunningEngine: {e}")
|
self.logger.error(f"SupervisoryNode: 无法将事件放入 WorkflowRunningEngine: {e}")
|
||||||
|
|
|
||||||
|
|
@ -1,69 +0,0 @@
|
||||||
import pytest
|
|
||||||
from unittest.mock import patch, MagicMock, AsyncMock
|
|
||||||
from pretor.core.global_state_machine.model_provider.gemini_provider import GeminiProvider, ProviderArgs
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
|
||||||
def provider_args():
|
|
||||||
return ProviderArgs(
|
|
||||||
provider_title="TestGemini",
|
|
||||||
provider_url="https://generativelanguage.googleapis.com",
|
|
||||||
provider_apikey="testkey",
|
|
||||||
provider_owner="1"
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
@patch("pretor.core.global_state_machine.model_provider.gemini_provider.httpx.AsyncClient")
|
|
||||||
async def test_load_models_success(mock_client, provider_args):
|
|
||||||
mock_response = MagicMock()
|
|
||||||
mock_response.status_code = 200
|
|
||||||
mock_response.json.return_value = {
|
|
||||||
"models": [
|
|
||||||
{"name": "models/gemini-1.5-pro", "supportedGenerationMethods": ["generateContent"]},
|
|
||||||
{"name": "models/gemini-1.5-flash", "supportedGenerationMethods": ["generateContent"]},
|
|
||||||
{"name": "models/other", "supportedGenerationMethods": []}
|
|
||||||
]
|
|
||||||
}
|
|
||||||
|
|
||||||
mock_client_instance = AsyncMock()
|
|
||||||
mock_client_instance.get.return_value = mock_response
|
|
||||||
mock_client.return_value.__aenter__.return_value = mock_client_instance
|
|
||||||
|
|
||||||
models = await GeminiProvider._load_models(provider_args)
|
|
||||||
assert models == ["gemini-1.5-flash", "gemini-1.5-pro"]
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
@patch("pretor.core.global_state_machine.model_provider.gemini_provider.httpx.AsyncClient")
|
|
||||||
async def test_load_models_status_error(mock_client, provider_args):
|
|
||||||
mock_response = MagicMock()
|
|
||||||
mock_response.status_code = 401
|
|
||||||
|
|
||||||
mock_client_instance = AsyncMock()
|
|
||||||
mock_client_instance.get.return_value = mock_response
|
|
||||||
mock_client.return_value.__aenter__.return_value = mock_client_instance
|
|
||||||
|
|
||||||
models = await GeminiProvider._load_models(provider_args)
|
|
||||||
assert models == []
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
@patch("pretor.core.global_state_machine.model_provider.gemini_provider.httpx.AsyncClient")
|
|
||||||
async def test_load_models_error(mock_client, provider_args):
|
|
||||||
mock_client_instance = AsyncMock()
|
|
||||||
mock_client_instance.get.side_effect = Exception("network error")
|
|
||||||
mock_client.return_value.__aenter__.return_value = mock_client_instance
|
|
||||||
|
|
||||||
models = await GeminiProvider._load_models(provider_args)
|
|
||||||
assert models == []
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
@patch("pretor.core.global_state_machine.model_provider.gemini_provider.GeminiProvider._load_models",
|
|
||||||
return_value=["gemini-1"])
|
|
||||||
async def test_create_provider(mock_load, provider_args):
|
|
||||||
provider = await GeminiProvider.create_provider(provider_args)
|
|
||||||
assert provider.provider_title == "TestGemini"
|
|
||||||
assert provider.provider_models == ["gemini-1"]
|
|
||||||
assert provider.provider_type == "gemini"
|
|
||||||
|
|
@ -21,7 +21,6 @@ async def test_provider_manager_init():
|
||||||
await manager.init_provider_register(mock_postgres)
|
await manager.init_provider_register(mock_postgres)
|
||||||
|
|
||||||
assert "openai" in manager.provider_mapper
|
assert "openai" in manager.provider_mapper
|
||||||
assert "gemini" in manager.provider_mapper
|
|
||||||
assert "claude" in manager.provider_mapper
|
assert "claude" in manager.provider_mapper
|
||||||
|
|
||||||
assert manager.provider_register["title1"] == mock_provider1
|
assert manager.provider_register["title1"] == mock_provider1
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue