chore: initial commit for Pretor v0.1.0-alpha
正式发布 Pretor 平台的首个 alpha 版本。本项目旨在构建一个基于分布式架构的多智能体协同工作流水线。 核心功能实现: 1. 建立基于 BaseIndividual 的动态插件加载机制。 2. 实现三类核心 worker_individual 子个体。 3. 集成 Ray 框架支持分布式集群调度。 4. 基于 PostgreSQL 的全量持久化存储方案。 5. 提供完整的 FastAPI 后端与 React 前端交互界面。
This commit is contained in:
@@ -0,0 +1,170 @@
|
||||
import pytest
|
||||
import asyncio
|
||||
from unittest.mock import MagicMock, AsyncMock, patch
|
||||
import sys
|
||||
|
||||
import builtins
|
||||
|
||||
real_import = builtins.__import__
|
||||
|
||||
|
||||
def mock_import(name, globals=None, locals=None, fromlist=(), level=0):
|
||||
if name == 'ray':
|
||||
mock_ray = MagicMock()
|
||||
|
||||
def mock_remote(*args, **kwargs):
|
||||
if len(args) == 1 and callable(args[0]):
|
||||
return args[0]
|
||||
def decorator(cls):
|
||||
return cls
|
||||
return decorator
|
||||
|
||||
mock_ray.remote = mock_remote
|
||||
return mock_ray
|
||||
return real_import(name, globals, locals, fromlist, level)
|
||||
|
||||
|
||||
builtins.__import__ = mock_import
|
||||
|
||||
for mod in list(sys.modules.keys()):
|
||||
if 'pretor.core.global_state_machine.global_state_machine' in mod or 'ray' in mod:
|
||||
del sys.modules[mod]
|
||||
|
||||
from pretor.core.global_state_machine.global_state_machine import GlobalStateMachine
|
||||
from pretor.api.platform.event import PretorEvent
|
||||
from pretor.core.workflow.workflow import PretorWorkflow
|
||||
|
||||
builtins.__import__ = real_import
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_postgres():
|
||||
return MagicMock()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def gsm(mock_postgres):
|
||||
manager = GlobalStateMachine(mock_postgres)
|
||||
return manager
|
||||
|
||||
|
||||
def test_add_delete_get_event(gsm):
|
||||
event = MagicMock(spec=PretorEvent)
|
||||
event.trace_id = "123"
|
||||
|
||||
gsm.add_event(event)
|
||||
|
||||
assert getattr(event, 'pending_queue', None) is not None
|
||||
assert getattr(event, 'receive_queue', None) is not None
|
||||
|
||||
retrieved = gsm.get_event("123")
|
||||
assert retrieved == event
|
||||
|
||||
gsm.delete_event("123")
|
||||
assert gsm.get_event("123") is None
|
||||
|
||||
|
||||
def test_update_attachment_and_workflow(gsm):
|
||||
event = MagicMock(spec=PretorEvent)
|
||||
event.trace_id = "abc"
|
||||
gsm.add_event(event)
|
||||
|
||||
gsm.update_attachment("abc", {"k": "v"})
|
||||
assert event.attachment == {"k": "v"}
|
||||
|
||||
wf = MagicMock(spec=PretorWorkflow)
|
||||
gsm.update_workflow("abc", wf)
|
||||
assert event.workflow == wf
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_queues(gsm):
|
||||
event = MagicMock(spec=PretorEvent)
|
||||
event.trace_id = "q_event"
|
||||
# To use await put/get, we must actually use real asyncio queues for the mock event
|
||||
event.pending_queue = asyncio.Queue()
|
||||
event.receive_queue = asyncio.Queue()
|
||||
gsm.event_dict["q_event"] = event
|
||||
|
||||
await gsm.put_pending("q_event", "item1")
|
||||
res1 = await gsm.get_pending("q_event")
|
||||
assert res1 == "item1"
|
||||
|
||||
await gsm.put_received("q_event", "item2")
|
||||
res2 = await gsm.get_received("q_event")
|
||||
assert res2 == "item2"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_add_provider_success(gsm, mock_postgres):
|
||||
mock_provider_class = AsyncMock()
|
||||
mock_provider = MagicMock()
|
||||
mock_provider.provider_title = "MyProvider"
|
||||
mock_provider.provider_url = "url"
|
||||
mock_provider.provider_apikey = "key"
|
||||
mock_provider.provider_models = ["model"]
|
||||
mock_provider.provider_type = "openai"
|
||||
mock_provider_class.create_provider.return_value = mock_provider
|
||||
|
||||
gsm._global_provider_manager.provider_mapper = {"openai": mock_provider_class}
|
||||
gsm._global_provider_manager.provider_register = {}
|
||||
|
||||
mock_add_provider = AsyncMock()
|
||||
mock_postgres.add_provider_db = MagicMock()
|
||||
mock_postgres.add_provider_db.remote = mock_add_provider
|
||||
|
||||
await gsm.add_provider_wrap("openai", "title", "url", "key", "1")
|
||||
|
||||
assert gsm._global_provider_manager.provider_register["title"] == mock_provider
|
||||
mock_add_provider.assert_called_once()
|
||||
assert mock_provider.provider_owner == "1"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_add_provider_unsupported(gsm):
|
||||
gsm._global_provider_manager.provider_mapper = {}
|
||||
with patch("pretor.utils.logger.global_logger.bind") as mock_bind:
|
||||
mock_logger = MagicMock()
|
||||
mock_bind.return_value = mock_logger
|
||||
await gsm.add_provider_wrap("magic", "title", "url", "key", "1")
|
||||
mock_logger.warning.assert_called_with("Provider type magic is not supported.")
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_add_provider_request_error(gsm):
|
||||
from httpx import RequestError
|
||||
mock_provider_class = AsyncMock()
|
||||
mock_provider_class.create_provider.side_effect = RequestError("Network Error", request=MagicMock())
|
||||
gsm._global_provider_manager.provider_mapper = {"openai": mock_provider_class}
|
||||
|
||||
with patch("pretor.utils.logger.global_logger.bind") as mock_bind:
|
||||
from pretor.utils.error import RetryableError
|
||||
import pytest
|
||||
mock_logger = MagicMock()
|
||||
mock_bind.return_value = mock_logger
|
||||
with pytest.raises(RetryableError):
|
||||
await gsm.add_provider_wrap("openai", "title", "url", "key", "1")
|
||||
mock_logger.warning.assert_called()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_add_provider_generic_error(gsm):
|
||||
mock_provider_class = AsyncMock()
|
||||
mock_provider_class.create_provider.side_effect = ValueError("Some Error")
|
||||
gsm._global_provider_manager.provider_mapper = {"openai": mock_provider_class}
|
||||
|
||||
with patch("pretor.utils.logger.global_logger.bind") as mock_bind:
|
||||
mock_logger = MagicMock()
|
||||
mock_bind.return_value = mock_logger
|
||||
await gsm.add_provider_wrap("openai", "title", "url", "key", "1")
|
||||
mock_logger.warning.assert_called_once()
|
||||
assert "解析模型列表时发生错误" in mock_logger.warning.call_args[0][0]
|
||||
|
||||
|
||||
def test_get_provider_list_and_get_provider(gsm):
|
||||
mock_provider = MagicMock()
|
||||
gsm._global_provider_manager.provider_register = {"p1": mock_provider}
|
||||
|
||||
assert gsm._global_provider_manager.get_provider_list() == {"p1": mock_provider}
|
||||
assert gsm._global_provider_manager.get_provider("p1") == mock_provider
|
||||
assert gsm._global_provider_manager.get_provider("missing") is None
|
||||
@@ -0,0 +1,25 @@
|
||||
from pretor.core.global_state_machine.model_provider.base_provider import Provider, ProviderArgs, ProviderStatus
|
||||
|
||||
def test_provider_status():
|
||||
assert ProviderStatus.UP == "up"
|
||||
assert ProviderStatus.DOWN == "down"
|
||||
|
||||
def test_provider_args():
|
||||
args = ProviderArgs(
|
||||
provider_title="title",
|
||||
provider_url="url",
|
||||
provider_apikey="key",
|
||||
provider_owner="1"
|
||||
)
|
||||
assert args.provider_title == "title"
|
||||
|
||||
def test_provider_model():
|
||||
p = Provider(
|
||||
provider_title="title",
|
||||
provider_url="url",
|
||||
provider_apikey="key",
|
||||
provider_models=["model"],
|
||||
provider_type="openai"
|
||||
)
|
||||
assert p.provider_status == ProviderStatus.UP
|
||||
assert p.provider_owner is None
|
||||
@@ -0,0 +1,51 @@
|
||||
import pytest
|
||||
from unittest.mock import patch, MagicMock, AsyncMock
|
||||
from pretor.core.global_state_machine.model_provider.claude_provider import ClaudeProvider, ProviderArgs
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def provider_args():
|
||||
return ProviderArgs(
|
||||
provider_title="TestClaude",
|
||||
provider_url="https://api.anthropic.com",
|
||||
provider_apikey="testkey",
|
||||
provider_owner="1"
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("pretor.core.global_state_machine.model_provider.claude_provider.httpx.AsyncClient")
|
||||
async def test_load_models_success(mock_client, provider_args):
|
||||
mock_response = MagicMock()
|
||||
mock_response.status_code = 200
|
||||
mock_response.json.return_value = {
|
||||
"data": [{"id": "claude-3-opus-20240229"}, {"id": "claude-3-sonnet-20240229"}]
|
||||
}
|
||||
|
||||
mock_client_instance = AsyncMock()
|
||||
mock_client_instance.get.return_value = mock_response
|
||||
mock_client.return_value.__aenter__.return_value = mock_client_instance
|
||||
|
||||
models = await ClaudeProvider._load_models(provider_args)
|
||||
assert models == ["claude-3-opus-20240229", "claude-3-sonnet-20240229"]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("pretor.core.global_state_machine.model_provider.claude_provider.httpx.AsyncClient")
|
||||
async def test_load_models_error(mock_client, provider_args):
|
||||
mock_client_instance = AsyncMock()
|
||||
mock_client_instance.get.side_effect = Exception("network error")
|
||||
mock_client.return_value.__aenter__.return_value = mock_client_instance
|
||||
|
||||
models = await ClaudeProvider._load_models(provider_args)
|
||||
assert models == []
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("pretor.core.global_state_machine.model_provider.claude_provider.ClaudeProvider._load_models",
|
||||
return_value=["claude-3"])
|
||||
async def test_create_provider(mock_load, provider_args):
|
||||
provider = await ClaudeProvider.create_provider(provider_args)
|
||||
assert provider.provider_title == "TestClaude"
|
||||
assert provider.provider_models == ["claude-3"]
|
||||
assert provider.provider_type == "claude"
|
||||
@@ -0,0 +1,112 @@
|
||||
import pytest
|
||||
from unittest.mock import patch, MagicMock, AsyncMock
|
||||
from pretor.core.global_state_machine.model_provider.openai_provider import OpenAIProvider, ProviderArgs
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def provider_args():
|
||||
return ProviderArgs(
|
||||
provider_title="TestOpenAI",
|
||||
provider_url="https://api.openai.com/v1",
|
||||
provider_apikey="testkey",
|
||||
provider_owner="1"
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def provider_args_no_v1():
|
||||
return ProviderArgs(
|
||||
provider_title="TestOpenAI",
|
||||
provider_url="https://api.openai.com",
|
||||
provider_apikey="testkey",
|
||||
provider_owner="1"
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("pretor.core.global_state_machine.model_provider.openai_provider.httpx.AsyncClient")
|
||||
async def test_load_models_success(mock_client, provider_args):
|
||||
mock_response = MagicMock()
|
||||
mock_response.status_code = 200
|
||||
mock_response.json.return_value = {
|
||||
"data": [{"id": "gpt-4"}, {"id": "gpt-3.5-turbo"}]
|
||||
}
|
||||
|
||||
mock_client_instance = AsyncMock()
|
||||
mock_client_instance.get.return_value = mock_response
|
||||
mock_client.return_value.__aenter__.return_value = mock_client_instance
|
||||
|
||||
models = await OpenAIProvider._load_models(provider_args)
|
||||
assert models == ["gpt-3.5-turbo", "gpt-4"]
|
||||
mock_client_instance.get.assert_called_once_with(
|
||||
"https://api.openai.com/v1/models",
|
||||
headers={"Authorization": "Bearer testkey", "Content-Type": "application/json"}
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("pretor.core.global_state_machine.model_provider.openai_provider.httpx.AsyncClient")
|
||||
async def test_load_models_no_v1(mock_client, provider_args_no_v1):
|
||||
mock_response = MagicMock()
|
||||
mock_response.status_code = 200
|
||||
mock_response.json.return_value = {"data": []}
|
||||
|
||||
mock_client_instance = AsyncMock()
|
||||
mock_client_instance.get.return_value = mock_response
|
||||
mock_client.return_value.__aenter__.return_value = mock_client_instance
|
||||
|
||||
models = await OpenAIProvider._load_models(provider_args_no_v1)
|
||||
assert models == []
|
||||
mock_client_instance.get.assert_called_once_with(
|
||||
"https://api.openai.com/v1/models",
|
||||
headers={"Authorization": "Bearer testkey", "Content-Type": "application/json"}
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("pretor.core.global_state_machine.model_provider.openai_provider.httpx.AsyncClient")
|
||||
async def test_load_models_status_error(mock_client, provider_args):
|
||||
mock_response = MagicMock()
|
||||
mock_response.status_code = 401
|
||||
|
||||
mock_client_instance = AsyncMock()
|
||||
mock_client_instance.get.return_value = mock_response
|
||||
mock_client.return_value.__aenter__.return_value = mock_client_instance
|
||||
|
||||
models = await OpenAIProvider._load_models(provider_args)
|
||||
assert models == []
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("pretor.core.global_state_machine.model_provider.openai_provider.httpx.AsyncClient")
|
||||
async def test_load_models_request_error(mock_client, provider_args):
|
||||
import httpx
|
||||
mock_client_instance = AsyncMock()
|
||||
mock_client_instance.get.side_effect = httpx.RequestError("network error", request=MagicMock())
|
||||
mock_client.return_value.__aenter__.return_value = mock_client_instance
|
||||
|
||||
import pytest
|
||||
from pretor.utils.error import RetryableError
|
||||
with pytest.raises(RetryableError):
|
||||
await OpenAIProvider._load_models(provider_args)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("pretor.core.global_state_machine.model_provider.openai_provider.httpx.AsyncClient")
|
||||
async def test_load_models_generic_error(mock_client, provider_args):
|
||||
mock_client_instance = AsyncMock()
|
||||
mock_client_instance.get.side_effect = Exception("generic error")
|
||||
mock_client.return_value.__aenter__.return_value = mock_client_instance
|
||||
|
||||
models = await OpenAIProvider._load_models(provider_args)
|
||||
assert models == []
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("pretor.core.global_state_machine.model_provider.openai_provider.OpenAIProvider._load_models",
|
||||
return_value=["gpt-4"])
|
||||
async def test_create_provider(mock_load, provider_args):
|
||||
provider = await OpenAIProvider.create_provider(provider_args)
|
||||
assert provider.provider_title == "TestOpenAI"
|
||||
assert provider.provider_models == ["gpt-4"]
|
||||
assert provider.provider_type == "openai"
|
||||
@@ -0,0 +1,26 @@
|
||||
import pytest
|
||||
from unittest.mock import MagicMock, AsyncMock
|
||||
from pretor.core.global_state_machine.provider_manager import ProviderManager
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_provider_manager_init():
|
||||
mock_postgres = MagicMock()
|
||||
mock_provider1 = MagicMock()
|
||||
mock_provider1.provider_title = "title1"
|
||||
|
||||
mock_provider2 = MagicMock()
|
||||
mock_provider2.provider_title = "title2"
|
||||
|
||||
mock_postgres.get_provider = MagicMock()
|
||||
mock_postgres.get_provider.remote = AsyncMock(return_value=[mock_provider1, mock_provider2])
|
||||
|
||||
manager = ProviderManager(mock_postgres)
|
||||
mock_postgres.provider_database = MagicMock()
|
||||
mock_postgres.provider_database.remote = AsyncMock(return_value=[mock_provider1, mock_provider2])
|
||||
await manager.init_provider_register(mock_postgres)
|
||||
|
||||
assert "openai" in manager.provider_mapper
|
||||
assert "claude" in manager.provider_mapper
|
||||
|
||||
assert manager.provider_register["title1"] == mock_provider1
|
||||
@@ -0,0 +1,5 @@
|
||||
from pretor.core.global_state_machine.tool_manager import GlobalToolManager
|
||||
|
||||
def test_global_tool_manager_init():
|
||||
manager = GlobalToolManager()
|
||||
assert isinstance(manager, GlobalToolManager)
|
||||
Reference in New Issue
Block a user