refactor(core): decouple actors and remove workflow templates (#67)

Removes the deprecated `workflow_template` concept entirely across both backend API routers, internal logic handling within the `supervisory_node` and `consciousness_node`, and front-end components. Enables `consciousness_node` to work autonomously.

Also refactors core package structure to enforce the "one python package, one Ray Actor" architectural rule. `GlobalWorkflowManager`, `WorkflowRunningEngine`, `PostgresDatabase`, and `WorkerCluster` have been moved to their own top-level decoupled package directories with properly exported `__init__.py` modules. Test suites have been relocated and import paths updated across the system.

Co-authored-by: google-labs-jules[bot] <161369871+google-labs-jules[bot]@users.noreply.github.com>
Co-authored-by: zhaoxi826 <198742034+zhaoxi826@users.noreply.github.com>
This commit is contained in:
2026-05-06 15:05:47 +08:00
committed by GitHub
parent b3ea4cd8d9
commit 209ba45477
97 changed files with 1872 additions and 1498 deletions
@@ -12,8 +12,12 @@ def test_create_agent_success_real():
mock_provider.provider_url = "url"
with patch("pretor.adapter.model_adapter.agent_factory.Agent") as mock_agent_cls:
with patch("pretor.adapter.model_adapter.agent_factory.OpenAIChatModel") as mock_model_cls:
with patch("pretor.adapter.model_adapter.agent_factory.OpenAIProvider") as mock_provider_cls:
with patch(
"pretor.adapter.model_adapter.agent_factory.OpenAIChatModel"
) as mock_model_cls:
with patch(
"pretor.adapter.model_adapter.agent_factory.OpenAIProvider"
) as mock_provider_cls:
factory = AgentFactory()
agent = factory.create_agent(
provider=mock_provider,
@@ -21,17 +25,19 @@ def test_create_agent_success_real():
output_type=str,
system_prompt="You are an AI",
deps_type=dict,
agent_name="myagent"
agent_name="myagent",
)
mock_provider_cls.assert_called_once_with(api_key="key", base_url="url")
mock_model_cls.assert_called_once_with("gpt-4", provider=mock_provider_cls.return_value)
mock_model_cls.assert_called_once_with(
"gpt-4", provider=mock_provider_cls.return_value
)
mock_agent_cls.assert_called_once_with(
model=mock_model_cls.return_value,
name="myagent",
system_prompt="You are an AI",
output_type=str,
deps_type=dict,
tools=None
deps_type=dict,
tools=None,
)
assert agent == mock_agent_cls.return_value
@@ -5,34 +5,42 @@ from pydantic import ValidationError
from pretor.utils.error import UserNotExistError
from pretor.core.database.database_exception import database_exception
@database_exception
async def success_func():
return "success"
@database_exception
async def validation_error_func():
raise ValidationError.from_exception_data(title="Mock", line_errors=[])
@database_exception
async def integrity_error_func():
raise IntegrityError("mock_statement", "mock_params", "mock_orig")
@database_exception
async def operational_error_func():
raise OperationalError("mock_statement", "mock_params", "mock_orig")
@database_exception
async def user_not_exist_error_func():
raise UserNotExistError("mock user")
@database_exception
async def exception_func():
raise Exception("mock generic exception")
@pytest.mark.asyncio
async def test_success_func():
assert await success_func() == "success"
@pytest.mark.asyncio
@patch("pretor.core.database.database_exception.logger")
async def test_validation_error(mock_logger):
@@ -41,6 +49,7 @@ async def test_validation_error(mock_logger):
mock_logger.error.assert_called_once()
assert "对象校验失败" in mock_logger.error.call_args[0][0]
@pytest.mark.asyncio
@patch("pretor.core.database.database_exception.logger")
async def test_integrity_error(mock_logger):
@@ -49,6 +58,7 @@ async def test_integrity_error(mock_logger):
mock_logger.error.assert_called_once()
assert "数据库完整性错误" in mock_logger.error.call_args[0][0]
@pytest.mark.asyncio
@patch("pretor.core.database.database_exception.logger")
async def test_operational_error(mock_logger):
@@ -57,6 +67,7 @@ async def test_operational_error(mock_logger):
mock_logger.error.assert_called_once()
assert "数据库连接异常" in mock_logger.error.call_args[0][0]
@pytest.mark.asyncio
@patch("pretor.core.database.database_exception.logger")
async def test_user_not_exist_error(mock_logger):
@@ -65,6 +76,7 @@ async def test_user_not_exist_error(mock_logger):
mock_logger.error.assert_called_once()
assert "更改密码失败,用户不存在" in mock_logger.error.call_args[0][0]
@pytest.mark.asyncio
@patch("pretor.core.database.database_exception.logger")
async def test_generic_exception(mock_logger):
+10
View File
@@ -26,6 +26,7 @@ def mock_session_maker():
async def test_add_user(mock_session_maker, mock_dependencies):
mock_user_cls, _ = mock_dependencies
from pretor.core.database.module.user import AuthDatabase
maker, session = mock_session_maker
db = AuthDatabase(maker)
@@ -51,6 +52,7 @@ async def test_add_user(mock_session_maker, mock_dependencies):
async def test_change_password_success(mock_session_maker, mock_dependencies):
mock_user_cls, mock_select = mock_dependencies
from pretor.core.database.module.user import AuthDatabase
maker, session = mock_session_maker
db = AuthDatabase(maker)
@@ -79,6 +81,7 @@ async def test_change_password_success(mock_session_maker, mock_dependencies):
async def test_change_password_user_not_exist(mock_session_maker, mock_dependencies):
mock_user_cls, mock_select = mock_dependencies
from pretor.core.database.module.user import AuthDatabase
maker, session = mock_session_maker
db = AuthDatabase(maker)
@@ -94,10 +97,12 @@ async def test_change_password_user_not_exist(mock_session_maker, mock_dependenc
async def test_change_password_wrong_password(mock_session_maker, mock_dependencies):
mock_user_cls, mock_select = mock_dependencies
from pretor.core.database.module.user import AuthDatabase
maker, session = mock_session_maker
db = AuthDatabase(maker)
from pretor.utils.access import Accessor
mock_user = MagicMock()
mock_user.hashed_password = Accessor.hash_password("actual_password")
mock_exec_result = MagicMock()
@@ -105,6 +110,7 @@ async def test_change_password_wrong_password(mock_session_maker, mock_dependenc
session.execute = AsyncMock(return_value=mock_exec_result)
from pretor.utils.error import UserPasswordError
with pytest.raises(UserPasswordError):
await db.change_password("testuser", "old_password", "new_password")
@@ -113,6 +119,7 @@ async def test_change_password_wrong_password(mock_session_maker, mock_dependenc
async def test_delete_user_success(mock_session_maker, mock_dependencies):
mock_user_cls, mock_select = mock_dependencies
from pretor.core.database.module.user import AuthDatabase
maker, session = mock_session_maker
db = AuthDatabase(maker)
@@ -134,6 +141,7 @@ async def test_delete_user_success(mock_session_maker, mock_dependencies):
async def test_delete_user_not_exist(mock_session_maker, mock_dependencies):
mock_user_cls, mock_select = mock_dependencies
from pretor.core.database.module.user import AuthDatabase
maker, session = mock_session_maker
db = AuthDatabase(maker)
@@ -149,6 +157,7 @@ async def test_delete_user_not_exist(mock_session_maker, mock_dependencies):
async def test_login_user_success(mock_session_maker, mock_dependencies):
mock_user_cls, mock_select = mock_dependencies
from pretor.core.database.module.user import AuthDatabase
maker, session = mock_session_maker
db = AuthDatabase(maker)
@@ -169,6 +178,7 @@ async def test_login_user_success(mock_session_maker, mock_dependencies):
async def test_login_user_not_exist(mock_session_maker, mock_dependencies):
mock_user_cls, mock_select = mock_dependencies
from pretor.core.database.module.user import AuthDatabase
maker, session = mock_session_maker
db = AuthDatabase(maker)
@@ -1,5 +1,6 @@
from pretor.core.database.table.provider import Provider
def test_provider_table():
# Provide required fields
provider = Provider(
@@ -8,7 +9,7 @@ def test_provider_table():
provider_apikey="key",
provider_models=["model_1"],
provider_type="type",
provider_owner=1
provider_owner=1,
)
assert Provider.__tablename__ == 'provider'
assert Provider.__tablename__ == "provider"
assert provider.provider_title == "title"
+2 -1
View File
@@ -1,6 +1,7 @@
from pretor.core.database.table.user import User
def test_user_table():
user = User(user_id="id", user_name="name", hashed_password="pw")
assert User.__tablename__ == 'user'
assert User.__tablename__ == "user"
assert user.user_name == "name"
@@ -7,14 +7,16 @@ real_import = builtins.__import__
def mock_import(name, globals=None, locals=None, fromlist=(), level=0):
if name == 'ray':
if name == "ray":
mock_ray = MagicMock()
def mock_remote(*args, **kwargs):
if len(args) == 1 and callable(args[0]):
return args[0]
def decorator(cls):
return cls
return decorator
mock_ray.remote = mock_remote
@@ -25,10 +27,10 @@ def mock_import(name, globals=None, locals=None, fromlist=(), level=0):
builtins.__import__ = mock_import
for mod in list(sys.modules.keys()):
if 'pretor.core.global_state_machine.global_state_machine' in mod or 'ray' in mod:
if "pretor.core.global_state_machine.global_state_machine" in mod or "ray" in mod:
del sys.modules[mod]
from pretor.core.global_state_machine.global_state_machine import GlobalStateMachine
from pretor.core.global_state_machine.global_state_machine import GlobalStateMachine # noqa: E402
builtins.__import__ = real_import
@@ -82,13 +84,17 @@ async def test_add_provider_unsupported(gsm):
@pytest.mark.asyncio
async def test_add_provider_request_error(gsm):
from httpx import RequestError
mock_provider_class = AsyncMock()
mock_provider_class.create_provider.side_effect = RequestError("Network Error", request=MagicMock())
mock_provider_class.create_provider.side_effect = RequestError(
"Network Error", request=MagicMock()
)
gsm._global_provider_manager.provider_mapper = {"openai": mock_provider_class}
with patch("pretor.utils.logger.global_logger.bind") as mock_bind:
from pretor.utils.error import RetryableError
import pytest
mock_logger = MagicMock()
mock_bind.return_value = mock_logger
with pytest.raises(RetryableError):
@@ -117,3 +123,4 @@ def test_get_provider_list_and_get_provider(gsm):
assert gsm._global_provider_manager.get_provider_list() == {"p1": mock_provider}
assert gsm._global_provider_manager.get_provider("p1") == mock_provider
assert gsm._global_provider_manager.get_provider("missing") is None
# noqa: E402
@@ -1,25 +1,32 @@
from pretor.core.global_state_machine.model_provider.base_provider import Provider, ProviderArgs, ProviderStatus
from pretor.core.global_state_machine.model_provider.base_provider import (
Provider,
ProviderArgs,
ProviderStatus,
)
def test_provider_status():
assert ProviderStatus.UP == "up"
assert ProviderStatus.DOWN == "down"
def test_provider_args():
args = ProviderArgs(
provider_title="title",
provider_url="url",
provider_apikey="key",
provider_owner="1"
provider_owner="1",
)
assert args.provider_title == "title"
def test_provider_model():
p = Provider(
provider_title="title",
provider_url="url",
provider_apikey="key",
provider_models=["model"],
provider_type="openai"
provider_type="openai",
)
assert p.provider_status == ProviderStatus.UP
assert p.provider_owner is None
@@ -1,6 +1,9 @@
import pytest
from unittest.mock import patch, MagicMock, AsyncMock
from pretor.core.global_state_machine.model_provider.claude_provider import ClaudeProvider, ProviderArgs
from pretor.core.global_state_machine.model_provider.claude_provider import (
ClaudeProvider,
ProviderArgs,
)
@pytest.fixture
@@ -9,12 +12,14 @@ def provider_args():
provider_title="TestClaude",
provider_url="https://api.anthropic.com",
provider_apikey="testkey",
provider_owner="1"
provider_owner="1",
)
@pytest.mark.asyncio
@patch("pretor.core.global_state_machine.model_provider.claude_provider.httpx.AsyncClient")
@patch(
"pretor.core.global_state_machine.model_provider.claude_provider.httpx.AsyncClient"
)
async def test_load_models_success(mock_client, provider_args):
mock_response = MagicMock()
mock_response.status_code = 200
@@ -31,7 +36,9 @@ async def test_load_models_success(mock_client, provider_args):
@pytest.mark.asyncio
@patch("pretor.core.global_state_machine.model_provider.claude_provider.httpx.AsyncClient")
@patch(
"pretor.core.global_state_machine.model_provider.claude_provider.httpx.AsyncClient"
)
async def test_load_models_error(mock_client, provider_args):
mock_client_instance = AsyncMock()
mock_client_instance.get.side_effect = Exception("network error")
@@ -42,8 +49,10 @@ async def test_load_models_error(mock_client, provider_args):
@pytest.mark.asyncio
@patch("pretor.core.global_state_machine.model_provider.claude_provider.ClaudeProvider._load_models",
return_value=["claude-3"])
@patch(
"pretor.core.global_state_machine.model_provider.claude_provider.ClaudeProvider._load_models",
return_value=["claude-3"],
)
async def test_create_provider(mock_load, provider_args):
provider = await ClaudeProvider.create_provider(provider_args)
assert provider.provider_title == "TestClaude"
@@ -1,6 +1,9 @@
import pytest
from unittest.mock import patch, MagicMock, AsyncMock
from pretor.core.global_state_machine.model_provider.openai_provider import OpenAIProvider, ProviderArgs
from pretor.core.global_state_machine.model_provider.openai_provider import (
OpenAIProvider,
ProviderArgs,
)
@pytest.fixture
@@ -9,7 +12,7 @@ def provider_args():
provider_title="TestOpenAI",
provider_url="https://api.openai.com/v1",
provider_apikey="testkey",
provider_owner="1"
provider_owner="1",
)
@@ -19,12 +22,14 @@ def provider_args_no_v1():
provider_title="TestOpenAI",
provider_url="https://api.openai.com",
provider_apikey="testkey",
provider_owner="1"
provider_owner="1",
)
@pytest.mark.asyncio
@patch("pretor.core.global_state_machine.model_provider.openai_provider.httpx.AsyncClient")
@patch(
"pretor.core.global_state_machine.model_provider.openai_provider.httpx.AsyncClient"
)
async def test_load_models_success(mock_client, provider_args):
mock_response = MagicMock()
mock_response.status_code = 200
@@ -40,12 +45,14 @@ async def test_load_models_success(mock_client, provider_args):
assert models == ["gpt-3.5-turbo", "gpt-4"]
mock_client_instance.get.assert_called_once_with(
"https://api.openai.com/v1/models",
headers={"Authorization": "Bearer testkey", "Content-Type": "application/json"}
headers={"Authorization": "Bearer testkey", "Content-Type": "application/json"},
)
@pytest.mark.asyncio
@patch("pretor.core.global_state_machine.model_provider.openai_provider.httpx.AsyncClient")
@patch(
"pretor.core.global_state_machine.model_provider.openai_provider.httpx.AsyncClient"
)
async def test_load_models_no_v1(mock_client, provider_args_no_v1):
mock_response = MagicMock()
mock_response.status_code = 200
@@ -59,12 +66,14 @@ async def test_load_models_no_v1(mock_client, provider_args_no_v1):
assert models == []
mock_client_instance.get.assert_called_once_with(
"https://api.openai.com/v1/models",
headers={"Authorization": "Bearer testkey", "Content-Type": "application/json"}
headers={"Authorization": "Bearer testkey", "Content-Type": "application/json"},
)
@pytest.mark.asyncio
@patch("pretor.core.global_state_machine.model_provider.openai_provider.httpx.AsyncClient")
@patch(
"pretor.core.global_state_machine.model_provider.openai_provider.httpx.AsyncClient"
)
async def test_load_models_status_error(mock_client, provider_args):
mock_response = MagicMock()
mock_response.status_code = 401
@@ -78,21 +87,29 @@ async def test_load_models_status_error(mock_client, provider_args):
@pytest.mark.asyncio
@patch("pretor.core.global_state_machine.model_provider.openai_provider.httpx.AsyncClient")
@patch(
"pretor.core.global_state_machine.model_provider.openai_provider.httpx.AsyncClient"
)
async def test_load_models_request_error(mock_client, provider_args):
import httpx
mock_client_instance = AsyncMock()
mock_client_instance.get.side_effect = httpx.RequestError("network error", request=MagicMock())
mock_client_instance.get.side_effect = httpx.RequestError(
"network error", request=MagicMock()
)
mock_client.return_value.__aenter__.return_value = mock_client_instance
import pytest
from pretor.utils.error import RetryableError
with pytest.raises(RetryableError):
await OpenAIProvider._load_models(provider_args)
@pytest.mark.asyncio
@patch("pretor.core.global_state_machine.model_provider.openai_provider.httpx.AsyncClient")
@patch(
"pretor.core.global_state_machine.model_provider.openai_provider.httpx.AsyncClient"
)
async def test_load_models_generic_error(mock_client, provider_args):
mock_client_instance = AsyncMock()
mock_client_instance.get.side_effect = Exception("generic error")
@@ -103,8 +120,10 @@ async def test_load_models_generic_error(mock_client, provider_args):
@pytest.mark.asyncio
@patch("pretor.core.global_state_machine.model_provider.openai_provider.OpenAIProvider._load_models",
return_value=["gpt-4"])
@patch(
"pretor.core.global_state_machine.model_provider.openai_provider.OpenAIProvider._load_models",
return_value=["gpt-4"],
)
async def test_create_provider(mock_load, provider_args):
provider = await OpenAIProvider.create_provider(provider_args)
assert provider.provider_title == "TestOpenAI"
@@ -13,11 +13,15 @@ async def test_provider_manager_init():
mock_provider2.provider_title = "title2"
mock_postgres.get_provider = MagicMock()
mock_postgres.get_provider.remote = AsyncMock(return_value=[mock_provider1, mock_provider2])
mock_postgres.get_provider.remote = AsyncMock(
return_value=[mock_provider1, mock_provider2]
)
manager = ProviderManager(mock_postgres)
mock_postgres.provider_database = MagicMock()
mock_postgres.provider_database.remote = AsyncMock(return_value=[mock_provider1, mock_provider2])
mock_postgres.provider_database.remote = AsyncMock(
return_value=[mock_provider1, mock_provider2]
)
await manager.init_provider_register(mock_postgres)
assert "openai" in manager.provider_mapper
@@ -1,5 +1,6 @@
from pretor.core.global_state_machine.tool_manager import GlobalToolManager
def test_global_tool_manager_init():
manager = GlobalToolManager()
assert isinstance(manager, GlobalToolManager)
@@ -7,14 +7,16 @@ real_import = builtins.__import__
def mock_import(name, globals=None, locals=None, fromlist=(), level=0):
if name == 'ray':
if name == "ray":
mock_ray = MagicMock()
def mock_remote(*args, **kwargs):
if len(args) == 1 and callable(args[0]):
return args[0]
def decorator(cls):
return cls
return decorator
mock_ray.remote = mock_remote
@@ -24,28 +26,30 @@ def mock_import(name, globals=None, locals=None, fromlist=(), level=0):
builtins.__import__ = mock_import
for mod in list(sys.modules.keys()):
if 'pretor.core.database.postgres' in mod or 'ray' in mod:
if "pretor.core.postgres_database.postgres" in mod or "ray" in mod:
del sys.modules[mod]
from pretor.core.database.postgres import PostgresDatabase
from pretor.core.postgres_database.postgres import PostgresDatabase # noqa: E402
builtins.__import__ = real_import
@patch("pretor.core.database.postgres.create_async_engine")
@patch("pretor.core.database.postgres.sessionmaker")
@patch("pretor.core.database.postgres.AuthDatabase")
@patch("pretor.core.database.postgres.ProviderDatabase")
@patch("pretor.core.database.postgres.os.environ.get")
@patch("pretor.core.postgres_database.postgres.create_async_engine")
@patch("pretor.core.postgres_database.postgres.sessionmaker")
@patch("pretor.core.postgres_database.postgres.AuthDatabase")
@patch("pretor.core.postgres_database.postgres.ProviderDatabase")
@patch("pretor.core.postgres_database.postgres.os.environ.get")
@pytest.mark.asyncio
async def test_postgres_database(mock_env_get, mock_provider_db, mock_auth_db, mock_sessionmaker, mock_create_engine):
async def test_postgres_database(
mock_env_get, mock_provider_db, mock_auth_db, mock_sessionmaker, mock_create_engine
):
def env_side_effect(key):
return {
"POSTGRES_USER": "testuser",
"POSTGRES_PASSWORD": "testpassword",
"POSTGRES_HOST": "localhost",
"POSTGRES_PORT": "5432",
"POSTGRES_DB": "testdb"
"POSTGRES_DB": "testdb",
}.get(key)
mock_env_get.side_effect = env_side_effect
@@ -53,6 +57,7 @@ async def test_postgres_database(mock_env_get, mock_provider_db, mock_auth_db, m
mock_engine = MagicMock()
mock_conn = MagicMock()
from unittest.mock import AsyncMock
mock_conn.run_sync = AsyncMock()
mock_begin_ctx = MagicMock()
@@ -64,15 +69,17 @@ async def test_postgres_database(mock_env_get, mock_provider_db, mock_auth_db, m
db = PostgresDatabase()
mock_create_engine.assert_called_once_with(
"postgresql+asyncpg://testuser:testpassword@localhost:5432/testdb",
echo=True
"postgresql+asyncpg://testuser:testpassword@localhost:5432/testdb", echo=True
)
mock_auth_db.assert_called_once()
mock_provider_db.assert_called_once()
mock_auth_db.return_value.get_user_authority = AsyncMock(return_value="test_auth")
with patch("pretor.core.database.postgres.SQLModel.metadata.create_all") as mock_create_all:
with patch(
"pretor.core.postgres_database.postgres.SQLModel.metadata.create_all"
) as mock_create_all:
await db.init_db()
mock_conn.run_sync.assert_called_once_with(mock_create_all)
assert await db.get_user_authority(user_id="123") == "test_auth"
# noqa: E402
@@ -1,43 +0,0 @@
from unittest.mock import patch, MagicMock
from pretor.core.workflow.workflow_template_generator.workflow_template_generator import WorkflowTemplateGenerator
@patch("pretor.core.workflow.workflow_template_generator.workflow_template_generator.Path")
def test_generate_workflow_template(mock_path):
mock_dir = MagicMock()
mock_dir.exists.return_value = False
mock_file = MagicMock()
mock_dir.__truediv__.return_value = mock_file
mock_open_ctx = MagicMock()
mock_file.open.return_value.__enter__.return_value = mock_open_ctx
mock_path_root = MagicMock()
mock_path_root.__truediv__.return_value = mock_dir
mock_path.return_value = mock_path_root
from pretor.core.workflow.workflow_template_generator.workflow_template import WorkflowTemplate
generator = WorkflowTemplateGenerator()
mock_template = MagicMock(spec=WorkflowTemplate)
mock_template.name = "test_wf"
mock_template.desc = "test_desc"
import json
mock_template.model_dump_json.return_value = json.dumps({
"name": "test_wf",
"desc": "test_desc",
"work_link": [{"step": 1, "node": "n", "action": "a", "desc": "d", "input": [], "output": [], "logic_gate": {}}]
})
generator.generate_workflow_template(
workflow_template=mock_template
)
mock_dir.mkdir.assert_called_once_with(parents=True)
mock_file.open.assert_called_once_with("w", encoding="utf-8")
mock_open_ctx.write.assert_called_once()
write_arg = mock_open_ctx.write.call_args[0][0]
written_data = json.loads(write_arg)
assert written_data["name"] == "test_wf"
assert written_data["desc"] == "test_desc"
assert written_data["work_link"][0]["step"] == 1
@@ -1,36 +0,0 @@
import pytest
from pydantic import ValidationError
from pretor.core.workflow.workflow_template_generator.workflow_template import WorkflowTemplateStep, WorkflowTemplate
def test_workflow_template_step():
step = WorkflowTemplateStep(
step=1,
node="node_type",
action="act",
desc="desc",
input=["in1"],
output=["out1"],
logic_gate={"if_pass": "next"}
)
assert step.step == 1
assert step.node == "node_type"
def test_workflow_template_success():
step1 = WorkflowTemplateStep(
step=1, node="node1", action="a1", desc="d1", input=[], output=[], logic_gate={}
)
step2 = WorkflowTemplateStep(
step=2, node="node2", action="a2", desc="d2", input=[], output=[], logic_gate={}
)
wt = WorkflowTemplate(name="temp", desc="desc", work_link=[step1, step2])
assert wt.name == "temp"
def test_workflow_template_error_duplicate_steps():
step1 = WorkflowTemplateStep(
step=1, node="node1", action="a1", desc="d1", input=[], output=[], logic_gate={}
)
step2 = WorkflowTemplateStep(
step=1, node="node2", action="a2", desc="d2", input=[], output=[], logic_gate={}
)
with pytest.raises(ValidationError, match="Step numbers in work_link must be unique"):
WorkflowTemplate(name="temp", desc="desc", work_link=[step1, step2])
@@ -1,57 +0,0 @@
import json
from unittest.mock import MagicMock, patch, mock_open
from pathlib import Path
from pretor.core.workflow.workflow_template_manager import WorkflowManager
def test_workflow_manager_init_success():
mock_file1 = MagicMock(spec=Path)
mock_file1.open = mock_open(read_data=json.dumps({"name": "test1", "desc": "desc1"}))
mock_file2 = MagicMock(spec=Path)
mock_file2.open = mock_open(read_data=json.dumps({"name": "test2", "desc": "desc2"}))
with patch("pretor.core.workflow.workflow_template_manager.Path.glob", return_value=[mock_file1, mock_file2]):
with patch("pretor.core.workflow.workflow_template_manager.WorkflowTemplateGenerator"):
manager = WorkflowManager()
assert manager.workflow_templates_registry == {"test1": "desc1", "test2": "desc2"}
def test_workflow_manager_init_json_error():
mock_file1 = MagicMock(spec=Path)
mock_file1.open = mock_open(read_data="{invalid_json}")
with patch("pretor.core.workflow.workflow_template_manager.Path.glob", return_value=[mock_file1]):
with patch("pretor.core.workflow.workflow_template_manager.logger") as mock_logger:
with patch("pretor.core.workflow.workflow_template_manager.WorkflowTemplateGenerator"):
manager = WorkflowManager()
assert manager.workflow_templates_registry == {}
mock_logger.warning.assert_called_once()
assert "不是json文件或格式错误" in mock_logger.warning.call_args[0][0]
from pretor.core.workflow.workflow_template_generator.workflow_template import WorkflowTemplate
@patch("pretor.core.workflow.workflow_template_manager.WorkflowTemplateGenerator")
@patch("pretor.core.workflow.workflow_template_manager.Path.glob", return_value=[])
def test_generate_workflow_template_success(mock_glob, mock_generator_cls):
manager = WorkflowManager()
mock_template = MagicMock(spec=WorkflowTemplate)
mock_template.name = "name"
mock_template.desc = "desc"
mock_generator_cls.return_value.generate_workflow_template.return_value = mock_template
manager.generate_workflow_template(mock_template)
mock_generator_cls.return_value.generate_workflow_template.assert_called_once_with(workflow_template=mock_template)
assert manager.workflow_templates_registry["name"] == "desc"
@patch("pretor.core.workflow.workflow_template_manager.WorkflowTemplateGenerator")
@patch("pretor.core.workflow.workflow_template_manager.Path.glob", return_value=[])
@patch("pretor.core.workflow.workflow_template_manager.logger")
def test_generate_workflow_template_exception(mock_logger, mock_glob, mock_generator_cls):
mock_generator_cls.return_value.generate_workflow_template.side_effect = Exception("error")
manager = WorkflowManager()
mock_template = MagicMock(spec=WorkflowTemplate)
manager.generate_workflow_template(mock_template)
mock_logger.exception.assert_called_once_with("Failed to generate workflow template")
+43 -8
View File
@@ -1,5 +1,11 @@
import pytest
from pretor.core.workflow.workflow import WorkStep, PretorWorkflow, WorkflowStatus, LogicGate
from pretor.core.workflow.workflow import (
WorkStep,
PretorWorkflow,
WorkflowStatus,
LogicGate,
)
def test_work_step():
ws = WorkStep(
@@ -7,7 +13,7 @@ def test_work_step():
name="step1",
node="control_node",
action="coding",
desc="Write some code"
desc="Write some code",
)
assert ws.step == 1
assert ws.name == "step1"
@@ -16,30 +22,59 @@ def test_work_step():
assert ws.desc == "Write some code"
assert ws.status == "waiting"
def test_pretor_workflow_validation_success():
ws1 = WorkStep(step=1, name="s1", node="control_node", action="a1", desc="d1")
ws2 = WorkStep(step=2, name="s2", node="supervisory_node", action="a2", desc="d2")
wf = PretorWorkflow(title="wf1", work_link=[ws1, ws2], trace_id="t", event_info={"platform":"a", "user_name":"b"})
wf = PretorWorkflow(
title="wf1",
work_link=[ws1, ws2],
trace_id="t",
event_info={"platform": "a", "user_name": "b"},
)
assert wf.title == "wf1"
def test_pretor_workflow_validation_error_step_discontinuous():
ws1 = WorkStep(step=1, name="s1", node="control_node", action="a1", desc="d1")
ws2 = WorkStep(step=3, name="s3", node="supervisory_node", action="a2", desc="d2")
with pytest.raises(ValueError, match="工作链步数不连续"):
PretorWorkflow(title="wf1", work_link=[ws1, ws2], trace_id="t", event_info={"platform":"a", "user_name":"b"})
PretorWorkflow(
title="wf1",
work_link=[ws1, ws2],
trace_id="t",
event_info={"platform": "a", "user_name": "b"},
)
def test_pretor_workflow_validation_error_jump_out_of_bounds():
lg = LogicGate(if_fail="jump_to_step_3", if_pass="continue")
ws1 = WorkStep(step=1, name="s1", node="control_node", action="a1", desc="d1", logic_gate=lg)
ws1 = WorkStep(
step=1, name="s1", node="control_node", action="a1", desc="d1", logic_gate=lg
)
ws2 = WorkStep(step=2, name="s2", node="supervisory_node", action="a2", desc="d2")
with pytest.raises(ValueError, match="跳转目标 Step 3 越界了"):
PretorWorkflow(title="wf1", work_link=[ws1, ws2], trace_id="t", event_info={"platform":"a", "user_name":"b"})
PretorWorkflow(
title="wf1",
work_link=[ws1, ws2],
trace_id="t",
event_info={"platform": "a", "user_name": "b"},
)
def test_pretor_workflow_validation_error_jump_format_error():
lg = LogicGate(if_fail="jump_to_step_invalid", if_pass="continue")
ws1 = WorkStep(step=1, name="s1", node="control_node", action="a1", desc="d1", logic_gate=lg)
ws1 = WorkStep(
step=1, name="s1", node="control_node", action="a1", desc="d1", logic_gate=lg
)
with pytest.raises(ValueError, match="LogicGate 格式错误"):
PretorWorkflow(title="wf1", work_link=[ws1], trace_id="t", event_info={"platform":"a", "user_name":"b"})
PretorWorkflow(
title="wf1",
work_link=[ws1],
trace_id="t",
event_info={"platform": "a", "user_name": "b"},
)
def test_workflow_status():
status = WorkflowStatus()
@@ -9,14 +9,16 @@ real_import = builtins.__import__
def mock_import(name, globals=None, locals=None, fromlist=(), level=0):
if name == 'ray':
if name == "ray":
mock_ray = MagicMock()
def mock_remote(*args, **kwargs):
if len(args) == 1 and callable(args[0]):
return args[0]
def decorator(cls):
return cls
return decorator
mock_ray.remote = mock_remote
@@ -26,16 +28,19 @@ def mock_import(name, globals=None, locals=None, fromlist=(), level=0):
builtins.__import__ = mock_import
for mod in list(sys.modules.keys()):
if 'pretor.core.workflow.workflow_runner' in mod or 'ray' in mod:
if "pretor.core.workflow_running_engine.workflow_runner" in mod or "ray" in mod:
del sys.modules[mod]
from pretor.core.workflow.workflow_runner import WorkflowEngine, WorkflowRunningEngine
from pretor.core.workflow_running_engine.workflow_runner import ( # noqa: E402
WorkflowEngine,
WorkflowRunningEngine,
)
builtins.__import__ = real_import
@pytest.fixture
def mock_ray():
with patch("pretor.core.workflow.workflow_runner.ray") as mock_ray:
with patch("pretor.core.workflow_running_engine.workflow_runner.ray") as mock_ray:
mock_ray.get = lambda x: x
yield mock_ray
@@ -91,7 +96,9 @@ async def test_workflow_engine_run():
engine = WorkflowEngine(mock_wf, mock_conscious, mock_control, mock_supervisor)
with patch("pretor.core.workflow.workflow_runner.ray") as mock_ray_patch:
with patch(
"pretor.core.workflow_running_engine.workflow_runner.ray"
) as mock_ray_patch:
mock_gsm = MagicMock()
mock_ray_patch.get_actor.return_value = mock_gsm
await engine.run()
@@ -141,22 +148,36 @@ async def test_workflow_running_engine_runner():
user_id="test_user",
user_name="test_user",
message="test_message",
context={"workflow_template": "test_template"}
context={},
)
await engine.workflow_queue.put(mock_event)
# Mock the global_state_machine get_skill_list.remote method properly
mock_gsm = MagicMock()
mock_gsm.list_individuals.remote = AsyncMock(return_value={"test_skill": {"agent_type": "skill_individual", "agent_name": "TestSkill", "description": "desc"}})
mock_gsm.list_individuals.remote = AsyncMock(
return_value={
"test_skill": {
"agent_type": "skill_individual",
"agent_name": "TestSkill",
"description": "desc",
}
}
)
engine.global_state_machine = mock_gsm
with patch("pretor.core.workflow.workflow_runner.WorkflowEngine") as mock_wf_engine_cls, patch("builtins.open", new_callable=MagicMock) as mock_open, \
patch("pretor.core.workflow.workflow_runner.ray_actor_hook") as mock_hook:
with (
patch(
"pretor.core.workflow_running_engine.workflow_runner.WorkflowEngine"
) as mock_wf_engine_cls,
patch("builtins.open", new_callable=MagicMock) as mock_open,
patch(
"pretor.core.workflow_running_engine.workflow_runner.ray_actor_hook"
) as mock_hook,
):
# Instead of patching hook, we inject it directly
# engine.global_state_machine = AsyncMock()
mock_open.return_value.__enter__.return_value.read.return_value = '{}'
mock_open.return_value.__enter__.return_value.read.return_value = "{}"
mock_gwm = MagicMock()
mock_gwm.update_workflow.remote = AsyncMock()
@@ -170,4 +191,7 @@ async def test_workflow_running_engine_runner():
await asyncio.sleep(0.05)
task.cancel()
mock_wf_engine_cls.assert_called_with(mock_wf, mock_consciousness, "control", "supervisor")
mock_wf_engine_cls.assert_called_with(
mock_wf, mock_consciousness, "control", "supervisor"
)
# noqa: E402
+7 -4
View File
@@ -28,9 +28,9 @@ sys.modules["passlib"] = MagicMock()
sys.modules["passlib.context"] = MagicMock()
sys.modules["pretor.core.database.table.user"] = MagicMock()
import pytest
import jwt
from pretor.utils.access import Accessor
import pytest # noqa: E402
import jwt # noqa: E402
from pretor.utils.access import Accessor # noqa: E402
def test_decode_token_success():
@@ -55,6 +55,7 @@ def test_decode_token_expired():
token = "expired.token.here"
from fastapi import HTTPException
with patch("jwt.decode", side_effect=jwt.ExpiredSignatureError):
with patch("pretor.utils.access.HTTPException", HTTPException):
with pytest.raises(HTTPException) as excinfo:
@@ -69,6 +70,7 @@ def test_decode_token_invalid():
token = "invalid.token.here"
from fastapi import HTTPException
with patch("jwt.decode", side_effect=jwt.InvalidTokenError):
with patch("pretor.utils.access.HTTPException", HTTPException):
with pytest.raises(HTTPException) as excinfo:
@@ -93,4 +95,5 @@ def test_decode_token_validation_error():
Accessor._decode_token(token)
assert excinfo.value.status_code == 401
assert excinfo.value.detail == "无效的认证凭证"
assert excinfo.value.detail == "无效的认证凭证"
# noqa: E402