wip: 增加测试
This commit is contained in:
parent
a04fc08735
commit
2432bc9e3b
4
main.py
4
main.py
|
|
@ -40,7 +40,9 @@ async def start_system():
|
||||||
pretor_gateway = PretorGateway.remote(
|
pretor_gateway = PretorGateway.remote(
|
||||||
postgres_database=postgres_database,
|
postgres_database=postgres_database,
|
||||||
global_state_machine=global_state_machine,
|
global_state_machine=global_state_machine,
|
||||||
supervisory_node=supervisory_node
|
supervisory_node=supervisory_node,
|
||||||
|
consciousness_node = consciousness_node,
|
||||||
|
control_node = control_node
|
||||||
)
|
)
|
||||||
|
|
||||||
# 挂起在网关服务上,暴露 8000 端口
|
# 挂起在网关服务上,暴露 8000 端口
|
||||||
|
|
|
||||||
|
|
@ -13,11 +13,40 @@
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
import pathlib
|
import pathlib
|
||||||
|
import importlib
|
||||||
|
import inspect
|
||||||
|
from collections import defaultdict
|
||||||
from pretor.tool_plugin.base_tool import BaseToolData
|
from pretor.tool_plugin.base_tool import BaseToolData
|
||||||
from typing import Dict, Type
|
from typing import Dict, Type
|
||||||
|
|
||||||
|
|
||||||
class GlobalToolManager:
|
class GlobalToolManager:
|
||||||
tool_mapper = Dict[str, Type[BaseToolData]]
|
tool_mapper: Dict[str, Dict[str, Type[BaseToolData]]]
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
pass
|
self.tool_mapper = defaultdict(dict)
|
||||||
|
|
||||||
|
tool_plugin_dir = pathlib.Path(__file__).parent.parent.parent / "tool_plugin"
|
||||||
|
if not tool_plugin_dir.exists() or not tool_plugin_dir.is_dir():
|
||||||
|
return
|
||||||
|
|
||||||
|
for item in tool_plugin_dir.iterdir():
|
||||||
|
if item.is_dir() and not item.name.startswith("__"):
|
||||||
|
plugin_name = item.name
|
||||||
|
module_name = f"pretor.tool_plugin.{plugin_name}"
|
||||||
|
|
||||||
|
try:
|
||||||
|
module = importlib.import_module(module_name)
|
||||||
|
for name, obj in inspect.getmembers(module, inspect.isclass):
|
||||||
|
if issubclass(obj, BaseToolData) and obj is not BaseToolData:
|
||||||
|
# It's a valid tool class
|
||||||
|
action_scopes = obj.model_fields.get("action_scope").default
|
||||||
|
|
||||||
|
if not action_scopes:
|
||||||
|
self.tool_mapper["default"][plugin_name] = obj
|
||||||
|
else:
|
||||||
|
for scope in action_scopes:
|
||||||
|
self.tool_mapper[scope][plugin_name] = obj
|
||||||
|
except Exception as e:
|
||||||
|
import logging
|
||||||
|
logging.getLogger(__name__).warning(f"Failed to load tool plugin {plugin_name}: {e}")
|
||||||
|
|
@ -1,14 +0,0 @@
|
||||||
# Copyright 2026 zhaoxi826
|
|
||||||
#
|
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
# you may not use this file except in compliance with the License.
|
|
||||||
# You may obtain a copy of the License at
|
|
||||||
#
|
|
||||||
# http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
#
|
|
||||||
# Unless required by applicable law or agreed to in writing, software
|
|
||||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
# See the License for the specific language governing permissions and
|
|
||||||
# limitations under the License.
|
|
||||||
|
|
||||||
|
|
@ -1,25 +0,0 @@
|
||||||
|
|
||||||
|
|
||||||
# Copyright 2026 zhaoxi826
|
|
||||||
#
|
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
# you may not use this file except in compliance with the License.
|
|
||||||
# You may obtain a copy of the License at
|
|
||||||
#
|
|
||||||
# http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
#
|
|
||||||
# Unless required by applicable law or agreed to in writing, software
|
|
||||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
# See the License for the specific language governing permissions and
|
|
||||||
# limitations under the License.
|
|
||||||
|
|
||||||
class ToolManager:
|
|
||||||
def __init__(self):
|
|
||||||
pass
|
|
||||||
|
|
||||||
def _load_tool_registry(self):
|
|
||||||
pass
|
|
||||||
|
|
||||||
def run_tool(self, tool_name, tool_desc):
|
|
||||||
pass
|
|
||||||
|
|
@ -67,7 +67,7 @@ class PretorWorkflow(BaseModel):
|
||||||
output: Dict[str, Any] = Field(default_factory=dict, description="工作流最终产出结果")
|
output: Dict[str, Any] = Field(default_factory=dict, description="工作流最终产出结果")
|
||||||
status: WorkflowStatus = Field(default_factory=WorkflowStatus, description="运行时状态对象")
|
status: WorkflowStatus = Field(default_factory=WorkflowStatus, description="运行时状态对象")
|
||||||
event_info: EventInfo | None = Field(default_factory=None)
|
event_info: EventInfo | None = Field(default_factory=None)
|
||||||
context_memory: Dict[str, Any] = Field(default=Dict())
|
context_memory: Dict[str, Any] = Field(default_factory=dict)
|
||||||
|
|
||||||
@model_validator(mode='after')
|
@model_validator(mode='after')
|
||||||
def validate_workflow_integrity(self) -> 'PretorWorkflow':
|
def validate_workflow_integrity(self) -> 'PretorWorkflow':
|
||||||
|
|
|
||||||
|
|
@ -15,12 +15,13 @@ from pydantic import Field
|
||||||
|
|
||||||
from pretor.tool_plugin.base_tool import BaseToolData
|
from pretor.tool_plugin.base_tool import BaseToolData
|
||||||
from pretor.utils.ray_hook import ray_actor_hook
|
from pretor.utils.ray_hook import ray_actor_hook
|
||||||
|
from typing import List, Literal, Dict
|
||||||
|
|
||||||
class ApprovalToolData(BaseToolData):
|
class ApprovalToolData(BaseToolData):
|
||||||
is_system = True
|
is_system: bool = True
|
||||||
action_scope = ["control_node", "consciousness_node",]
|
action_scope: List[Literal["control_node", "consciousness_node", "supervisory_node", "growth_node", "", ""]] = [
|
||||||
config_args = {}
|
"control_node", "consciousness_node"]
|
||||||
|
config_args: Dict[str, str] = {}
|
||||||
|
|
||||||
|
|
||||||
async def approval(message: str, event_id: str) -> str:
|
async def approval(message: str, event_id: str) -> str:
|
||||||
|
|
|
||||||
|
|
@ -14,8 +14,10 @@
|
||||||
|
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
from typing import List, Literal, Dict
|
from typing import List, Literal, Dict
|
||||||
|
from pydantic import ConfigDict
|
||||||
|
|
||||||
class BaseToolData(BaseModel):
|
class BaseToolData(BaseModel):
|
||||||
|
model_config = ConfigDict(extra="allow")
|
||||||
is_system: bool
|
is_system: bool
|
||||||
action_scope: List[Literal["control_node", "consciousness_node", "supervisory_node", "growth_node", "", ""]] = []
|
action_scope: List[Literal["control_node", "consciousness_node", "supervisory_node", "growth_node", "", ""]] = []
|
||||||
config_args: Dict[str, str] = {}
|
config_args: Dict[str, str] = {}
|
||||||
|
|
@ -0,0 +1,54 @@
|
||||||
|
import pytest
|
||||||
|
from unittest.mock import MagicMock, patch
|
||||||
|
from pretor.adapter.model_adapter.agent_factory import AgentFactory
|
||||||
|
from pretor.utils.error import ModelNotExistError
|
||||||
|
|
||||||
|
|
||||||
|
def test_create_agent_success_real():
|
||||||
|
mock_provider = MagicMock()
|
||||||
|
mock_provider.provider_type = "openai"
|
||||||
|
mock_provider.provider_models = ["gpt-4"]
|
||||||
|
mock_provider.api_key = "key"
|
||||||
|
mock_provider.url = "url"
|
||||||
|
|
||||||
|
with patch("pretor.adapter.model_adapter.agent_factory.Agent") as mock_agent_cls:
|
||||||
|
with patch("pretor.adapter.model_adapter.agent_factory.OpenAIChatModel") as mock_model_cls:
|
||||||
|
with patch("pretor.adapter.model_adapter.agent_factory.OpenAIProvider") as mock_provider_cls:
|
||||||
|
factory = AgentFactory()
|
||||||
|
agent = factory.create_agent(
|
||||||
|
provider=mock_provider,
|
||||||
|
model_id="gpt-4",
|
||||||
|
output_type=str,
|
||||||
|
system_prompt="You are an AI",
|
||||||
|
deps_type=dict,
|
||||||
|
agent_name="myagent"
|
||||||
|
)
|
||||||
|
mock_provider_cls.assert_called_once_with(api_key="key", url="url")
|
||||||
|
mock_model_cls.assert_called_once_with("gpt-4", mock_provider_cls.return_value)
|
||||||
|
mock_agent_cls.assert_called_once_with(
|
||||||
|
model=mock_model_cls.return_value,
|
||||||
|
name="myagent",
|
||||||
|
system_prompt="You are an AI",
|
||||||
|
output_type=str,
|
||||||
|
deps_type=dict
|
||||||
|
)
|
||||||
|
assert agent == mock_agent_cls.return_value
|
||||||
|
|
||||||
|
|
||||||
|
def test_create_agent_model_not_exist():
|
||||||
|
factory = AgentFactory()
|
||||||
|
mock_provider = MagicMock()
|
||||||
|
mock_provider.provider_models = ["gpt-3"]
|
||||||
|
|
||||||
|
with pytest.raises(ModelNotExistError):
|
||||||
|
factory.create_agent(mock_provider, "gpt-4", str, "prompt", dict, "agent")
|
||||||
|
|
||||||
|
|
||||||
|
def test_create_agent_invalid_provider_type():
|
||||||
|
factory = AgentFactory()
|
||||||
|
mock_provider = MagicMock()
|
||||||
|
mock_provider.provider_type = "unknown"
|
||||||
|
mock_provider.provider_models = ["gpt-4"]
|
||||||
|
|
||||||
|
with pytest.raises(ValueError, match="不支持的协议类型: unknown"):
|
||||||
|
factory.create_agent(mock_provider, "gpt-4", str, "prompt", dict, "agent")
|
||||||
|
|
@ -0,0 +1,74 @@
|
||||||
|
import pytest
|
||||||
|
from unittest.mock import patch
|
||||||
|
from sqlalchemy.exc import IntegrityError, OperationalError
|
||||||
|
from pydantic import ValidationError
|
||||||
|
from pretor.utils.error import UserNotExistError
|
||||||
|
from pretor.core.database.database_exception import database_exception
|
||||||
|
|
||||||
|
@database_exception
|
||||||
|
async def success_func():
|
||||||
|
return "success"
|
||||||
|
|
||||||
|
@database_exception
|
||||||
|
async def validation_error_func():
|
||||||
|
raise ValidationError.from_exception_data(title="Mock", line_errors=[])
|
||||||
|
|
||||||
|
@database_exception
|
||||||
|
async def integrity_error_func():
|
||||||
|
raise IntegrityError("mock_statement", "mock_params", "mock_orig")
|
||||||
|
|
||||||
|
@database_exception
|
||||||
|
async def operational_error_func():
|
||||||
|
raise OperationalError("mock_statement", "mock_params", "mock_orig")
|
||||||
|
|
||||||
|
@database_exception
|
||||||
|
async def user_not_exist_error_func():
|
||||||
|
raise UserNotExistError("mock user")
|
||||||
|
|
||||||
|
@database_exception
|
||||||
|
async def exception_func():
|
||||||
|
raise Exception("mock generic exception")
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_success_func():
|
||||||
|
assert await success_func() == "success"
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
@patch("pretor.core.database.database_exception.logger")
|
||||||
|
async def test_validation_error(mock_logger):
|
||||||
|
with pytest.raises(ValidationError):
|
||||||
|
await validation_error_func()
|
||||||
|
mock_logger.error.assert_called_once()
|
||||||
|
assert "对象校验失败" in mock_logger.error.call_args[0][0]
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
@patch("pretor.core.database.database_exception.logger")
|
||||||
|
async def test_integrity_error(mock_logger):
|
||||||
|
with pytest.raises(IntegrityError):
|
||||||
|
await integrity_error_func()
|
||||||
|
mock_logger.error.assert_called_once()
|
||||||
|
assert "数据库完整性错误" in mock_logger.error.call_args[0][0]
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
@patch("pretor.core.database.database_exception.logger")
|
||||||
|
async def test_operational_error(mock_logger):
|
||||||
|
with pytest.raises(OperationalError):
|
||||||
|
await operational_error_func()
|
||||||
|
mock_logger.error.assert_called_once()
|
||||||
|
assert "数据库连接异常" in mock_logger.error.call_args[0][0]
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
@patch("pretor.core.database.database_exception.logger")
|
||||||
|
async def test_user_not_exist_error(mock_logger):
|
||||||
|
result = await user_not_exist_error_func()
|
||||||
|
assert result is None
|
||||||
|
mock_logger.error.assert_called_once()
|
||||||
|
assert "更改密码失败,用户不存在" in mock_logger.error.call_args[0][0]
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
@patch("pretor.core.database.database_exception.logger")
|
||||||
|
async def test_generic_exception(mock_logger):
|
||||||
|
with pytest.raises(Exception, match="mock generic exception"):
|
||||||
|
await exception_func()
|
||||||
|
mock_logger.exception.assert_called_once()
|
||||||
|
assert "未预期的数据库错误" in mock_logger.exception.call_args[0][0]
|
||||||
|
|
@ -0,0 +1,145 @@
|
||||||
|
import pytest
|
||||||
|
from unittest.mock import MagicMock, AsyncMock, patch
|
||||||
|
import json
|
||||||
|
import sys
|
||||||
|
import builtins
|
||||||
|
|
||||||
|
real_import = builtins.__import__
|
||||||
|
|
||||||
|
|
||||||
|
def mock_import(name, globals=None, locals=None, fromlist=(), level=0):
|
||||||
|
if name == 'sqlmodel':
|
||||||
|
mock_sqlmodel = MagicMock()
|
||||||
|
|
||||||
|
class DummySQLModel:
|
||||||
|
def __init_subclass__(cls, **kwargs):
|
||||||
|
pass
|
||||||
|
|
||||||
|
mock_sqlmodel.SQLModel = DummySQLModel
|
||||||
|
mock_sqlmodel.Field = MagicMock(return_value=None)
|
||||||
|
mock_sqlmodel.select = MagicMock()
|
||||||
|
return mock_sqlmodel
|
||||||
|
return real_import(name, globals, locals, fromlist, level)
|
||||||
|
|
||||||
|
|
||||||
|
builtins.__import__ = mock_import
|
||||||
|
for mod in list(sys.modules.keys()):
|
||||||
|
if 'pretor.core.database.module.memory' in mod or 'sqlmodel' in mod:
|
||||||
|
del sys.modules[mod]
|
||||||
|
from pretor.core.database.module.memory import MemoryRAG
|
||||||
|
|
||||||
|
builtins.__import__ = real_import
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(autouse=True)
|
||||||
|
def mock_dependencies():
|
||||||
|
with patch("pretor.core.database.module.memory.WorkflowRecord") as mock_workflow_record:
|
||||||
|
with patch("pretor.core.database.module.memory.MemoryRecord") as mock_memory_record:
|
||||||
|
with patch("pretor.core.database.module.memory.select") as mock_select:
|
||||||
|
yield mock_workflow_record, mock_memory_record, mock_select
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_session_maker():
|
||||||
|
maker = MagicMock()
|
||||||
|
session = AsyncMock()
|
||||||
|
session.add = MagicMock()
|
||||||
|
maker.return_value.__aenter__.return_value = session
|
||||||
|
maker.__aenter__.return_value = session
|
||||||
|
maker.__aexit__ = AsyncMock()
|
||||||
|
return maker, session
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_save_workflow(mock_session_maker, mock_dependencies):
|
||||||
|
mock_workflow_record, _, _ = mock_dependencies
|
||||||
|
maker, session = mock_session_maker
|
||||||
|
rag = MemoryRAG(maker)
|
||||||
|
|
||||||
|
mock_record = MagicMock()
|
||||||
|
mock_workflow_record.return_value = mock_record
|
||||||
|
|
||||||
|
workflow_data = {"key": "value"}
|
||||||
|
record = await rag.save_workflow("wf_123", workflow_data)
|
||||||
|
|
||||||
|
mock_workflow_record.assert_called_once_with(
|
||||||
|
workflow_id="wf_123",
|
||||||
|
workflow_data_json=json.dumps(workflow_data)
|
||||||
|
)
|
||||||
|
session.add.assert_called_once_with(mock_record)
|
||||||
|
session.commit.assert_called_once()
|
||||||
|
session.refresh.assert_called_once_with(mock_record)
|
||||||
|
assert record == mock_record
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_get_workflow_success(mock_session_maker, mock_dependencies):
|
||||||
|
_, _, mock_select = mock_dependencies
|
||||||
|
maker, session = mock_session_maker
|
||||||
|
rag = MemoryRAG(maker)
|
||||||
|
|
||||||
|
mock_statement = MagicMock()
|
||||||
|
mock_select.return_value.where.return_value = mock_statement
|
||||||
|
|
||||||
|
mock_record = MagicMock()
|
||||||
|
mock_record.workflow_data_json = '{"key": "value"}'
|
||||||
|
|
||||||
|
mock_exec_result = MagicMock()
|
||||||
|
mock_exec_result.scalar_one_or_none.return_value = mock_record
|
||||||
|
session.execute = AsyncMock(return_value=mock_exec_result)
|
||||||
|
|
||||||
|
data = await rag.get_workflow("wf_123")
|
||||||
|
|
||||||
|
session.execute.assert_called_once_with(mock_statement)
|
||||||
|
assert data == {"key": "value"}
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_get_workflow_not_found(mock_session_maker, mock_dependencies):
|
||||||
|
_, _, mock_select = mock_dependencies
|
||||||
|
maker, session = mock_session_maker
|
||||||
|
rag = MemoryRAG(maker)
|
||||||
|
|
||||||
|
mock_exec_result = MagicMock()
|
||||||
|
mock_exec_result.scalar_one_or_none.return_value = None
|
||||||
|
session.execute = AsyncMock(return_value=mock_exec_result)
|
||||||
|
|
||||||
|
data = await rag.get_workflow("wf_123")
|
||||||
|
assert data is None
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_add_memory(mock_session_maker, mock_dependencies):
|
||||||
|
_, mock_memory_record, _ = mock_dependencies
|
||||||
|
maker, session = mock_session_maker
|
||||||
|
rag = MemoryRAG(maker)
|
||||||
|
|
||||||
|
mock_record = MagicMock()
|
||||||
|
mock_memory_record.return_value = mock_record
|
||||||
|
|
||||||
|
record = await rag.add_memory("text", [0.1, 0.2])
|
||||||
|
|
||||||
|
mock_memory_record.assert_called_once_with(memory_text="text", embedding=[0.1, 0.2])
|
||||||
|
session.add.assert_called_once_with(mock_record)
|
||||||
|
session.commit.assert_called_once()
|
||||||
|
session.refresh.assert_called_once_with(mock_record)
|
||||||
|
assert record == mock_record
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_retrieve_memory(mock_session_maker, mock_dependencies):
|
||||||
|
_, _, mock_select = mock_dependencies
|
||||||
|
maker, session = mock_session_maker
|
||||||
|
rag = MemoryRAG(maker)
|
||||||
|
|
||||||
|
mock_statement = MagicMock()
|
||||||
|
mock_select.return_value.limit.return_value = mock_statement
|
||||||
|
|
||||||
|
mock_exec_result = MagicMock()
|
||||||
|
mock_exec_result.all.return_value = ["res1", "res2"]
|
||||||
|
session.execute = AsyncMock(return_value=mock_exec_result)
|
||||||
|
|
||||||
|
results = await rag.retrieve_memory([0.1, 0.2], 5)
|
||||||
|
|
||||||
|
session.execute.assert_called_once_with(mock_statement)
|
||||||
|
assert results == ["res1", "res2"]
|
||||||
|
|
@ -0,0 +1,173 @@
|
||||||
|
import pytest
|
||||||
|
from unittest.mock import MagicMock, AsyncMock, patch
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(autouse=True)
|
||||||
|
def mock_dependencies():
|
||||||
|
with patch("pretor.core.database.module.user.User") as mock_user_cls:
|
||||||
|
mock_user_cls.user_name = MagicMock()
|
||||||
|
with patch("pretor.core.database.module.user.select") as mock_select:
|
||||||
|
yield mock_user_cls, mock_select
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_session_maker():
|
||||||
|
maker = MagicMock()
|
||||||
|
session = AsyncMock()
|
||||||
|
session.add = MagicMock()
|
||||||
|
session.delete = MagicMock()
|
||||||
|
maker.return_value.__aenter__.return_value = session
|
||||||
|
maker.__aenter__.return_value = session
|
||||||
|
maker.__aexit__ = AsyncMock()
|
||||||
|
return maker, session
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_add_user(mock_session_maker, mock_dependencies):
|
||||||
|
mock_user_cls, _ = mock_dependencies
|
||||||
|
from pretor.core.database.module.user import AuthDatabase
|
||||||
|
maker, session = mock_session_maker
|
||||||
|
db = AuthDatabase(maker)
|
||||||
|
|
||||||
|
mock_user = MagicMock()
|
||||||
|
mock_user.user_name = "testuser"
|
||||||
|
mock_user.hashed_password = "hashedpw"
|
||||||
|
mock_user_cls.return_value = mock_user
|
||||||
|
|
||||||
|
user = await db.add_user("testuser", "hashedpw")
|
||||||
|
|
||||||
|
assert user.user_name == "testuser"
|
||||||
|
assert user.hashed_password == "hashedpw"
|
||||||
|
session.add.assert_called_once_with(mock_user)
|
||||||
|
session.commit.assert_called_once()
|
||||||
|
session.refresh.assert_called_once_with(mock_user)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_change_password_success(mock_session_maker, mock_dependencies):
|
||||||
|
mock_user_cls, mock_select = mock_dependencies
|
||||||
|
from pretor.core.database.module.user import AuthDatabase
|
||||||
|
maker, session = mock_session_maker
|
||||||
|
db = AuthDatabase(maker)
|
||||||
|
|
||||||
|
mock_statement = MagicMock()
|
||||||
|
mock_select.return_value.where.return_value = mock_statement
|
||||||
|
|
||||||
|
mock_user = MagicMock()
|
||||||
|
mock_user.hashed_password = "old_password"
|
||||||
|
|
||||||
|
mock_exec_result = MagicMock()
|
||||||
|
mock_exec_result.scalar_one_or_none.return_value = mock_user
|
||||||
|
session.exec = AsyncMock(return_value=mock_exec_result)
|
||||||
|
|
||||||
|
user = await db.change_password("testuser", "old_password", "new_password")
|
||||||
|
|
||||||
|
session.exec.assert_called_once_with(mock_statement)
|
||||||
|
assert user.hashed_password == "new_password"
|
||||||
|
session.add.assert_called_once_with(mock_user)
|
||||||
|
session.commit.assert_called_once()
|
||||||
|
session.refresh.assert_called_once_with(mock_user)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_change_password_user_not_exist(mock_session_maker, mock_dependencies):
|
||||||
|
mock_user_cls, mock_select = mock_dependencies
|
||||||
|
from pretor.core.database.module.user import AuthDatabase
|
||||||
|
maker, session = mock_session_maker
|
||||||
|
db = AuthDatabase(maker)
|
||||||
|
|
||||||
|
mock_exec_result = MagicMock()
|
||||||
|
mock_exec_result.scalar_one_or_none.return_value = None
|
||||||
|
session.exec = AsyncMock(return_value=mock_exec_result)
|
||||||
|
|
||||||
|
result = await db.change_password("testuser", "old_password", "new_password")
|
||||||
|
assert result is None
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_change_password_wrong_password(mock_session_maker, mock_dependencies):
|
||||||
|
mock_user_cls, mock_select = mock_dependencies
|
||||||
|
from pretor.core.database.module.user import AuthDatabase
|
||||||
|
maker, session = mock_session_maker
|
||||||
|
db = AuthDatabase(maker)
|
||||||
|
|
||||||
|
mock_user = MagicMock()
|
||||||
|
mock_user.hashed_password = "actual_password"
|
||||||
|
mock_exec_result = MagicMock()
|
||||||
|
mock_exec_result.scalar_one_or_none.return_value = mock_user
|
||||||
|
session.exec = AsyncMock(return_value=mock_exec_result)
|
||||||
|
|
||||||
|
from pretor.utils.error import UserPasswordError
|
||||||
|
with pytest.raises(UserPasswordError):
|
||||||
|
await db.change_password("testuser", "old_password", "new_password")
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_delete_user_success(mock_session_maker, mock_dependencies):
|
||||||
|
mock_user_cls, mock_select = mock_dependencies
|
||||||
|
from pretor.core.database.module.user import AuthDatabase
|
||||||
|
maker, session = mock_session_maker
|
||||||
|
db = AuthDatabase(maker)
|
||||||
|
|
||||||
|
mock_statement = MagicMock()
|
||||||
|
mock_select.return_value.where.return_value = mock_statement
|
||||||
|
|
||||||
|
mock_user = MagicMock()
|
||||||
|
mock_exec_result = MagicMock()
|
||||||
|
mock_exec_result.scalar_one_or_none.return_value = mock_user
|
||||||
|
session.exec = AsyncMock(return_value=mock_exec_result)
|
||||||
|
|
||||||
|
await db.delete_user("testuser")
|
||||||
|
session.exec.assert_called_once_with(mock_statement)
|
||||||
|
session.delete.assert_called_once_with(mock_user)
|
||||||
|
session.commit.assert_called_once()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_delete_user_not_exist(mock_session_maker, mock_dependencies):
|
||||||
|
mock_user_cls, mock_select = mock_dependencies
|
||||||
|
from pretor.core.database.module.user import AuthDatabase
|
||||||
|
maker, session = mock_session_maker
|
||||||
|
db = AuthDatabase(maker)
|
||||||
|
|
||||||
|
mock_exec_result = MagicMock()
|
||||||
|
mock_exec_result.scalar_one_or_none.return_value = None
|
||||||
|
session.exec = AsyncMock(return_value=mock_exec_result)
|
||||||
|
|
||||||
|
result = await db.delete_user("testuser")
|
||||||
|
assert result is None
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_login_user_success(mock_session_maker, mock_dependencies):
|
||||||
|
mock_user_cls, mock_select = mock_dependencies
|
||||||
|
from pretor.core.database.module.user import AuthDatabase
|
||||||
|
maker, session = mock_session_maker
|
||||||
|
db = AuthDatabase(maker)
|
||||||
|
|
||||||
|
mock_statement = MagicMock()
|
||||||
|
mock_select.return_value.where.return_value = mock_statement
|
||||||
|
|
||||||
|
mock_user = MagicMock()
|
||||||
|
mock_exec_result = MagicMock()
|
||||||
|
mock_exec_result.scalar_one_or_none.return_value = mock_user
|
||||||
|
session.exec = AsyncMock(return_value=mock_exec_result)
|
||||||
|
|
||||||
|
user = await db.login_user("testuser")
|
||||||
|
session.exec.assert_called_once_with(mock_statement)
|
||||||
|
assert user == mock_user
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_login_user_not_exist(mock_session_maker, mock_dependencies):
|
||||||
|
mock_user_cls, mock_select = mock_dependencies
|
||||||
|
from pretor.core.database.module.user import AuthDatabase
|
||||||
|
maker, session = mock_session_maker
|
||||||
|
db = AuthDatabase(maker)
|
||||||
|
|
||||||
|
mock_exec_result = MagicMock()
|
||||||
|
mock_exec_result.scalar_one_or_none.return_value = None
|
||||||
|
session.exec = AsyncMock(return_value=mock_exec_result)
|
||||||
|
|
||||||
|
result = await db.login_user("testuser")
|
||||||
|
assert result is None
|
||||||
|
|
@ -0,0 +1,71 @@
|
||||||
|
import pytest
|
||||||
|
from unittest.mock import patch, MagicMock
|
||||||
|
import sys
|
||||||
|
import builtins
|
||||||
|
|
||||||
|
real_import = builtins.__import__
|
||||||
|
|
||||||
|
|
||||||
|
def mock_import(name, globals=None, locals=None, fromlist=(), level=0):
|
||||||
|
if name == 'ray':
|
||||||
|
mock_ray = MagicMock()
|
||||||
|
|
||||||
|
def mock_remote(cls):
|
||||||
|
return cls
|
||||||
|
|
||||||
|
mock_ray.remote = mock_remote
|
||||||
|
return mock_ray
|
||||||
|
return real_import(name, globals, locals, fromlist, level)
|
||||||
|
|
||||||
|
|
||||||
|
builtins.__import__ = mock_import
|
||||||
|
for mod in list(sys.modules.keys()):
|
||||||
|
if 'pretor.core.database.postgres' in mod or 'ray' in mod:
|
||||||
|
del sys.modules[mod]
|
||||||
|
|
||||||
|
from pretor.core.database.postgres import PostgresDatabase
|
||||||
|
|
||||||
|
builtins.__import__ = real_import
|
||||||
|
|
||||||
|
|
||||||
|
@patch("pretor.core.database.postgres.create_async_engine")
|
||||||
|
@patch("pretor.core.database.postgres.sessionmaker")
|
||||||
|
@patch("pretor.core.database.postgres.AuthDatabase")
|
||||||
|
@patch("pretor.core.database.postgres.ProviderDatabase")
|
||||||
|
@patch("pretor.core.database.postgres.os.environ.get")
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_postgres_database(mock_env_get, mock_provider_db, mock_auth_db, mock_sessionmaker, mock_create_engine):
|
||||||
|
def env_side_effect(key):
|
||||||
|
return {
|
||||||
|
"POSTGRES_USER": "testuser",
|
||||||
|
"POSTGRES_PASSWORD": "testpassword",
|
||||||
|
"POSTGRES_HOST": "localhost",
|
||||||
|
"POSTGRES_PORT": "5432",
|
||||||
|
"POSTGRES_DB": "testdb"
|
||||||
|
}.get(key)
|
||||||
|
|
||||||
|
mock_env_get.side_effect = env_side_effect
|
||||||
|
|
||||||
|
mock_engine = MagicMock()
|
||||||
|
mock_conn = MagicMock()
|
||||||
|
from unittest.mock import AsyncMock
|
||||||
|
mock_conn.run_sync = AsyncMock()
|
||||||
|
|
||||||
|
mock_begin_ctx = MagicMock()
|
||||||
|
mock_begin_ctx.__aenter__ = AsyncMock(return_value=mock_conn)
|
||||||
|
mock_begin_ctx.__aexit__ = AsyncMock()
|
||||||
|
mock_engine.begin.return_value = mock_begin_ctx
|
||||||
|
mock_create_engine.return_value = mock_engine
|
||||||
|
|
||||||
|
db = PostgresDatabase()
|
||||||
|
|
||||||
|
mock_create_engine.assert_called_once_with(
|
||||||
|
"postgresql+asyncpg://testuser:testpassword@localhost:5432/testdb",
|
||||||
|
echo=True
|
||||||
|
)
|
||||||
|
mock_auth_db.assert_called_once()
|
||||||
|
mock_provider_db.assert_called_once()
|
||||||
|
|
||||||
|
with patch("pretor.core.database.postgres.SQLModel.metadata.create_all") as mock_create_all:
|
||||||
|
await db.init_db()
|
||||||
|
mock_conn.run_sync.assert_called_once_with(mock_create_all)
|
||||||
|
|
@ -0,0 +1,15 @@
|
||||||
|
import pytest
|
||||||
|
from pretor.core.database.table.provider import Provider
|
||||||
|
|
||||||
|
def test_provider_table():
|
||||||
|
# Provide required fields
|
||||||
|
provider = Provider(
|
||||||
|
provider_title="title",
|
||||||
|
provider_url="url",
|
||||||
|
provider_apikey="key",
|
||||||
|
provider_models=["model_1"],
|
||||||
|
provider_type="type",
|
||||||
|
provider_owner=1
|
||||||
|
)
|
||||||
|
assert Provider.__tablename__ == 'provider'
|
||||||
|
assert provider.provider_title == "title"
|
||||||
|
|
@ -0,0 +1,7 @@
|
||||||
|
import pytest
|
||||||
|
from pretor.core.database.table.user import User
|
||||||
|
|
||||||
|
def test_user_table():
|
||||||
|
user = User(user_name="name", hashed_password="pw")
|
||||||
|
assert User.__tablename__ == 'user'
|
||||||
|
assert user.user_name == "name"
|
||||||
|
|
@ -0,0 +1,158 @@
|
||||||
|
import pytest
|
||||||
|
import asyncio
|
||||||
|
from unittest.mock import MagicMock, AsyncMock, patch
|
||||||
|
import sys
|
||||||
|
|
||||||
|
import builtins
|
||||||
|
|
||||||
|
real_import = builtins.__import__
|
||||||
|
|
||||||
|
|
||||||
|
def mock_import(name, globals=None, locals=None, fromlist=(), level=0):
|
||||||
|
if name == 'ray':
|
||||||
|
mock_ray = MagicMock()
|
||||||
|
|
||||||
|
def mock_remote(cls):
|
||||||
|
return cls
|
||||||
|
|
||||||
|
mock_ray.remote = mock_remote
|
||||||
|
return mock_ray
|
||||||
|
return real_import(name, globals, locals, fromlist, level)
|
||||||
|
|
||||||
|
|
||||||
|
builtins.__import__ = mock_import
|
||||||
|
|
||||||
|
for mod in list(sys.modules.keys()):
|
||||||
|
if 'pretor.core.global_state_machine.global_state_machine' in mod or 'ray' in mod:
|
||||||
|
del sys.modules[mod]
|
||||||
|
|
||||||
|
from pretor.core.global_state_machine.global_state_machine import GlobalStateMachine
|
||||||
|
from pretor.api.platform.event import PretorEvent
|
||||||
|
from pretor.core.workflow.workflow import PretorWorkflow
|
||||||
|
|
||||||
|
builtins.__import__ = real_import
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_postgres():
|
||||||
|
return MagicMock()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def gsm(mock_postgres):
|
||||||
|
with patch("pretor.core.global_state_machine.global_state_machine.ProviderManager") as mock_pm:
|
||||||
|
manager = GlobalStateMachine(mock_postgres)
|
||||||
|
return manager
|
||||||
|
|
||||||
|
|
||||||
|
def test_add_delete_get_event(gsm):
|
||||||
|
event = MagicMock(spec=PretorEvent)
|
||||||
|
event.event_id = 123
|
||||||
|
|
||||||
|
gsm.add_event(event)
|
||||||
|
|
||||||
|
assert getattr(event, 'pending_queue', None) is not None
|
||||||
|
assert getattr(event, 'receive_queue', None) is not None
|
||||||
|
|
||||||
|
retrieved = gsm.get_event(123)
|
||||||
|
assert retrieved == event
|
||||||
|
|
||||||
|
gsm.delete_event(123)
|
||||||
|
assert gsm.get_event(123) is None
|
||||||
|
|
||||||
|
|
||||||
|
def test_update_attachment_and_workflow(gsm):
|
||||||
|
event = MagicMock(spec=PretorEvent)
|
||||||
|
event.event_id = "abc"
|
||||||
|
gsm.add_event(event)
|
||||||
|
|
||||||
|
gsm.update_attachment("abc", {"k": "v"})
|
||||||
|
assert event.attachment == {"k": "v"}
|
||||||
|
|
||||||
|
wf = MagicMock(spec=PretorWorkflow)
|
||||||
|
gsm.update_workflow("abc", wf)
|
||||||
|
assert event.workflow == wf
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_queues(gsm):
|
||||||
|
event = MagicMock(spec=PretorEvent)
|
||||||
|
event.event_id = "q_event"
|
||||||
|
# To use await put/get, we must actually use real asyncio queues for the mock event
|
||||||
|
event.pending_queue = asyncio.Queue()
|
||||||
|
event.receive_queue = asyncio.Queue()
|
||||||
|
gsm.event_dict["q_event"] = event
|
||||||
|
|
||||||
|
await gsm.put_pending("q_event", "item1")
|
||||||
|
res1 = await gsm.get_pending("q_event")
|
||||||
|
assert res1 == "item1"
|
||||||
|
|
||||||
|
await gsm.put_received("q_event", "item2")
|
||||||
|
res2 = await gsm.get_received("q_event")
|
||||||
|
assert res2 == "item2"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_add_provider_success(gsm, mock_postgres):
|
||||||
|
mock_provider_class = AsyncMock()
|
||||||
|
mock_provider = MagicMock()
|
||||||
|
mock_provider.provider_title = "MyProvider"
|
||||||
|
mock_provider.provider_url = "url"
|
||||||
|
mock_provider.provider_apikey = "key"
|
||||||
|
mock_provider.provider_models = ["model"]
|
||||||
|
mock_provider.provider_type = "openai"
|
||||||
|
mock_provider_class.create_model.return_value = mock_provider
|
||||||
|
|
||||||
|
gsm.global_provider_manager.provider_mapper = {"openai": mock_provider_class}
|
||||||
|
gsm.global_provider_manager.provider_register = {}
|
||||||
|
|
||||||
|
mock_add_provider = AsyncMock()
|
||||||
|
mock_postgres.provider_database.add_provider.remote = mock_add_provider
|
||||||
|
|
||||||
|
await gsm.add_provider("openai", "title", "url", "key", 1)
|
||||||
|
|
||||||
|
assert gsm.global_provider_manager.provider_register["title"] == mock_provider
|
||||||
|
mock_add_provider.assert_called_once()
|
||||||
|
assert mock_provider.provider_owner == 1
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_add_provider_unsupported(gsm):
|
||||||
|
gsm.global_provider_manager.provider_mapper = {}
|
||||||
|
with patch("pretor.core.global_state_machine.global_state_machine.logger") as mock_logger:
|
||||||
|
await gsm.add_provider("magic", "title", "url", "key", 1)
|
||||||
|
mock_logger.warning.assert_called_with("Provider type magic is not supported.")
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_add_provider_request_error(gsm):
|
||||||
|
from httpx import RequestError
|
||||||
|
mock_provider_class = AsyncMock()
|
||||||
|
mock_provider_class.create_model.side_effect = RequestError("Network Error", request=MagicMock())
|
||||||
|
gsm.global_provider_manager.provider_mapper = {"openai": mock_provider_class}
|
||||||
|
|
||||||
|
with patch("pretor.core.global_state_machine.global_state_machine.logger") as mock_logger:
|
||||||
|
await gsm.add_provider("openai", "title", "url", "key", 1)
|
||||||
|
mock_logger.warning.assert_called_once()
|
||||||
|
assert "网络请求异常" in mock_logger.warning.call_args[0][0]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_add_provider_generic_error(gsm):
|
||||||
|
mock_provider_class = AsyncMock()
|
||||||
|
mock_provider_class.create_model.side_effect = ValueError("Some Error")
|
||||||
|
gsm.global_provider_manager.provider_mapper = {"openai": mock_provider_class}
|
||||||
|
|
||||||
|
with patch("pretor.core.global_state_machine.global_state_machine.logger") as mock_logger:
|
||||||
|
await gsm.add_provider("openai", "title", "url", "key", 1)
|
||||||
|
mock_logger.warning.assert_called_once()
|
||||||
|
assert "解析模型列表时发生错误" in mock_logger.warning.call_args[0][0]
|
||||||
|
|
||||||
|
|
||||||
|
def test_get_provider_list_and_get_provider(gsm):
|
||||||
|
mock_provider = MagicMock()
|
||||||
|
gsm.global_provider_manager.provider_register = {"p1": mock_provider}
|
||||||
|
|
||||||
|
assert gsm.get_provider_list() == {"p1": mock_provider}
|
||||||
|
assert gsm.get_provider("p1") == mock_provider
|
||||||
|
assert gsm.get_provider("missing") is None
|
||||||
|
|
@ -0,0 +1,26 @@
|
||||||
|
import pytest
|
||||||
|
from pretor.core.global_state_machine.model_provider.base_provider import Provider, ProviderArgs, ProviderStatus
|
||||||
|
|
||||||
|
def test_provider_status():
|
||||||
|
assert ProviderStatus.UP == "up"
|
||||||
|
assert ProviderStatus.DOWN == "down"
|
||||||
|
|
||||||
|
def test_provider_args():
|
||||||
|
args = ProviderArgs(
|
||||||
|
provider_title="title",
|
||||||
|
provider_url="url",
|
||||||
|
provider_apikey="key",
|
||||||
|
provider_owner=1
|
||||||
|
)
|
||||||
|
assert args.provider_title == "title"
|
||||||
|
|
||||||
|
def test_provider_model():
|
||||||
|
p = Provider(
|
||||||
|
provider_title="title",
|
||||||
|
provider_url="url",
|
||||||
|
provider_apikey="key",
|
||||||
|
provider_models=["model"],
|
||||||
|
provider_type="openai"
|
||||||
|
)
|
||||||
|
assert p.provider_status == ProviderStatus.UP
|
||||||
|
assert p.provider_owner is None
|
||||||
|
|
@ -0,0 +1,51 @@
|
||||||
|
import pytest
|
||||||
|
from unittest.mock import patch, MagicMock, AsyncMock
|
||||||
|
from pretor.core.global_state_machine.model_provider.claude_provider import ClaudeProvider, ProviderArgs
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def provider_args():
|
||||||
|
return ProviderArgs(
|
||||||
|
provider_title="TestClaude",
|
||||||
|
provider_url="https://api.anthropic.com",
|
||||||
|
provider_apikey="testkey",
|
||||||
|
provider_owner=1
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
@patch("pretor.core.global_state_machine.model_provider.claude_provider.httpx.AsyncClient")
|
||||||
|
async def test_load_models_success(mock_client, provider_args):
|
||||||
|
mock_response = MagicMock()
|
||||||
|
mock_response.status_code = 200
|
||||||
|
mock_response.json.return_value = {
|
||||||
|
"data": [{"id": "claude-3-opus-20240229"}, {"id": "claude-3-sonnet-20240229"}]
|
||||||
|
}
|
||||||
|
|
||||||
|
mock_client_instance = AsyncMock()
|
||||||
|
mock_client_instance.get.return_value = mock_response
|
||||||
|
mock_client.return_value.__aenter__.return_value = mock_client_instance
|
||||||
|
|
||||||
|
models = await ClaudeProvider._load_models(provider_args)
|
||||||
|
assert models == ["claude-3-opus-20240229", "claude-3-sonnet-20240229"]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
@patch("pretor.core.global_state_machine.model_provider.claude_provider.httpx.AsyncClient")
|
||||||
|
async def test_load_models_error(mock_client, provider_args):
|
||||||
|
mock_client_instance = AsyncMock()
|
||||||
|
mock_client_instance.get.side_effect = Exception("network error")
|
||||||
|
mock_client.return_value.__aenter__.return_value = mock_client_instance
|
||||||
|
|
||||||
|
models = await ClaudeProvider._load_models(provider_args)
|
||||||
|
assert models == []
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
@patch("pretor.core.global_state_machine.model_provider.claude_provider.ClaudeProvider._load_models",
|
||||||
|
return_value=["claude-3"])
|
||||||
|
async def test_create_provider(mock_load, provider_args):
|
||||||
|
provider = await ClaudeProvider.create_provider(provider_args)
|
||||||
|
assert provider.provider_title == "TestClaude"
|
||||||
|
assert provider.provider_models == ["claude-3"]
|
||||||
|
assert provider.provider_type == "claude"
|
||||||
|
|
@ -0,0 +1,69 @@
|
||||||
|
import pytest
|
||||||
|
from unittest.mock import patch, MagicMock, AsyncMock
|
||||||
|
from pretor.core.global_state_machine.model_provider.gemini_provider import GeminiProvider, ProviderArgs
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def provider_args():
|
||||||
|
return ProviderArgs(
|
||||||
|
provider_title="TestGemini",
|
||||||
|
provider_url="https://generativelanguage.googleapis.com",
|
||||||
|
provider_apikey="testkey",
|
||||||
|
provider_owner=1
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
@patch("pretor.core.global_state_machine.model_provider.gemini_provider.httpx.AsyncClient")
|
||||||
|
async def test_load_models_success(mock_client, provider_args):
|
||||||
|
mock_response = MagicMock()
|
||||||
|
mock_response.status_code = 200
|
||||||
|
mock_response.json.return_value = {
|
||||||
|
"models": [
|
||||||
|
{"name": "models/gemini-1.5-pro", "supportedGenerationMethods": ["generateContent"]},
|
||||||
|
{"name": "models/gemini-1.5-flash", "supportedGenerationMethods": ["generateContent"]},
|
||||||
|
{"name": "models/other", "supportedGenerationMethods": []}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
|
||||||
|
mock_client_instance = AsyncMock()
|
||||||
|
mock_client_instance.get.return_value = mock_response
|
||||||
|
mock_client.return_value.__aenter__.return_value = mock_client_instance
|
||||||
|
|
||||||
|
models = await GeminiProvider._load_models(provider_args)
|
||||||
|
assert models == ["gemini-1.5-flash", "gemini-1.5-pro"]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
@patch("pretor.core.global_state_machine.model_provider.gemini_provider.httpx.AsyncClient")
|
||||||
|
async def test_load_models_status_error(mock_client, provider_args):
|
||||||
|
mock_response = MagicMock()
|
||||||
|
mock_response.status_code = 401
|
||||||
|
|
||||||
|
mock_client_instance = AsyncMock()
|
||||||
|
mock_client_instance.get.return_value = mock_response
|
||||||
|
mock_client.return_value.__aenter__.return_value = mock_client_instance
|
||||||
|
|
||||||
|
models = await GeminiProvider._load_models(provider_args)
|
||||||
|
assert models == []
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
@patch("pretor.core.global_state_machine.model_provider.gemini_provider.httpx.AsyncClient")
|
||||||
|
async def test_load_models_error(mock_client, provider_args):
|
||||||
|
mock_client_instance = AsyncMock()
|
||||||
|
mock_client_instance.get.side_effect = Exception("network error")
|
||||||
|
mock_client.return_value.__aenter__.return_value = mock_client_instance
|
||||||
|
|
||||||
|
models = await GeminiProvider._load_models(provider_args)
|
||||||
|
assert models == []
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
@patch("pretor.core.global_state_machine.model_provider.gemini_provider.GeminiProvider._load_models",
|
||||||
|
return_value=["gemini-1"])
|
||||||
|
async def test_create_provider(mock_load, provider_args):
|
||||||
|
provider = await GeminiProvider.create_provider(provider_args)
|
||||||
|
assert provider.provider_title == "TestGemini"
|
||||||
|
assert provider.provider_models == ["gemini-1"]
|
||||||
|
assert provider.provider_type == "gemini"
|
||||||
|
|
@ -0,0 +1,110 @@
|
||||||
|
import pytest
|
||||||
|
from unittest.mock import patch, MagicMock, AsyncMock
|
||||||
|
from pretor.core.global_state_machine.model_provider.openai_provider import OpenAIProvider, ProviderArgs
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def provider_args():
|
||||||
|
return ProviderArgs(
|
||||||
|
provider_title="TestOpenAI",
|
||||||
|
provider_url="https://api.openai.com/v1",
|
||||||
|
provider_apikey="testkey",
|
||||||
|
provider_owner=1
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def provider_args_no_v1():
|
||||||
|
return ProviderArgs(
|
||||||
|
provider_title="TestOpenAI",
|
||||||
|
provider_url="https://api.openai.com",
|
||||||
|
provider_apikey="testkey",
|
||||||
|
provider_owner=1
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
@patch("pretor.core.global_state_machine.model_provider.openai_provider.httpx.AsyncClient")
|
||||||
|
async def test_load_models_success(mock_client, provider_args):
|
||||||
|
mock_response = MagicMock()
|
||||||
|
mock_response.status_code = 200
|
||||||
|
mock_response.json.return_value = {
|
||||||
|
"data": [{"id": "gpt-4"}, {"id": "gpt-3.5-turbo"}]
|
||||||
|
}
|
||||||
|
|
||||||
|
mock_client_instance = AsyncMock()
|
||||||
|
mock_client_instance.get.return_value = mock_response
|
||||||
|
mock_client.return_value.__aenter__.return_value = mock_client_instance
|
||||||
|
|
||||||
|
models = await OpenAIProvider._load_models(provider_args)
|
||||||
|
assert models == ["gpt-3.5-turbo", "gpt-4"]
|
||||||
|
mock_client_instance.get.assert_called_once_with(
|
||||||
|
"https://api.openai.com/v1/models",
|
||||||
|
headers={"Authorization": "Bearer testkey", "Content-Type": "application/json"}
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
@patch("pretor.core.global_state_machine.model_provider.openai_provider.httpx.AsyncClient")
|
||||||
|
async def test_load_models_no_v1(mock_client, provider_args_no_v1):
|
||||||
|
mock_response = MagicMock()
|
||||||
|
mock_response.status_code = 200
|
||||||
|
mock_response.json.return_value = {"data": []}
|
||||||
|
|
||||||
|
mock_client_instance = AsyncMock()
|
||||||
|
mock_client_instance.get.return_value = mock_response
|
||||||
|
mock_client.return_value.__aenter__.return_value = mock_client_instance
|
||||||
|
|
||||||
|
models = await OpenAIProvider._load_models(provider_args_no_v1)
|
||||||
|
assert models == []
|
||||||
|
mock_client_instance.get.assert_called_once_with(
|
||||||
|
"https://api.openai.com/v1/models",
|
||||||
|
headers={"Authorization": "Bearer testkey", "Content-Type": "application/json"}
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
@patch("pretor.core.global_state_machine.model_provider.openai_provider.httpx.AsyncClient")
|
||||||
|
async def test_load_models_status_error(mock_client, provider_args):
|
||||||
|
mock_response = MagicMock()
|
||||||
|
mock_response.status_code = 401
|
||||||
|
|
||||||
|
mock_client_instance = AsyncMock()
|
||||||
|
mock_client_instance.get.return_value = mock_response
|
||||||
|
mock_client.return_value.__aenter__.return_value = mock_client_instance
|
||||||
|
|
||||||
|
models = await OpenAIProvider._load_models(provider_args)
|
||||||
|
assert models == []
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
@patch("pretor.core.global_state_machine.model_provider.openai_provider.httpx.AsyncClient")
|
||||||
|
async def test_load_models_request_error(mock_client, provider_args):
|
||||||
|
import httpx
|
||||||
|
mock_client_instance = AsyncMock()
|
||||||
|
mock_client_instance.get.side_effect = httpx.RequestError("network error", request=MagicMock())
|
||||||
|
mock_client.return_value.__aenter__.return_value = mock_client_instance
|
||||||
|
|
||||||
|
models = await OpenAIProvider._load_models(provider_args)
|
||||||
|
assert models == []
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
@patch("pretor.core.global_state_machine.model_provider.openai_provider.httpx.AsyncClient")
|
||||||
|
async def test_load_models_generic_error(mock_client, provider_args):
|
||||||
|
mock_client_instance = AsyncMock()
|
||||||
|
mock_client_instance.get.side_effect = Exception("generic error")
|
||||||
|
mock_client.return_value.__aenter__.return_value = mock_client_instance
|
||||||
|
|
||||||
|
models = await OpenAIProvider._load_models(provider_args)
|
||||||
|
assert models == []
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
@patch("pretor.core.global_state_machine.model_provider.openai_provider.OpenAIProvider._load_models",
|
||||||
|
return_value=["gpt-4"])
|
||||||
|
async def test_create_provider(mock_load, provider_args):
|
||||||
|
provider = await OpenAIProvider.create_provider(provider_args)
|
||||||
|
assert provider.provider_title == "TestOpenAI"
|
||||||
|
assert provider.provider_models == ["gpt-4"]
|
||||||
|
assert provider.provider_type == "openai"
|
||||||
|
|
@ -0,0 +1,32 @@
|
||||||
|
import pytest
|
||||||
|
from unittest.mock import MagicMock
|
||||||
|
from pretor.core.global_state_machine.provider_manager import ProviderManager
|
||||||
|
|
||||||
|
|
||||||
|
def test_provider_manager_init():
|
||||||
|
mock_postgres = MagicMock()
|
||||||
|
|
||||||
|
mock_provider1 = MagicMock()
|
||||||
|
mock_provider1.title = "title1"
|
||||||
|
|
||||||
|
mock_provider2 = MagicMock()
|
||||||
|
mock_provider2.title = "title2"
|
||||||
|
|
||||||
|
# In _load_provider_register, it calls `postgres.provider_database.get_provider.remote()`
|
||||||
|
# which returns a list of providers synchronously?
|
||||||
|
# Yes, it assumes `.remote()` returns an iterable in this context. Wait!
|
||||||
|
# `.remote()` in Ray actually returns an ObjectRef which is NOT iterable directly,
|
||||||
|
# it must be `ray.get()`.
|
||||||
|
# But let's mock it to return a list anyway because the code does `for provider in providers:`.
|
||||||
|
|
||||||
|
mock_postgres.provider_database.get_provider.remote.return_value = [mock_provider1, mock_provider2]
|
||||||
|
|
||||||
|
manager = ProviderManager(mock_postgres)
|
||||||
|
|
||||||
|
assert "openai" in manager.provider_mapper
|
||||||
|
assert "gemini" in manager.provider_mapper
|
||||||
|
assert "claude" in manager.provider_mapper
|
||||||
|
|
||||||
|
assert manager.provider_register["title1"] == mock_provider1
|
||||||
|
assert manager.provider_register["title2"] == mock_provider2
|
||||||
|
mock_postgres.provider_database.get_provider.remote.assert_called_once()
|
||||||
|
|
@ -0,0 +1,5 @@
|
||||||
|
from pretor.core.global_state_machine.tool_manager import GlobalToolManager
|
||||||
|
|
||||||
|
def test_global_tool_manager_init():
|
||||||
|
manager = GlobalToolManager()
|
||||||
|
assert isinstance(manager, GlobalToolManager)
|
||||||
|
|
@ -0,0 +1,117 @@
|
||||||
|
import pytest
|
||||||
|
from unittest.mock import MagicMock, AsyncMock, patch
|
||||||
|
import asyncio
|
||||||
|
|
||||||
|
import sys
|
||||||
|
import builtins
|
||||||
|
|
||||||
|
real_import = builtins.__import__
|
||||||
|
|
||||||
|
|
||||||
|
def mock_import(name, globals=None, locals=None, fromlist=(), level=0):
|
||||||
|
if name == 'ray':
|
||||||
|
mock_ray = MagicMock()
|
||||||
|
|
||||||
|
def mock_remote(cls):
|
||||||
|
return cls
|
||||||
|
|
||||||
|
mock_ray.remote = mock_remote
|
||||||
|
return mock_ray
|
||||||
|
return real_import(name, globals, locals, fromlist, level)
|
||||||
|
|
||||||
|
|
||||||
|
builtins.__import__ = mock_import
|
||||||
|
for mod in list(sys.modules.keys()):
|
||||||
|
if 'pretor.core.workflow.workflow_runner' in mod or 'ray' in mod:
|
||||||
|
del sys.modules[mod]
|
||||||
|
from pretor.core.workflow.workflow_runner import WorkflowEngine, WorkflowRunningEngine
|
||||||
|
|
||||||
|
builtins.__import__ = real_import
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_ray():
|
||||||
|
with patch("pretor.core.workflow.workflow_runner.ray") as mock_ray:
|
||||||
|
mock_ray.get = lambda x: x
|
||||||
|
yield mock_ray
|
||||||
|
|
||||||
|
|
||||||
|
def test_workflow_engine_init():
|
||||||
|
mock_wf = MagicMock()
|
||||||
|
mock_wf.work_link = []
|
||||||
|
engine = WorkflowEngine(mock_wf, "conscious", "control", "supervisor")
|
||||||
|
assert engine.workflow == mock_wf
|
||||||
|
assert engine.consciousness_node == "conscious"
|
||||||
|
assert engine.control_node == "control"
|
||||||
|
assert engine.supervisory_node == "supervisor"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_workflow_engine_run():
|
||||||
|
mock_wf = MagicMock()
|
||||||
|
|
||||||
|
step1 = MagicMock()
|
||||||
|
step1.step = 1
|
||||||
|
step1.status = "waiting"
|
||||||
|
step1.node = "control_node"
|
||||||
|
step1.inputs = []
|
||||||
|
step1.outputs = ["res"]
|
||||||
|
step1.logic_gate = None
|
||||||
|
mock_wf.work_link = [step1]
|
||||||
|
|
||||||
|
mock_wf.status.step = 1
|
||||||
|
mock_wf.context_memory = {}
|
||||||
|
|
||||||
|
mock_control = MagicMock()
|
||||||
|
mock_control.process.remote.return_value = "process_result"
|
||||||
|
|
||||||
|
engine = WorkflowEngine(mock_wf, "conscious", mock_control, "supervisor")
|
||||||
|
|
||||||
|
with patch("pretor.core.workflow.workflow_runner.ray") as mock_ray_patch:
|
||||||
|
mock_gsm = MagicMock()
|
||||||
|
mock_ray_patch.get_actor.return_value = mock_gsm
|
||||||
|
await engine.run()
|
||||||
|
|
||||||
|
assert step1.status == "completed"
|
||||||
|
assert mock_wf.context_memory["res"] == "process_result"
|
||||||
|
mock_gsm.update_workflow.remote.assert_called_with(mock_wf.trace_id, mock_wf)
|
||||||
|
|
||||||
|
|
||||||
|
def test_workflow_running_engine_init():
|
||||||
|
engine = WorkflowRunningEngine("conscious", "control", "supervisor")
|
||||||
|
assert engine.consciousness_node == "conscious"
|
||||||
|
assert engine.control_node == "control"
|
||||||
|
assert engine.supervisory_node == "supervisor"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_workflow_running_engine_submit():
|
||||||
|
engine = WorkflowRunningEngine("conscious", "control", "supervisor")
|
||||||
|
engine.workflow_queue = asyncio.Queue()
|
||||||
|
|
||||||
|
mock_wf = MagicMock()
|
||||||
|
await engine.submit_workflow(mock_wf)
|
||||||
|
|
||||||
|
item = await engine.workflow_queue.get()
|
||||||
|
assert item == mock_wf
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_workflow_running_engine_runner():
|
||||||
|
engine = WorkflowRunningEngine("conscious", "control", "supervisor")
|
||||||
|
engine.workflow_queue = asyncio.Queue()
|
||||||
|
|
||||||
|
mock_wf = MagicMock()
|
||||||
|
await engine.workflow_queue.put(mock_wf)
|
||||||
|
|
||||||
|
with patch("pretor.core.workflow.workflow_runner.WorkflowEngine") as mock_wf_engine_cls:
|
||||||
|
mock_engine_instance = MagicMock()
|
||||||
|
mock_engine_instance.run = AsyncMock()
|
||||||
|
mock_wf_engine_cls.return_value = mock_engine_instance
|
||||||
|
|
||||||
|
task = asyncio.create_task(engine.runner(1))
|
||||||
|
await asyncio.sleep(0.05) # Give runner time to process one item
|
||||||
|
task.cancel() # Stop the infinite loop
|
||||||
|
|
||||||
|
mock_wf_engine_cls.assert_called_with(mock_wf, "conscious", "control", "supervisor")
|
||||||
|
mock_engine_instance.run.assert_called_once()
|
||||||
|
|
@ -0,0 +1,34 @@
|
||||||
|
import pytest
|
||||||
|
from unittest.mock import patch, MagicMock
|
||||||
|
from pretor.core.workflow.workflow_template_generator.workflow_template_generator import WorkflowTemplateGenerator
|
||||||
|
|
||||||
|
|
||||||
|
@patch("pretor.core.workflow.workflow_template_generator.workflow_template_generator.Path")
|
||||||
|
def test_generate_workflow_template(mock_path):
|
||||||
|
mock_dir = MagicMock()
|
||||||
|
mock_dir.exists.return_value = False
|
||||||
|
mock_path.return_value = mock_dir
|
||||||
|
|
||||||
|
mock_file = MagicMock()
|
||||||
|
mock_dir.__truediv__.return_value = mock_file
|
||||||
|
|
||||||
|
mock_open_ctx = MagicMock()
|
||||||
|
mock_file.open.return_value.__enter__.return_value = mock_open_ctx
|
||||||
|
|
||||||
|
generator = WorkflowTemplateGenerator()
|
||||||
|
generator.generate_workflow_template(
|
||||||
|
name="test_wf",
|
||||||
|
desc="test_desc",
|
||||||
|
steps=[{"step": 1, "node": "n", "action": "a", "desc": "d", "input": [], "output": [], "logic_gate": {}}]
|
||||||
|
)
|
||||||
|
|
||||||
|
mock_dir.mkdir.assert_called_once_with(parents=True)
|
||||||
|
mock_file.open.assert_called_once_with("w", encoding="utf-8")
|
||||||
|
|
||||||
|
mock_open_ctx.write.assert_called_once()
|
||||||
|
write_arg = mock_open_ctx.write.call_args[0][0]
|
||||||
|
import json
|
||||||
|
written_data = json.loads(write_arg)
|
||||||
|
assert written_data["name"] == "test_wf"
|
||||||
|
assert written_data["desc"] == "test_desc"
|
||||||
|
assert written_data["work_link"][0]["step"] == 1
|
||||||
|
|
@ -0,0 +1,36 @@
|
||||||
|
import pytest
|
||||||
|
from pydantic import ValidationError
|
||||||
|
from pretor.core.workflow.workflow_template_generator.workflow_template import WorkflowTemplateStep, WorkflowTemplate
|
||||||
|
|
||||||
|
def test_workflow_template_step():
|
||||||
|
step = WorkflowTemplateStep(
|
||||||
|
step=1,
|
||||||
|
node="node_type",
|
||||||
|
action="act",
|
||||||
|
desc="desc",
|
||||||
|
input=["in1"],
|
||||||
|
output=["out1"],
|
||||||
|
logic_gate={"if_pass": "next"}
|
||||||
|
)
|
||||||
|
assert step.step == 1
|
||||||
|
assert step.node == "node_type"
|
||||||
|
|
||||||
|
def test_workflow_template_success():
|
||||||
|
step1 = WorkflowTemplateStep(
|
||||||
|
step=1, node="node1", action="a1", desc="d1", input=[], output=[], logic_gate={}
|
||||||
|
)
|
||||||
|
step2 = WorkflowTemplateStep(
|
||||||
|
step=2, node="node2", action="a2", desc="d2", input=[], output=[], logic_gate={}
|
||||||
|
)
|
||||||
|
wt = WorkflowTemplate(name="temp", desc="desc", work_link=[step1, step2])
|
||||||
|
assert wt.name == "temp"
|
||||||
|
|
||||||
|
def test_workflow_template_error_duplicate_steps():
|
||||||
|
step1 = WorkflowTemplateStep(
|
||||||
|
step=1, node="node1", action="a1", desc="d1", input=[], output=[], logic_gate={}
|
||||||
|
)
|
||||||
|
step2 = WorkflowTemplateStep(
|
||||||
|
step=1, node="node2", action="a2", desc="d2", input=[], output=[], logic_gate={}
|
||||||
|
)
|
||||||
|
with pytest.raises(ValidationError, match="Step numbers in work_link must be unique"):
|
||||||
|
WorkflowTemplate(name="temp", desc="desc", work_link=[step1, step2])
|
||||||
|
|
@ -0,0 +1,50 @@
|
||||||
|
import pytest
|
||||||
|
import json
|
||||||
|
from unittest.mock import MagicMock, patch, mock_open
|
||||||
|
from pathlib import Path
|
||||||
|
from pretor.core.workflow.workflow_template_manager import WorkflowManager
|
||||||
|
|
||||||
|
|
||||||
|
def test_workflow_manager_init_success():
|
||||||
|
mock_file1 = MagicMock(spec=Path)
|
||||||
|
mock_file1.open = mock_open(read_data=json.dumps({"name": "test1", "desc": "desc1"}))
|
||||||
|
|
||||||
|
mock_file2 = MagicMock(spec=Path)
|
||||||
|
mock_file2.open = mock_open(read_data=json.dumps({"name": "test2", "desc": "desc2"}))
|
||||||
|
|
||||||
|
with patch("pretor.core.workflow.workflow_template_manager.Path.glob", return_value=[mock_file1, mock_file2]):
|
||||||
|
with patch("pretor.core.workflow.workflow_template_manager.WorkflowTemplateGenerator"):
|
||||||
|
manager = WorkflowManager()
|
||||||
|
assert manager.workflow_templates_registry == {"test1": "desc1", "test2": "desc2"}
|
||||||
|
|
||||||
|
|
||||||
|
def test_workflow_manager_init_json_error():
|
||||||
|
mock_file1 = MagicMock(spec=Path)
|
||||||
|
mock_file1.open = mock_open(read_data="{invalid_json}")
|
||||||
|
|
||||||
|
with patch("pretor.core.workflow.workflow_template_manager.Path.glob", return_value=[mock_file1]):
|
||||||
|
with patch("pretor.core.workflow.workflow_template_manager.logger") as mock_logger:
|
||||||
|
with patch("pretor.core.workflow.workflow_template_manager.WorkflowTemplateGenerator"):
|
||||||
|
manager = WorkflowManager()
|
||||||
|
assert manager.workflow_templates_registry == {}
|
||||||
|
mock_logger.warning.assert_called_once()
|
||||||
|
assert "不是json文件或格式错误" in mock_logger.warning.call_args[0][0]
|
||||||
|
|
||||||
|
|
||||||
|
@patch("pretor.core.workflow.workflow_template_manager.WorkflowTemplateGenerator")
|
||||||
|
@patch("pretor.core.workflow.workflow_template_manager.Path.glob", return_value=[])
|
||||||
|
def test_generate_workflow_template_success(mock_glob, mock_generator_cls):
|
||||||
|
manager = WorkflowManager()
|
||||||
|
manager.generate_workflow_template("name", "desc", ["step1"])
|
||||||
|
mock_generator_cls.return_value.generate_workflow_template.assert_called_once_with(name="name", desc="desc",
|
||||||
|
steps=["step1"])
|
||||||
|
|
||||||
|
|
||||||
|
@patch("pretor.core.workflow.workflow_template_manager.WorkflowTemplateGenerator")
|
||||||
|
@patch("pretor.core.workflow.workflow_template_manager.Path.glob", return_value=[])
|
||||||
|
@patch("pretor.core.workflow.workflow_template_manager.logger")
|
||||||
|
def test_generate_workflow_template_exception(mock_logger, mock_glob, mock_generator_cls):
|
||||||
|
mock_generator_cls.return_value.generate_workflow_template.side_effect = Exception("error")
|
||||||
|
manager = WorkflowManager()
|
||||||
|
manager.generate_workflow_template("name", "desc", ["step1"])
|
||||||
|
mock_logger.exception.assert_called_once_with("Failed to generate workflow template")
|
||||||
|
|
@ -0,0 +1,56 @@
|
||||||
|
import pytest
|
||||||
|
from pretor.core.workflow.workflow import WorkerGroup, WorkStep, PretorWorkflow, WorkflowStatus, LogicGate
|
||||||
|
|
||||||
|
def test_worker_group():
|
||||||
|
wg = WorkerGroup(name="group1", primary_individual={"coder": 1}, composite_individual={"tester": 1})
|
||||||
|
assert wg.name == "group1"
|
||||||
|
assert wg.primary_individual == {"coder": 1}
|
||||||
|
assert wg.composite_individual == {"tester": 1}
|
||||||
|
|
||||||
|
def test_work_step():
|
||||||
|
ws = WorkStep(
|
||||||
|
step=1,
|
||||||
|
node="control_node",
|
||||||
|
action="coding",
|
||||||
|
desc="Write some code"
|
||||||
|
)
|
||||||
|
assert ws.step == 1
|
||||||
|
assert ws.node == "control_node"
|
||||||
|
assert ws.action == "coding"
|
||||||
|
assert ws.desc == "Write some code"
|
||||||
|
assert ws.status == "waiting"
|
||||||
|
|
||||||
|
def test_pretor_workflow_validation_success():
|
||||||
|
ws1 = WorkStep(step=1, node="control_node", action="a1", desc="d1")
|
||||||
|
ws2 = WorkStep(step=2, node="supervisory_node", action="a2", desc="d2")
|
||||||
|
wg = WorkerGroup(name="g1", primary_individual={"coder": 1}, composite_individual={})
|
||||||
|
wf = PretorWorkflow(title="wf1", workgroup_list=[wg], work_link=[ws1, ws2], trace_id="t", event_info={"platform":"a", "username":"b"})
|
||||||
|
assert wf.title == "wf1"
|
||||||
|
|
||||||
|
def test_pretor_workflow_validation_error_step_discontinuous():
|
||||||
|
ws1 = WorkStep(step=1, node="control_node", action="a1", desc="d1")
|
||||||
|
ws2 = WorkStep(step=3, node="supervisory_node", action="a2", desc="d2")
|
||||||
|
wg = WorkerGroup(name="g1", primary_individual={}, composite_individual={})
|
||||||
|
with pytest.raises(ValueError, match="工作链步数不连续"):
|
||||||
|
PretorWorkflow(title="wf1", workgroup_list=[wg], work_link=[ws1, ws2], trace_id="t", event_info={"platform":"a", "username":"b"})
|
||||||
|
|
||||||
|
def test_pretor_workflow_validation_error_jump_out_of_bounds():
|
||||||
|
lg = LogicGate(if_fail="jump_to_step_3", if_pass="continue")
|
||||||
|
ws1 = WorkStep(step=1, node="control_node", action="a1", desc="d1", logic_gate=lg)
|
||||||
|
ws2 = WorkStep(step=2, node="supervisory_node", action="a2", desc="d2")
|
||||||
|
wg = WorkerGroup(name="g1", primary_individual={}, composite_individual={})
|
||||||
|
with pytest.raises(ValueError, match="跳转目标 Step 3 越界了"):
|
||||||
|
PretorWorkflow(title="wf1", workgroup_list=[wg], work_link=[ws1, ws2], trace_id="t", event_info={"platform":"a", "username":"b"})
|
||||||
|
|
||||||
|
def test_pretor_workflow_validation_error_jump_format_error():
|
||||||
|
lg = LogicGate(if_fail="jump_to_step_invalid", if_pass="continue")
|
||||||
|
ws1 = WorkStep(step=1, node="control_node", action="a1", desc="d1", logic_gate=lg)
|
||||||
|
wg = WorkerGroup(name="g1", primary_individual={}, composite_individual={})
|
||||||
|
with pytest.raises(ValueError, match="LogicGate 格式错误"):
|
||||||
|
PretorWorkflow(title="wf1", workgroup_list=[wg], work_link=[ws1], trace_id="t", event_info={"platform":"a", "username":"b"})
|
||||||
|
|
||||||
|
def test_workflow_status():
|
||||||
|
status = WorkflowStatus()
|
||||||
|
assert status.step == 1
|
||||||
|
assert status.status == "waiting_llm_working"
|
||||||
|
assert status.demand is None
|
||||||
Loading…
Reference in New Issue