chore: initial commit for Pretor v0.1.0-alpha

正式发布 Pretor 平台的首个 alpha 版本。本项目旨在构建一个基于分布式架构的多智能体协同工作流水线。

核心功能实现:
1. 建立基于 BaseIndividual 的动态插件加载机制。
2. 实现三类核心 worker_individual 子个体。
3. 集成 Ray 框架支持分布式集群调度。
4. 基于 PostgreSQL 的全量持久化存储方案。
5. 提供完整的 FastAPI 后端与 React 前端交互界面。
This commit is contained in:
2026-04-29 10:09:07 +08:00
commit d84212f780
163 changed files with 19251 additions and 0 deletions
@@ -0,0 +1,55 @@
import pytest
from unittest.mock import MagicMock, patch
from pretor.adapter.model_adapter.agent_factory import AgentFactory
from pretor.utils.error import ModelNotExistError
def test_create_agent_success_real():
mock_provider = MagicMock()
mock_provider.provider_type = "openai"
mock_provider.provider_models = ["gpt-4"]
mock_provider.provider_apikey = "key"
mock_provider.provider_url = "url"
with patch("pretor.adapter.model_adapter.agent_factory.Agent") as mock_agent_cls:
with patch("pretor.adapter.model_adapter.agent_factory.OpenAIChatModel") as mock_model_cls:
with patch("pretor.adapter.model_adapter.agent_factory.OpenAIProvider") as mock_provider_cls:
factory = AgentFactory()
agent = factory.create_agent(
provider=mock_provider,
model_id="gpt-4",
output_type=str,
system_prompt="You are an AI",
deps_type=dict,
agent_name="myagent"
)
mock_provider_cls.assert_called_once_with(api_key="key", base_url="url")
mock_model_cls.assert_called_once_with("gpt-4", provider=mock_provider_cls.return_value)
mock_agent_cls.assert_called_once_with(
model=mock_model_cls.return_value,
name="myagent",
system_prompt="You are an AI",
output_type=str,
deps_type=dict,
tools=None
)
assert agent == mock_agent_cls.return_value
def test_create_agent_model_not_exist():
factory = AgentFactory()
mock_provider = MagicMock()
mock_provider.provider_models = ["gpt-3"]
with pytest.raises(ModelNotExistError):
factory.create_agent(mock_provider, "gpt-4", str, "prompt", dict, "agent")
def test_create_agent_invalid_provider_type():
factory = AgentFactory()
mock_provider = MagicMock()
mock_provider.provider_type = "unknown"
mock_provider.provider_models = ["gpt-4"]
with pytest.raises(ValueError, match="不支持的协议类型: unknown"):
factory.create_agent(mock_provider, "gpt-4", str, "prompt", dict, "agent")
@@ -0,0 +1,74 @@
import pytest
from unittest.mock import patch
from sqlalchemy.exc import IntegrityError, OperationalError
from pydantic import ValidationError
from pretor.utils.error import UserNotExistError
from pretor.core.database.database_exception import database_exception
@database_exception
async def success_func():
return "success"
@database_exception
async def validation_error_func():
raise ValidationError.from_exception_data(title="Mock", line_errors=[])
@database_exception
async def integrity_error_func():
raise IntegrityError("mock_statement", "mock_params", "mock_orig")
@database_exception
async def operational_error_func():
raise OperationalError("mock_statement", "mock_params", "mock_orig")
@database_exception
async def user_not_exist_error_func():
raise UserNotExistError("mock user")
@database_exception
async def exception_func():
raise Exception("mock generic exception")
@pytest.mark.asyncio
async def test_success_func():
assert await success_func() == "success"
@pytest.mark.asyncio
@patch("pretor.core.database.database_exception.logger")
async def test_validation_error(mock_logger):
with pytest.raises(ValidationError):
await validation_error_func()
mock_logger.error.assert_called_once()
assert "对象校验失败" in mock_logger.error.call_args[0][0]
@pytest.mark.asyncio
@patch("pretor.core.database.database_exception.logger")
async def test_integrity_error(mock_logger):
with pytest.raises(IntegrityError):
await integrity_error_func()
mock_logger.error.assert_called_once()
assert "数据库完整性错误" in mock_logger.error.call_args[0][0]
@pytest.mark.asyncio
@patch("pretor.core.database.database_exception.logger")
async def test_operational_error(mock_logger):
with pytest.raises(OperationalError):
await operational_error_func()
mock_logger.error.assert_called_once()
assert "数据库连接异常" in mock_logger.error.call_args[0][0]
@pytest.mark.asyncio
@patch("pretor.core.database.database_exception.logger")
async def test_user_not_exist_error(mock_logger):
result = await user_not_exist_error_func()
assert result is None
mock_logger.error.assert_called_once()
assert "更改密码失败,用户不存在" in mock_logger.error.call_args[0][0]
@pytest.mark.asyncio
@patch("pretor.core.database.database_exception.logger")
async def test_generic_exception(mock_logger):
with pytest.raises(Exception, match="mock generic exception"):
await exception_func()
mock_logger.exception.assert_called_once()
assert "未预期的数据库错误" in mock_logger.exception.call_args[0][0]
+145
View File
@@ -0,0 +1,145 @@
import pytest
from unittest.mock import MagicMock, AsyncMock, patch
import json
import sys
import builtins
real_import = builtins.__import__
def mock_import(name, globals=None, locals=None, fromlist=(), level=0):
if name == 'sqlmodel':
mock_sqlmodel = MagicMock()
class DummySQLModel:
def __init_subclass__(cls, **kwargs):
pass
mock_sqlmodel.SQLModel = DummySQLModel
mock_sqlmodel.Field = MagicMock(return_value=None)
mock_sqlmodel.select = MagicMock()
return mock_sqlmodel
return real_import(name, globals, locals, fromlist, level)
builtins.__import__ = mock_import
for mod in list(sys.modules.keys()):
if 'pretor.core.database.module.memory' in mod or 'sqlmodel' in mod:
del sys.modules[mod]
from pretor.core.database.module.memory import MemoryRAG
builtins.__import__ = real_import
@pytest.fixture(autouse=True)
def mock_dependencies():
with patch("pretor.core.database.module.memory.WorkflowRecord") as mock_workflow_record:
with patch("pretor.core.database.module.memory.MemoryRecord") as mock_memory_record:
with patch("pretor.core.database.module.memory.select") as mock_select:
yield mock_workflow_record, mock_memory_record, mock_select
@pytest.fixture
def mock_session_maker():
maker = MagicMock()
session = AsyncMock()
session.add = MagicMock()
maker.return_value.__aenter__.return_value = session
maker.__aenter__.return_value = session
maker.__aexit__ = AsyncMock()
return maker, session
@pytest.mark.asyncio
async def test_save_workflow(mock_session_maker, mock_dependencies):
mock_workflow_record, _, _ = mock_dependencies
maker, session = mock_session_maker
rag = MemoryRAG(maker)
mock_record = MagicMock()
mock_workflow_record.return_value = mock_record
workflow_data = {"key": "value"}
record = await rag.save_workflow("wf_123", workflow_data)
mock_workflow_record.assert_called_once_with(
workflow_id="wf_123",
workflow_data_json=json.dumps(workflow_data)
)
session.add.assert_called_once_with(mock_record)
session.commit.assert_called_once()
session.refresh.assert_called_once_with(mock_record)
assert record == mock_record
@pytest.mark.asyncio
async def test_get_workflow_success(mock_session_maker, mock_dependencies):
_, _, mock_select = mock_dependencies
maker, session = mock_session_maker
rag = MemoryRAG(maker)
mock_statement = MagicMock()
mock_select.return_value.where.return_value = mock_statement
mock_record = MagicMock()
mock_record.workflow_data_json = '{"key": "value"}'
mock_exec_result = MagicMock()
mock_exec_result.scalar_one_or_none.return_value = mock_record
session.execute = AsyncMock(return_value=mock_exec_result)
data = await rag.get_workflow("wf_123")
session.execute.assert_called_once_with(mock_statement)
assert data == {"key": "value"}
@pytest.mark.asyncio
async def test_get_workflow_not_found(mock_session_maker, mock_dependencies):
_, _, mock_select = mock_dependencies
maker, session = mock_session_maker
rag = MemoryRAG(maker)
mock_exec_result = MagicMock()
mock_exec_result.scalar_one_or_none.return_value = None
session.execute = AsyncMock(return_value=mock_exec_result)
data = await rag.get_workflow("wf_123")
assert data is None
@pytest.mark.asyncio
async def test_add_memory(mock_session_maker, mock_dependencies):
_, mock_memory_record, _ = mock_dependencies
maker, session = mock_session_maker
rag = MemoryRAG(maker)
mock_record = MagicMock()
mock_memory_record.return_value = mock_record
record = await rag.add_memory("text", [0.1, 0.2])
mock_memory_record.assert_called_once_with(memory_text="text", embedding=[0.1, 0.2])
session.add.assert_called_once_with(mock_record)
session.commit.assert_called_once()
session.refresh.assert_called_once_with(mock_record)
assert record == mock_record
@pytest.mark.asyncio
async def test_retrieve_memory(mock_session_maker, mock_dependencies):
_, _, mock_select = mock_dependencies
maker, session = mock_session_maker
rag = MemoryRAG(maker)
mock_statement = MagicMock()
mock_select.return_value.limit.return_value = mock_statement
mock_exec_result = MagicMock()
mock_exec_result.all.return_value = ["res1", "res2"]
session.execute = AsyncMock(return_value=mock_exec_result)
results = await rag.retrieve_memory([0.1, 0.2], 5)
session.execute.assert_called_once_with(mock_statement)
assert results == ["res1", "res2"]
+180
View File
@@ -0,0 +1,180 @@
import pytest
from unittest.mock import MagicMock, AsyncMock, patch
@pytest.fixture(autouse=True)
def mock_dependencies():
with patch("pretor.core.database.module.user.User") as mock_user_cls:
mock_user_cls.user_name = MagicMock()
with patch("pretor.core.database.module.user.select") as mock_select:
yield mock_user_cls, mock_select
@pytest.fixture
def mock_session_maker():
maker = MagicMock()
session = AsyncMock()
session.add = MagicMock()
session.delete = MagicMock()
maker.return_value.__aenter__.return_value = session
maker.__aenter__.return_value = session
maker.__aexit__ = AsyncMock()
return maker, session
@pytest.mark.asyncio
async def test_add_user(mock_session_maker, mock_dependencies):
mock_user_cls, _ = mock_dependencies
from pretor.core.database.module.user import AuthDatabase
maker, session = mock_session_maker
db = AuthDatabase(maker)
mock_user = MagicMock()
mock_user.user_name = "testuser"
mock_user.hashed_password = "hashedpw"
mock_user_cls.return_value = mock_user
mock_exec_result = MagicMock()
mock_exec_result.first.return_value = None
session.execute = AsyncMock(return_value=mock_exec_result)
user = await db.add_user("testuser", "hashedpw")
assert user.user_name == "testuser"
assert user.hashed_password == "hashedpw"
session.add.assert_called_once_with(mock_user)
session.commit.assert_called_once()
session.refresh.assert_called_once_with(mock_user)
@pytest.mark.asyncio
async def test_change_password_success(mock_session_maker, mock_dependencies):
mock_user_cls, mock_select = mock_dependencies
from pretor.core.database.module.user import AuthDatabase
maker, session = mock_session_maker
db = AuthDatabase(maker)
mock_statement = MagicMock()
mock_select.return_value.where.return_value = mock_statement
from pretor.utils.access import Accessor
mock_user = MagicMock()
mock_user.hashed_password = Accessor.hash_password("old_password")
mock_exec_result = MagicMock()
mock_exec_result.scalar_one_or_none.return_value = mock_user
session.execute = AsyncMock(return_value=mock_exec_result)
user = await db.change_password("testuser", "old_password", "new_password")
session.execute.assert_called_once_with(mock_statement)
assert user.hashed_password == "new_password"
session.add.assert_called_once_with(mock_user)
session.commit.assert_called_once()
session.refresh.assert_called_once_with(mock_user)
@pytest.mark.asyncio
async def test_change_password_user_not_exist(mock_session_maker, mock_dependencies):
mock_user_cls, mock_select = mock_dependencies
from pretor.core.database.module.user import AuthDatabase
maker, session = mock_session_maker
db = AuthDatabase(maker)
mock_exec_result = MagicMock()
mock_exec_result.scalar_one_or_none.return_value = None
session.execute = AsyncMock(return_value=mock_exec_result)
result = await db.change_password("testuser", "old_password", "new_password")
assert result is None
@pytest.mark.asyncio
async def test_change_password_wrong_password(mock_session_maker, mock_dependencies):
mock_user_cls, mock_select = mock_dependencies
from pretor.core.database.module.user import AuthDatabase
maker, session = mock_session_maker
db = AuthDatabase(maker)
from pretor.utils.access import Accessor
mock_user = MagicMock()
mock_user.hashed_password = Accessor.hash_password("actual_password")
mock_exec_result = MagicMock()
mock_exec_result.scalar_one_or_none.return_value = mock_user
session.execute = AsyncMock(return_value=mock_exec_result)
from pretor.utils.error import UserPasswordError
with pytest.raises(UserPasswordError):
await db.change_password("testuser", "old_password", "new_password")
@pytest.mark.asyncio
async def test_delete_user_success(mock_session_maker, mock_dependencies):
mock_user_cls, mock_select = mock_dependencies
from pretor.core.database.module.user import AuthDatabase
maker, session = mock_session_maker
db = AuthDatabase(maker)
mock_statement = MagicMock()
mock_select.return_value.where.return_value = mock_statement
mock_user = MagicMock()
mock_exec_result = MagicMock()
mock_exec_result.scalar_one_or_none.return_value = mock_user
session.execute = AsyncMock(return_value=mock_exec_result)
await db.delete_user("testuser")
session.execute.assert_called_once_with(mock_statement)
session.delete.assert_called_once_with(mock_user)
session.commit.assert_called_once()
@pytest.mark.asyncio
async def test_delete_user_not_exist(mock_session_maker, mock_dependencies):
mock_user_cls, mock_select = mock_dependencies
from pretor.core.database.module.user import AuthDatabase
maker, session = mock_session_maker
db = AuthDatabase(maker)
mock_exec_result = MagicMock()
mock_exec_result.scalar_one_or_none.return_value = None
session.execute = AsyncMock(return_value=mock_exec_result)
result = await db.delete_user("testuser")
assert result is None
@pytest.mark.asyncio
async def test_login_user_success(mock_session_maker, mock_dependencies):
mock_user_cls, mock_select = mock_dependencies
from pretor.core.database.module.user import AuthDatabase
maker, session = mock_session_maker
db = AuthDatabase(maker)
mock_statement = MagicMock()
mock_select.return_value.where.return_value = mock_statement
mock_user = MagicMock()
mock_exec_result = MagicMock()
mock_exec_result.scalar_one_or_none.return_value = mock_user
session.execute = AsyncMock(return_value=mock_exec_result)
user = await db.login_user("testuser")
session.execute.assert_called_once_with(mock_statement)
assert user == mock_user
@pytest.mark.asyncio
async def test_login_user_not_exist(mock_session_maker, mock_dependencies):
mock_user_cls, mock_select = mock_dependencies
from pretor.core.database.module.user import AuthDatabase
maker, session = mock_session_maker
db = AuthDatabase(maker)
mock_exec_result = MagicMock()
mock_exec_result.scalar_one_or_none.return_value = None
session.execute = AsyncMock(return_value=mock_exec_result)
result = await db.login_user("testuser")
assert result is None
+78
View File
@@ -0,0 +1,78 @@
import pytest
from unittest.mock import patch, MagicMock
import sys
import builtins
real_import = builtins.__import__
def mock_import(name, globals=None, locals=None, fromlist=(), level=0):
if name == 'ray':
mock_ray = MagicMock()
def mock_remote(*args, **kwargs):
if len(args) == 1 and callable(args[0]):
return args[0]
def decorator(cls):
return cls
return decorator
mock_ray.remote = mock_remote
return mock_ray
return real_import(name, globals, locals, fromlist, level)
builtins.__import__ = mock_import
for mod in list(sys.modules.keys()):
if 'pretor.core.database.postgres' in mod or 'ray' in mod:
del sys.modules[mod]
from pretor.core.database.postgres import PostgresDatabase
builtins.__import__ = real_import
@patch("pretor.core.database.postgres.create_async_engine")
@patch("pretor.core.database.postgres.sessionmaker")
@patch("pretor.core.database.postgres.AuthDatabase")
@patch("pretor.core.database.postgres.ProviderDatabase")
@patch("pretor.core.database.postgres.os.environ.get")
@pytest.mark.asyncio
async def test_postgres_database(mock_env_get, mock_provider_db, mock_auth_db, mock_sessionmaker, mock_create_engine):
def env_side_effect(key):
return {
"POSTGRES_USER": "testuser",
"POSTGRES_PASSWORD": "testpassword",
"POSTGRES_HOST": "localhost",
"POSTGRES_PORT": "5432",
"POSTGRES_DB": "testdb"
}.get(key)
mock_env_get.side_effect = env_side_effect
mock_engine = MagicMock()
mock_conn = MagicMock()
from unittest.mock import AsyncMock
mock_conn.run_sync = AsyncMock()
mock_begin_ctx = MagicMock()
mock_begin_ctx.__aenter__ = AsyncMock(return_value=mock_conn)
mock_begin_ctx.__aexit__ = AsyncMock()
mock_engine.begin.return_value = mock_begin_ctx
mock_create_engine.return_value = mock_engine
db = PostgresDatabase()
mock_create_engine.assert_called_once_with(
"postgresql+asyncpg://testuser:testpassword@localhost:5432/testdb",
echo=True
)
mock_auth_db.assert_called_once()
mock_provider_db.assert_called_once()
mock_auth_db.return_value.get_user_authority = AsyncMock(return_value="test_auth")
with patch("pretor.core.database.postgres.SQLModel.metadata.create_all") as mock_create_all:
await db.init_db()
mock_conn.run_sync.assert_called_once_with(mock_create_all)
assert await db.get_user_authority(user_id="123") == "test_auth"
@@ -0,0 +1,14 @@
from pretor.core.database.table.provider import Provider
def test_provider_table():
# Provide required fields
provider = Provider(
provider_title="title",
provider_url="url",
provider_apikey="key",
provider_models=["model_1"],
provider_type="type",
provider_owner=1
)
assert Provider.__tablename__ == 'provider'
assert provider.provider_title == "title"
@@ -0,0 +1,6 @@
from pretor.core.database.table.user import User
def test_user_table():
user = User(user_id="id", user_name="name", hashed_password="pw")
assert User.__tablename__ == 'user'
assert user.user_name == "name"
@@ -0,0 +1,170 @@
import pytest
import asyncio
from unittest.mock import MagicMock, AsyncMock, patch
import sys
import builtins
real_import = builtins.__import__
def mock_import(name, globals=None, locals=None, fromlist=(), level=0):
if name == 'ray':
mock_ray = MagicMock()
def mock_remote(*args, **kwargs):
if len(args) == 1 and callable(args[0]):
return args[0]
def decorator(cls):
return cls
return decorator
mock_ray.remote = mock_remote
return mock_ray
return real_import(name, globals, locals, fromlist, level)
builtins.__import__ = mock_import
for mod in list(sys.modules.keys()):
if 'pretor.core.global_state_machine.global_state_machine' in mod or 'ray' in mod:
del sys.modules[mod]
from pretor.core.global_state_machine.global_state_machine import GlobalStateMachine
from pretor.api.platform.event import PretorEvent
from pretor.core.workflow.workflow import PretorWorkflow
builtins.__import__ = real_import
@pytest.fixture
def mock_postgres():
return MagicMock()
@pytest.fixture
def gsm(mock_postgres):
manager = GlobalStateMachine(mock_postgres)
return manager
def test_add_delete_get_event(gsm):
event = MagicMock(spec=PretorEvent)
event.trace_id = "123"
gsm.add_event(event)
assert getattr(event, 'pending_queue', None) is not None
assert getattr(event, 'receive_queue', None) is not None
retrieved = gsm.get_event("123")
assert retrieved == event
gsm.delete_event("123")
assert gsm.get_event("123") is None
def test_update_attachment_and_workflow(gsm):
event = MagicMock(spec=PretorEvent)
event.trace_id = "abc"
gsm.add_event(event)
gsm.update_attachment("abc", {"k": "v"})
assert event.attachment == {"k": "v"}
wf = MagicMock(spec=PretorWorkflow)
gsm.update_workflow("abc", wf)
assert event.workflow == wf
@pytest.mark.asyncio
async def test_queues(gsm):
event = MagicMock(spec=PretorEvent)
event.trace_id = "q_event"
# To use await put/get, we must actually use real asyncio queues for the mock event
event.pending_queue = asyncio.Queue()
event.receive_queue = asyncio.Queue()
gsm.event_dict["q_event"] = event
await gsm.put_pending("q_event", "item1")
res1 = await gsm.get_pending("q_event")
assert res1 == "item1"
await gsm.put_received("q_event", "item2")
res2 = await gsm.get_received("q_event")
assert res2 == "item2"
@pytest.mark.asyncio
async def test_add_provider_success(gsm, mock_postgres):
mock_provider_class = AsyncMock()
mock_provider = MagicMock()
mock_provider.provider_title = "MyProvider"
mock_provider.provider_url = "url"
mock_provider.provider_apikey = "key"
mock_provider.provider_models = ["model"]
mock_provider.provider_type = "openai"
mock_provider_class.create_provider.return_value = mock_provider
gsm._global_provider_manager.provider_mapper = {"openai": mock_provider_class}
gsm._global_provider_manager.provider_register = {}
mock_add_provider = AsyncMock()
mock_postgres.add_provider_db = MagicMock()
mock_postgres.add_provider_db.remote = mock_add_provider
await gsm.add_provider_wrap("openai", "title", "url", "key", "1")
assert gsm._global_provider_manager.provider_register["title"] == mock_provider
mock_add_provider.assert_called_once()
assert mock_provider.provider_owner == "1"
@pytest.mark.asyncio
async def test_add_provider_unsupported(gsm):
gsm._global_provider_manager.provider_mapper = {}
with patch("pretor.utils.logger.global_logger.bind") as mock_bind:
mock_logger = MagicMock()
mock_bind.return_value = mock_logger
await gsm.add_provider_wrap("magic", "title", "url", "key", "1")
mock_logger.warning.assert_called_with("Provider type magic is not supported.")
@pytest.mark.asyncio
async def test_add_provider_request_error(gsm):
from httpx import RequestError
mock_provider_class = AsyncMock()
mock_provider_class.create_provider.side_effect = RequestError("Network Error", request=MagicMock())
gsm._global_provider_manager.provider_mapper = {"openai": mock_provider_class}
with patch("pretor.utils.logger.global_logger.bind") as mock_bind:
from pretor.utils.error import RetryableError
import pytest
mock_logger = MagicMock()
mock_bind.return_value = mock_logger
with pytest.raises(RetryableError):
await gsm.add_provider_wrap("openai", "title", "url", "key", "1")
mock_logger.warning.assert_called()
@pytest.mark.asyncio
async def test_add_provider_generic_error(gsm):
mock_provider_class = AsyncMock()
mock_provider_class.create_provider.side_effect = ValueError("Some Error")
gsm._global_provider_manager.provider_mapper = {"openai": mock_provider_class}
with patch("pretor.utils.logger.global_logger.bind") as mock_bind:
mock_logger = MagicMock()
mock_bind.return_value = mock_logger
await gsm.add_provider_wrap("openai", "title", "url", "key", "1")
mock_logger.warning.assert_called_once()
assert "解析模型列表时发生错误" in mock_logger.warning.call_args[0][0]
def test_get_provider_list_and_get_provider(gsm):
mock_provider = MagicMock()
gsm._global_provider_manager.provider_register = {"p1": mock_provider}
assert gsm._global_provider_manager.get_provider_list() == {"p1": mock_provider}
assert gsm._global_provider_manager.get_provider("p1") == mock_provider
assert gsm._global_provider_manager.get_provider("missing") is None
@@ -0,0 +1,25 @@
from pretor.core.global_state_machine.model_provider.base_provider import Provider, ProviderArgs, ProviderStatus
def test_provider_status():
assert ProviderStatus.UP == "up"
assert ProviderStatus.DOWN == "down"
def test_provider_args():
args = ProviderArgs(
provider_title="title",
provider_url="url",
provider_apikey="key",
provider_owner="1"
)
assert args.provider_title == "title"
def test_provider_model():
p = Provider(
provider_title="title",
provider_url="url",
provider_apikey="key",
provider_models=["model"],
provider_type="openai"
)
assert p.provider_status == ProviderStatus.UP
assert p.provider_owner is None
@@ -0,0 +1,51 @@
import pytest
from unittest.mock import patch, MagicMock, AsyncMock
from pretor.core.global_state_machine.model_provider.claude_provider import ClaudeProvider, ProviderArgs
@pytest.fixture
def provider_args():
return ProviderArgs(
provider_title="TestClaude",
provider_url="https://api.anthropic.com",
provider_apikey="testkey",
provider_owner="1"
)
@pytest.mark.asyncio
@patch("pretor.core.global_state_machine.model_provider.claude_provider.httpx.AsyncClient")
async def test_load_models_success(mock_client, provider_args):
mock_response = MagicMock()
mock_response.status_code = 200
mock_response.json.return_value = {
"data": [{"id": "claude-3-opus-20240229"}, {"id": "claude-3-sonnet-20240229"}]
}
mock_client_instance = AsyncMock()
mock_client_instance.get.return_value = mock_response
mock_client.return_value.__aenter__.return_value = mock_client_instance
models = await ClaudeProvider._load_models(provider_args)
assert models == ["claude-3-opus-20240229", "claude-3-sonnet-20240229"]
@pytest.mark.asyncio
@patch("pretor.core.global_state_machine.model_provider.claude_provider.httpx.AsyncClient")
async def test_load_models_error(mock_client, provider_args):
mock_client_instance = AsyncMock()
mock_client_instance.get.side_effect = Exception("network error")
mock_client.return_value.__aenter__.return_value = mock_client_instance
models = await ClaudeProvider._load_models(provider_args)
assert models == []
@pytest.mark.asyncio
@patch("pretor.core.global_state_machine.model_provider.claude_provider.ClaudeProvider._load_models",
return_value=["claude-3"])
async def test_create_provider(mock_load, provider_args):
provider = await ClaudeProvider.create_provider(provider_args)
assert provider.provider_title == "TestClaude"
assert provider.provider_models == ["claude-3"]
assert provider.provider_type == "claude"
@@ -0,0 +1,112 @@
import pytest
from unittest.mock import patch, MagicMock, AsyncMock
from pretor.core.global_state_machine.model_provider.openai_provider import OpenAIProvider, ProviderArgs
@pytest.fixture
def provider_args():
return ProviderArgs(
provider_title="TestOpenAI",
provider_url="https://api.openai.com/v1",
provider_apikey="testkey",
provider_owner="1"
)
@pytest.fixture
def provider_args_no_v1():
return ProviderArgs(
provider_title="TestOpenAI",
provider_url="https://api.openai.com",
provider_apikey="testkey",
provider_owner="1"
)
@pytest.mark.asyncio
@patch("pretor.core.global_state_machine.model_provider.openai_provider.httpx.AsyncClient")
async def test_load_models_success(mock_client, provider_args):
mock_response = MagicMock()
mock_response.status_code = 200
mock_response.json.return_value = {
"data": [{"id": "gpt-4"}, {"id": "gpt-3.5-turbo"}]
}
mock_client_instance = AsyncMock()
mock_client_instance.get.return_value = mock_response
mock_client.return_value.__aenter__.return_value = mock_client_instance
models = await OpenAIProvider._load_models(provider_args)
assert models == ["gpt-3.5-turbo", "gpt-4"]
mock_client_instance.get.assert_called_once_with(
"https://api.openai.com/v1/models",
headers={"Authorization": "Bearer testkey", "Content-Type": "application/json"}
)
@pytest.mark.asyncio
@patch("pretor.core.global_state_machine.model_provider.openai_provider.httpx.AsyncClient")
async def test_load_models_no_v1(mock_client, provider_args_no_v1):
mock_response = MagicMock()
mock_response.status_code = 200
mock_response.json.return_value = {"data": []}
mock_client_instance = AsyncMock()
mock_client_instance.get.return_value = mock_response
mock_client.return_value.__aenter__.return_value = mock_client_instance
models = await OpenAIProvider._load_models(provider_args_no_v1)
assert models == []
mock_client_instance.get.assert_called_once_with(
"https://api.openai.com/v1/models",
headers={"Authorization": "Bearer testkey", "Content-Type": "application/json"}
)
@pytest.mark.asyncio
@patch("pretor.core.global_state_machine.model_provider.openai_provider.httpx.AsyncClient")
async def test_load_models_status_error(mock_client, provider_args):
mock_response = MagicMock()
mock_response.status_code = 401
mock_client_instance = AsyncMock()
mock_client_instance.get.return_value = mock_response
mock_client.return_value.__aenter__.return_value = mock_client_instance
models = await OpenAIProvider._load_models(provider_args)
assert models == []
@pytest.mark.asyncio
@patch("pretor.core.global_state_machine.model_provider.openai_provider.httpx.AsyncClient")
async def test_load_models_request_error(mock_client, provider_args):
import httpx
mock_client_instance = AsyncMock()
mock_client_instance.get.side_effect = httpx.RequestError("network error", request=MagicMock())
mock_client.return_value.__aenter__.return_value = mock_client_instance
import pytest
from pretor.utils.error import RetryableError
with pytest.raises(RetryableError):
await OpenAIProvider._load_models(provider_args)
@pytest.mark.asyncio
@patch("pretor.core.global_state_machine.model_provider.openai_provider.httpx.AsyncClient")
async def test_load_models_generic_error(mock_client, provider_args):
mock_client_instance = AsyncMock()
mock_client_instance.get.side_effect = Exception("generic error")
mock_client.return_value.__aenter__.return_value = mock_client_instance
models = await OpenAIProvider._load_models(provider_args)
assert models == []
@pytest.mark.asyncio
@patch("pretor.core.global_state_machine.model_provider.openai_provider.OpenAIProvider._load_models",
return_value=["gpt-4"])
async def test_create_provider(mock_load, provider_args):
provider = await OpenAIProvider.create_provider(provider_args)
assert provider.provider_title == "TestOpenAI"
assert provider.provider_models == ["gpt-4"]
assert provider.provider_type == "openai"
@@ -0,0 +1,26 @@
import pytest
from unittest.mock import MagicMock, AsyncMock
from pretor.core.global_state_machine.provider_manager import ProviderManager
@pytest.mark.asyncio
async def test_provider_manager_init():
mock_postgres = MagicMock()
mock_provider1 = MagicMock()
mock_provider1.provider_title = "title1"
mock_provider2 = MagicMock()
mock_provider2.provider_title = "title2"
mock_postgres.get_provider = MagicMock()
mock_postgres.get_provider.remote = AsyncMock(return_value=[mock_provider1, mock_provider2])
manager = ProviderManager(mock_postgres)
mock_postgres.provider_database = MagicMock()
mock_postgres.provider_database.remote = AsyncMock(return_value=[mock_provider1, mock_provider2])
await manager.init_provider_register(mock_postgres)
assert "openai" in manager.provider_mapper
assert "claude" in manager.provider_mapper
assert manager.provider_register["title1"] == mock_provider1
@@ -0,0 +1,5 @@
from pretor.core.global_state_machine.tool_manager import GlobalToolManager
def test_global_tool_manager_init():
manager = GlobalToolManager()
assert isinstance(manager, GlobalToolManager)
+168
View File
@@ -0,0 +1,168 @@
import pytest
from unittest.mock import MagicMock, AsyncMock, patch
import asyncio
import sys
import builtins
real_import = builtins.__import__
def mock_import(name, globals=None, locals=None, fromlist=(), level=0):
if name == 'ray':
mock_ray = MagicMock()
def mock_remote(*args, **kwargs):
if len(args) == 1 and callable(args[0]):
return args[0]
def decorator(cls):
return cls
return decorator
mock_ray.remote = mock_remote
return mock_ray
return real_import(name, globals, locals, fromlist, level)
builtins.__import__ = mock_import
for mod in list(sys.modules.keys()):
if 'pretor.core.workflow.workflow_runner' in mod or 'ray' in mod:
del sys.modules[mod]
from pretor.core.workflow.workflow_runner import WorkflowEngine, WorkflowRunningEngine
builtins.__import__ = real_import
@pytest.fixture
def mock_ray():
with patch("pretor.core.workflow.workflow_runner.ray") as mock_ray:
mock_ray.get = lambda x: x
yield mock_ray
def test_workflow_engine_init():
mock_wf = MagicMock()
mock_wf.work_link = []
engine = WorkflowEngine(mock_wf, "conscious", "control", "supervisor")
assert engine.workflow == mock_wf
assert engine.consciousness_node == "conscious"
assert engine.control_node == "control"
assert engine.supervisory_node == "supervisor"
@pytest.mark.asyncio
async def test_workflow_engine_run():
from pretor.core.workflow.workflow import PretorWorkflow, WorkStep, WorkflowStatus
mock_wf = MagicMock(spec=PretorWorkflow)
step1 = MagicMock(spec=WorkStep)
step1.step = 1
step1.status = "waiting"
step1.node = "control_node"
step1.name = "mock_name"
step1.desc = "mock_desc"
step1.action = "mock_action"
step1.inputs = []
step1.outputs = "res"
step1.logic_gate = None
mock_wf.work_link = [step1]
mock_status = MagicMock(spec=WorkflowStatus)
mock_status.step = 1
mock_status.status = "running"
mock_wf.status = mock_status
mock_wf.context_memory = {}
mock_wf.title = "mock_title"
mock_wf.trace_id = "mock_trace_id"
mock_wf.command = "mock_command"
mock_wf.event_info = MagicMock()
mock_wf.event_info.platform = "test"
mock_wf.event_info.user_name = "test_user"
mock_control = MagicMock()
mock_control.working.remote = AsyncMock(return_value="process_result")
mock_conscious = MagicMock()
mock_conscious.working.remote = AsyncMock(return_value="report")
mock_supervisor = MagicMock()
mock_supervisor.working.remote = AsyncMock(return_value="response")
engine = WorkflowEngine(mock_wf, mock_conscious, mock_control, mock_supervisor)
with patch("pretor.core.workflow.workflow_runner.ray") as mock_ray_patch:
mock_gsm = MagicMock()
mock_ray_patch.get_actor.return_value = mock_gsm
await engine.run()
assert step1.status == "completed"
assert mock_wf.context_memory["res"] == "process_result"
def test_workflow_running_engine_init():
engine = WorkflowRunningEngine("conscious", "control", "supervisor")
assert engine.consciousness_node == "conscious"
assert engine.control_node == "control"
assert engine.supervisory_node == "supervisor"
@pytest.mark.asyncio
async def test_workflow_running_engine_submit():
engine = WorkflowRunningEngine("conscious", "control", "supervisor")
engine.workflow_queue = asyncio.Queue()
mock_wf = MagicMock()
await engine.workflow_queue.put(mock_wf)
item = await engine.workflow_queue.get()
assert item == mock_wf
@pytest.mark.asyncio
async def test_workflow_running_engine_runner():
from pretor.api.platform.event import PretorEvent
from pretor.core.individual.consciousness_node.template import ForWorkflowEngine
mock_consciousness = MagicMock()
mock_wf = MagicMock()
mock_wf.trace_id = "test_trace"
mock_wf.title = "test_title"
mock_result = MagicMock(spec=ForWorkflowEngine)
mock_result.workflow = mock_wf
mock_consciousness.working.remote = AsyncMock(return_value=mock_result)
engine = WorkflowRunningEngine(mock_consciousness, "control", "supervisor")
engine.workflow_queue = asyncio.Queue()
mock_event = PretorEvent(
platform="test_platform",
user_id="test_user",
user_name="test_user",
message="test_message",
context={"workflow_template": "test_template"}
)
await engine.workflow_queue.put(mock_event)
# Mock the global_state_machine get_skill_list.remote method properly
mock_gsm = MagicMock()
mock_gsm.get_skill_list.remote = AsyncMock(return_value={"test_skill": ("description", "instructions")})
engine.global_state_machine = mock_gsm
with patch("pretor.core.workflow.workflow_runner.WorkflowEngine") as mock_wf_engine_cls, patch("builtins.open", new_callable=MagicMock) as mock_open:
# Instead of patching hook, we inject it directly
engine.global_state_machine = AsyncMock()
mock_open.return_value.__enter__.return_value.read.return_value = '{}'
mock_engine_instance = MagicMock()
mock_engine_instance.run = AsyncMock()
mock_wf_engine_cls.return_value = mock_engine_instance
task = asyncio.create_task(engine.runner(1))
await asyncio.sleep(0.05)
task.cancel()
mock_wf_engine_cls.assert_called_with(mock_wf, mock_consciousness, "control", "supervisor")
@@ -0,0 +1,43 @@
from unittest.mock import patch, MagicMock
from pretor.core.workflow.workflow_template_generator.workflow_template_generator import WorkflowTemplateGenerator
@patch("pretor.core.workflow.workflow_template_generator.workflow_template_generator.Path")
def test_generate_workflow_template(mock_path):
mock_dir = MagicMock()
mock_dir.exists.return_value = False
mock_file = MagicMock()
mock_dir.__truediv__.return_value = mock_file
mock_open_ctx = MagicMock()
mock_file.open.return_value.__enter__.return_value = mock_open_ctx
mock_path_root = MagicMock()
mock_path_root.__truediv__.return_value = mock_dir
mock_path.return_value = mock_path_root
from pretor.core.workflow.workflow_template_generator.workflow_template import WorkflowTemplate
generator = WorkflowTemplateGenerator()
mock_template = MagicMock(spec=WorkflowTemplate)
mock_template.name = "test_wf"
mock_template.desc = "test_desc"
import json
mock_template.model_dump_json.return_value = json.dumps({
"name": "test_wf",
"desc": "test_desc",
"work_link": [{"step": 1, "node": "n", "action": "a", "desc": "d", "input": [], "output": [], "logic_gate": {}}]
})
generator.generate_workflow_template(
workflow_template=mock_template
)
mock_dir.mkdir.assert_called_once_with(parents=True)
mock_file.open.assert_called_once_with("w", encoding="utf-8")
mock_open_ctx.write.assert_called_once()
write_arg = mock_open_ctx.write.call_args[0][0]
written_data = json.loads(write_arg)
assert written_data["name"] == "test_wf"
assert written_data["desc"] == "test_desc"
assert written_data["work_link"][0]["step"] == 1
@@ -0,0 +1,36 @@
import pytest
from pydantic import ValidationError
from pretor.core.workflow.workflow_template_generator.workflow_template import WorkflowTemplateStep, WorkflowTemplate
def test_workflow_template_step():
step = WorkflowTemplateStep(
step=1,
node="node_type",
action="act",
desc="desc",
input=["in1"],
output=["out1"],
logic_gate={"if_pass": "next"}
)
assert step.step == 1
assert step.node == "node_type"
def test_workflow_template_success():
step1 = WorkflowTemplateStep(
step=1, node="node1", action="a1", desc="d1", input=[], output=[], logic_gate={}
)
step2 = WorkflowTemplateStep(
step=2, node="node2", action="a2", desc="d2", input=[], output=[], logic_gate={}
)
wt = WorkflowTemplate(name="temp", desc="desc", work_link=[step1, step2])
assert wt.name == "temp"
def test_workflow_template_error_duplicate_steps():
step1 = WorkflowTemplateStep(
step=1, node="node1", action="a1", desc="d1", input=[], output=[], logic_gate={}
)
step2 = WorkflowTemplateStep(
step=1, node="node2", action="a2", desc="d2", input=[], output=[], logic_gate={}
)
with pytest.raises(ValidationError, match="Step numbers in work_link must be unique"):
WorkflowTemplate(name="temp", desc="desc", work_link=[step1, step2])
@@ -0,0 +1,57 @@
import json
from unittest.mock import MagicMock, patch, mock_open
from pathlib import Path
from pretor.core.workflow.workflow_template_manager import WorkflowManager
def test_workflow_manager_init_success():
mock_file1 = MagicMock(spec=Path)
mock_file1.open = mock_open(read_data=json.dumps({"name": "test1", "desc": "desc1"}))
mock_file2 = MagicMock(spec=Path)
mock_file2.open = mock_open(read_data=json.dumps({"name": "test2", "desc": "desc2"}))
with patch("pretor.core.workflow.workflow_template_manager.Path.glob", return_value=[mock_file1, mock_file2]):
with patch("pretor.core.workflow.workflow_template_manager.WorkflowTemplateGenerator"):
manager = WorkflowManager()
assert manager.workflow_templates_registry == {"test1": "desc1", "test2": "desc2"}
def test_workflow_manager_init_json_error():
mock_file1 = MagicMock(spec=Path)
mock_file1.open = mock_open(read_data="{invalid_json}")
with patch("pretor.core.workflow.workflow_template_manager.Path.glob", return_value=[mock_file1]):
with patch("pretor.core.workflow.workflow_template_manager.logger") as mock_logger:
with patch("pretor.core.workflow.workflow_template_manager.WorkflowTemplateGenerator"):
manager = WorkflowManager()
assert manager.workflow_templates_registry == {}
mock_logger.warning.assert_called_once()
assert "不是json文件或格式错误" in mock_logger.warning.call_args[0][0]
from pretor.core.workflow.workflow_template_generator.workflow_template import WorkflowTemplate
@patch("pretor.core.workflow.workflow_template_manager.WorkflowTemplateGenerator")
@patch("pretor.core.workflow.workflow_template_manager.Path.glob", return_value=[])
def test_generate_workflow_template_success(mock_glob, mock_generator_cls):
manager = WorkflowManager()
mock_template = MagicMock(spec=WorkflowTemplate)
mock_template.name = "name"
mock_template.desc = "desc"
mock_generator_cls.return_value.generate_workflow_template.return_value = mock_template
manager.generate_workflow_template(mock_template)
mock_generator_cls.return_value.generate_workflow_template.assert_called_once_with(workflow_template=mock_template)
assert manager.workflow_templates_registry["name"] == "desc"
@patch("pretor.core.workflow.workflow_template_manager.WorkflowTemplateGenerator")
@patch("pretor.core.workflow.workflow_template_manager.Path.glob", return_value=[])
@patch("pretor.core.workflow.workflow_template_manager.logger")
def test_generate_workflow_template_exception(mock_logger, mock_glob, mock_generator_cls):
mock_generator_cls.return_value.generate_workflow_template.side_effect = Exception("error")
manager = WorkflowManager()
mock_template = MagicMock(spec=WorkflowTemplate)
manager.generate_workflow_template(mock_template)
mock_logger.exception.assert_called_once_with("Failed to generate workflow template")
+47
View File
@@ -0,0 +1,47 @@
import pytest
from pretor.core.workflow.workflow import WorkStep, PretorWorkflow, WorkflowStatus, LogicGate
def test_work_step():
ws = WorkStep(
step=1,
name="step1",
node="control_node",
action="coding",
desc="Write some code"
)
assert ws.step == 1
assert ws.name == "step1"
assert ws.node == "control_node"
assert ws.action == "coding"
assert ws.desc == "Write some code"
assert ws.status == "waiting"
def test_pretor_workflow_validation_success():
ws1 = WorkStep(step=1, name="s1", node="control_node", action="a1", desc="d1")
ws2 = WorkStep(step=2, name="s2", node="supervisory_node", action="a2", desc="d2")
wf = PretorWorkflow(title="wf1", work_link=[ws1, ws2], trace_id="t", event_info={"platform":"a", "user_name":"b"})
assert wf.title == "wf1"
def test_pretor_workflow_validation_error_step_discontinuous():
ws1 = WorkStep(step=1, name="s1", node="control_node", action="a1", desc="d1")
ws2 = WorkStep(step=3, name="s3", node="supervisory_node", action="a2", desc="d2")
with pytest.raises(ValueError, match="工作链步数不连续"):
PretorWorkflow(title="wf1", work_link=[ws1, ws2], trace_id="t", event_info={"platform":"a", "user_name":"b"})
def test_pretor_workflow_validation_error_jump_out_of_bounds():
lg = LogicGate(if_fail="jump_to_step_3", if_pass="continue")
ws1 = WorkStep(step=1, name="s1", node="control_node", action="a1", desc="d1", logic_gate=lg)
ws2 = WorkStep(step=2, name="s2", node="supervisory_node", action="a2", desc="d2")
with pytest.raises(ValueError, match="跳转目标 Step 3 越界了"):
PretorWorkflow(title="wf1", work_link=[ws1, ws2], trace_id="t", event_info={"platform":"a", "user_name":"b"})
def test_pretor_workflow_validation_error_jump_format_error():
lg = LogicGate(if_fail="jump_to_step_invalid", if_pass="continue")
ws1 = WorkStep(step=1, name="s1", node="control_node", action="a1", desc="d1", logic_gate=lg)
with pytest.raises(ValueError, match="LogicGate 格式错误"):
PretorWorkflow(title="wf1", work_link=[ws1], trace_id="t", event_info={"platform":"a", "user_name":"b"})
def test_workflow_status():
status = WorkflowStatus()
assert status.step == 1
assert status.status == "waiting_llm_working"
+96
View File
@@ -0,0 +1,96 @@
import sys
from unittest.mock import MagicMock, patch
# Mock dependencies before importing the module under test
class MockHTTPException(Exception):
def __init__(self, status_code, detail=None, headers=None):
self.status_code = status_code
self.detail = detail
self.headers = headers
class MockValidationError(Exception):
pass
mock_fastapi = MagicMock()
mock_fastapi.HTTPException = MockHTTPException
mock_fastapi.status.HTTP_401_UNAUTHORIZED = 401
mock_pydantic = MagicMock()
mock_pydantic.ValidationError = MockValidationError
sys.modules["fastapi"] = mock_fastapi
sys.modules["pydantic"] = mock_pydantic
sys.modules["sqlmodel"] = MagicMock()
sys.modules["passlib"] = MagicMock()
sys.modules["passlib.context"] = MagicMock()
sys.modules["pretor.core.database.table.user"] = MagicMock()
import pytest
import jwt
from pretor.utils.access import Accessor
def test_decode_token_success():
"""Test successful token decoding."""
token = "valid.token.here"
payload = {"user_id": "123", "username": "testuser", "exp": 1234567890}
with patch("jwt.decode", return_value=payload) as mock_decode:
with patch("pretor.utils.access.TokenData") as mock_token_data_cls:
mock_token_data_instance = MagicMock()
mock_token_data_cls.return_value = mock_token_data_instance
result = Accessor._decode_token(token)
mock_decode.assert_called_once()
mock_token_data_cls.assert_called_once_with(**payload)
assert result == mock_token_data_instance
def test_decode_token_expired():
"""Test token decoding with an expired token."""
token = "expired.token.here"
from fastapi import HTTPException
with patch("jwt.decode", side_effect=jwt.ExpiredSignatureError):
with patch("pretor.utils.access.HTTPException", HTTPException):
with pytest.raises(HTTPException) as excinfo:
Accessor._decode_token(token)
assert excinfo.value.status_code == 401
assert excinfo.value.detail == "Token 已过期"
def test_decode_token_invalid():
"""Test token decoding with an invalid token."""
token = "invalid.token.here"
from fastapi import HTTPException
with patch("jwt.decode", side_effect=jwt.InvalidTokenError):
with patch("pretor.utils.access.HTTPException", HTTPException):
with pytest.raises(HTTPException) as excinfo:
Accessor._decode_token(token)
assert excinfo.value.status_code == 401
assert excinfo.value.detail == "无效的认证凭证"
def test_decode_token_validation_error():
"""Test token decoding with a payload that fails validation."""
token = "valid.jwt.invalid.payload"
payload = {"wrong": "payload"}
from fastapi import HTTPException
with patch("jwt.decode", return_value=payload):
with patch("pretor.utils.access.TokenData", side_effect=MockValidationError):
with patch("pretor.utils.access.ValidationError", MockValidationError):
with patch("pretor.utils.access.HTTPException", HTTPException):
with pytest.raises(HTTPException) as excinfo:
Accessor._decode_token(token)
assert excinfo.value.status_code == 401
assert excinfo.value.detail == "无效的认证凭证"