feat(system):优化后端
1.新增后端测试 2.增加了后端的加密 3.增加了i18n(国际化)
This commit is contained in:
@@ -1,61 +0,0 @@
|
||||
import pytest
|
||||
from unittest.mock import MagicMock, patch
|
||||
from kilostar.adapter.model_adapter.agent_factory import AgentFactory
|
||||
from kilostar.utils.error import ModelNotExistError
|
||||
|
||||
|
||||
def test_create_agent_success_real():
|
||||
mock_provider = MagicMock()
|
||||
mock_provider.provider_type = "openai"
|
||||
mock_provider.provider_models = ["gpt-4"]
|
||||
mock_provider.provider_apikey = "key"
|
||||
mock_provider.provider_url = "url"
|
||||
|
||||
with patch("kilostar.adapter.model_adapter.agent_factory.Agent") as mock_agent_cls:
|
||||
with patch(
|
||||
"kilostar.adapter.model_adapter.agent_factory.OpenAIChatModel"
|
||||
) as mock_model_cls:
|
||||
with patch(
|
||||
"kilostar.adapter.model_adapter.agent_factory.OpenAIProvider"
|
||||
) as mock_provider_cls:
|
||||
factory = AgentFactory()
|
||||
agent = factory.create_agent(
|
||||
provider=mock_provider,
|
||||
model_id="gpt-4",
|
||||
output_type=str,
|
||||
system_prompt="You are an AI",
|
||||
deps_type=dict,
|
||||
agent_name="myagent",
|
||||
)
|
||||
mock_provider_cls.assert_called_once_with(api_key="key", base_url="url")
|
||||
mock_model_cls.assert_called_once_with(
|
||||
"gpt-4", provider=mock_provider_cls.return_value
|
||||
)
|
||||
mock_agent_cls.assert_called_once_with(
|
||||
model=mock_model_cls.return_value,
|
||||
name="myagent",
|
||||
system_prompt="You are an AI",
|
||||
output_type=str,
|
||||
deps_type=dict,
|
||||
tools=None,
|
||||
)
|
||||
assert agent == mock_agent_cls.return_value
|
||||
|
||||
|
||||
def test_create_agent_model_not_exist():
|
||||
factory = AgentFactory()
|
||||
mock_provider = MagicMock()
|
||||
mock_provider.provider_models = ["gpt-3"]
|
||||
|
||||
with pytest.raises(ModelNotExistError):
|
||||
factory.create_agent(mock_provider, "gpt-4", str, "prompt", dict, "agent")
|
||||
|
||||
|
||||
def test_create_agent_invalid_provider_type():
|
||||
factory = AgentFactory()
|
||||
mock_provider = MagicMock()
|
||||
mock_provider.provider_type = "unknown"
|
||||
mock_provider.provider_models = ["gpt-4"]
|
||||
|
||||
with pytest.raises(ValueError, match="不支持的协议类型: unknown"):
|
||||
factory.create_agent(mock_provider, "gpt-4", str, "prompt", dict, "agent")
|
||||
@@ -0,0 +1,87 @@
|
||||
"""Pytest 全局 fixture:把 Ray Actor 句柄、PostgresDatabase、loguru 等重副作用模块替换成可控的 stub。"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import sys
|
||||
import types
|
||||
from typing import Any, Dict, Optional
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
|
||||
# ─────────────────────────────────────────────────────────────────────────────
|
||||
# Ray actor 句柄存根:测试期不真正连 Ray,由 conftest 注入名字 -> AsyncMock 句柄
|
||||
# ─────────────────────────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
class _FakeActorRegistry:
|
||||
"""模拟 ``ray.get_actor`` 行为:测试可往里塞名字 -> AsyncMock。"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self._actors: Dict[str, Any] = {}
|
||||
|
||||
def register(self, name: str, handle: Any) -> None:
|
||||
self._actors[name] = handle
|
||||
|
||||
def get(self, name: str, namespace: str = "kilostar"): # noqa: ARG002
|
||||
if name not in self._actors:
|
||||
raise ValueError(f"FakeActorRegistry: actor {name!r} not registered")
|
||||
return self._actors[name]
|
||||
|
||||
def clear(self) -> None:
|
||||
self._actors.clear()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def fake_actors(monkeypatch) -> _FakeActorRegistry:
|
||||
"""把 ``kilostar.utils.ray_hook._get_cached_actor_handle`` 的实现替换为 fake registry。
|
||||
|
||||
用法::
|
||||
|
||||
def test_xxx(fake_actors):
|
||||
gsm = AsyncMock()
|
||||
gsm.get_tool_config.remote = AsyncMock(return_value={"api_key": "k"})
|
||||
fake_actors.register("global_state_machine", types.SimpleNamespace(global_state_machine=gsm))
|
||||
"""
|
||||
registry = _FakeActorRegistry()
|
||||
|
||||
from kilostar.utils import ray_hook
|
||||
|
||||
ray_hook.clear_actor_cache()
|
||||
original = ray_hook._get_cached_actor_handle
|
||||
|
||||
def _stub(actor_name: str):
|
||||
return registry.get(actor_name)
|
||||
|
||||
_stub.cache_clear = lambda: None # type: ignore[attr-defined]
|
||||
monkeypatch.setattr(ray_hook, "_get_cached_actor_handle", _stub)
|
||||
yield registry
|
||||
registry.clear()
|
||||
monkeypatch.setattr(ray_hook, "_get_cached_actor_handle", original)
|
||||
ray_hook.clear_actor_cache()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def gsm_handle(fake_actors) -> MagicMock:
|
||||
"""快捷 fixture:注册一个名为 ``global_state_machine`` 的 actor,返回其内部 mock。
|
||||
|
||||
内部 mock 的方法默认全部是 ``AsyncMock``,调用 ``.remote(...)`` 会按 AsyncMock 规则返回。
|
||||
"""
|
||||
gsm = MagicMock()
|
||||
container = types.SimpleNamespace(global_state_machine=gsm)
|
||||
fake_actors.register("global_state_machine", container)
|
||||
return gsm
|
||||
|
||||
|
||||
# ─────────────────────────────────────────────────────────────────────────────
|
||||
# 工具:构造一个 ``MCP_AVAILABLE`` 关闭的 mcp_helper 状态
|
||||
# ─────────────────────────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mcp_unavailable(monkeypatch):
|
||||
from kilostar.utils import mcp_helper
|
||||
|
||||
monkeypatch.setattr(mcp_helper, "_MCP_AVAILABLE", False)
|
||||
yield
|
||||
@@ -1,86 +0,0 @@
|
||||
import pytest
|
||||
from unittest.mock import patch
|
||||
from sqlalchemy.exc import IntegrityError, OperationalError
|
||||
from pydantic import ValidationError
|
||||
from kilostar.utils.error import UserNotExistError
|
||||
from kilostar.core.postgres_database import database_exception
|
||||
|
||||
|
||||
@database_exception
|
||||
async def success_func():
|
||||
return "success"
|
||||
|
||||
|
||||
@database_exception
|
||||
async def validation_error_func():
|
||||
raise ValidationError.from_exception_data(title="Mock", line_errors=[])
|
||||
|
||||
|
||||
@database_exception
|
||||
async def integrity_error_func():
|
||||
raise IntegrityError("mock_statement", "mock_params", "mock_orig")
|
||||
|
||||
|
||||
@database_exception
|
||||
async def operational_error_func():
|
||||
raise OperationalError("mock_statement", "mock_params", "mock_orig")
|
||||
|
||||
|
||||
@database_exception
|
||||
async def user_not_exist_error_func():
|
||||
raise UserNotExistError("mock user")
|
||||
|
||||
|
||||
@database_exception
|
||||
async def exception_func():
|
||||
raise Exception("mock generic exception")
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_success_func():
|
||||
assert await success_func() == "success"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("kilostar.core.database.database_exception.logger")
|
||||
async def test_validation_error(mock_logger):
|
||||
with pytest.raises(ValidationError):
|
||||
await validation_error_func()
|
||||
mock_logger.error.assert_called_once()
|
||||
assert "对象校验失败" in mock_logger.error.call_args[0][0]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("kilostar.core.database.database_exception.logger")
|
||||
async def test_integrity_error(mock_logger):
|
||||
with pytest.raises(IntegrityError):
|
||||
await integrity_error_func()
|
||||
mock_logger.error.assert_called_once()
|
||||
assert "数据库完整性错误" in mock_logger.error.call_args[0][0]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("kilostar.core.database.database_exception.logger")
|
||||
async def test_operational_error(mock_logger):
|
||||
with pytest.raises(OperationalError):
|
||||
await operational_error_func()
|
||||
mock_logger.error.assert_called_once()
|
||||
assert "数据库连接异常" in mock_logger.error.call_args[0][0]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("kilostar.core.database.database_exception.logger")
|
||||
async def test_user_not_exist_error(mock_logger):
|
||||
result = await user_not_exist_error_func()
|
||||
assert result is None
|
||||
mock_logger.error.assert_called_once()
|
||||
assert "更改密码失败,用户不存在" in mock_logger.error.call_args[0][0]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("kilostar.core.database.database_exception.logger")
|
||||
async def test_generic_exception(mock_logger):
|
||||
with pytest.raises(Exception, match="mock generic exception"):
|
||||
await exception_func()
|
||||
mock_logger.exception.assert_called_once()
|
||||
assert "未预期的数据库错误" in mock_logger.exception.call_args[0][0]
|
||||
@@ -1,190 +0,0 @@
|
||||
import pytest
|
||||
from unittest.mock import MagicMock, AsyncMock, patch
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def mock_dependencies():
|
||||
with patch("kilostar.core.database.module.user.User") as mock_user_cls:
|
||||
mock_user_cls.user_name = MagicMock()
|
||||
with patch("kilostar.core.database.module.user.select") as mock_select:
|
||||
yield mock_user_cls, mock_select
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_session_maker():
|
||||
maker = MagicMock()
|
||||
session = AsyncMock()
|
||||
session.add = MagicMock()
|
||||
session.delete = MagicMock()
|
||||
maker.return_value.__aenter__.return_value = session
|
||||
maker.__aenter__.return_value = session
|
||||
maker.__aexit__ = AsyncMock()
|
||||
return maker, session
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_add_user(mock_session_maker, mock_dependencies):
|
||||
mock_user_cls, _ = mock_dependencies
|
||||
from kilostar.core.postgres_database.module import AuthDatabase
|
||||
|
||||
maker, session = mock_session_maker
|
||||
db = AuthDatabase(maker)
|
||||
|
||||
mock_user = MagicMock()
|
||||
mock_user.user_name = "testuser"
|
||||
mock_user.hashed_password = "hashedpw"
|
||||
mock_user_cls.return_value = mock_user
|
||||
|
||||
mock_exec_result = MagicMock()
|
||||
mock_exec_result.first.return_value = None
|
||||
session.execute = AsyncMock(return_value=mock_exec_result)
|
||||
|
||||
user = await db.add_user("testuser", "hashedpw")
|
||||
|
||||
assert user.user_name == "testuser"
|
||||
assert user.hashed_password == "hashedpw"
|
||||
session.add.assert_called_once_with(mock_user)
|
||||
session.commit.assert_called_once()
|
||||
session.refresh.assert_called_once_with(mock_user)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_change_password_success(mock_session_maker, mock_dependencies):
|
||||
mock_user_cls, mock_select = mock_dependencies
|
||||
from kilostar.core.postgres_database.module import AuthDatabase
|
||||
|
||||
maker, session = mock_session_maker
|
||||
db = AuthDatabase(maker)
|
||||
|
||||
mock_statement = MagicMock()
|
||||
mock_select.return_value.where.return_value = mock_statement
|
||||
|
||||
from kilostar.utils.access import Accessor
|
||||
|
||||
mock_user = MagicMock()
|
||||
mock_user.hashed_password = Accessor.hash_password("old_password")
|
||||
|
||||
mock_exec_result = MagicMock()
|
||||
mock_exec_result.scalar_one_or_none.return_value = mock_user
|
||||
session.execute = AsyncMock(return_value=mock_exec_result)
|
||||
|
||||
user = await db.change_password("testuser", "old_password", "new_password")
|
||||
|
||||
session.execute.assert_called_once_with(mock_statement)
|
||||
assert user.hashed_password == "new_password"
|
||||
session.add.assert_called_once_with(mock_user)
|
||||
session.commit.assert_called_once()
|
||||
session.refresh.assert_called_once_with(mock_user)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_change_password_user_not_exist(mock_session_maker, mock_dependencies):
|
||||
mock_user_cls, mock_select = mock_dependencies
|
||||
from kilostar.core.postgres_database.module import AuthDatabase
|
||||
|
||||
maker, session = mock_session_maker
|
||||
db = AuthDatabase(maker)
|
||||
|
||||
mock_exec_result = MagicMock()
|
||||
mock_exec_result.scalar_one_or_none.return_value = None
|
||||
session.execute = AsyncMock(return_value=mock_exec_result)
|
||||
|
||||
result = await db.change_password("testuser", "old_password", "new_password")
|
||||
assert result is None
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_change_password_wrong_password(mock_session_maker, mock_dependencies):
|
||||
mock_user_cls, mock_select = mock_dependencies
|
||||
from kilostar.core.postgres_database.module import AuthDatabase
|
||||
|
||||
maker, session = mock_session_maker
|
||||
db = AuthDatabase(maker)
|
||||
|
||||
from kilostar.utils.access import Accessor
|
||||
|
||||
mock_user = MagicMock()
|
||||
mock_user.hashed_password = Accessor.hash_password("actual_password")
|
||||
mock_exec_result = MagicMock()
|
||||
mock_exec_result.scalar_one_or_none.return_value = mock_user
|
||||
session.execute = AsyncMock(return_value=mock_exec_result)
|
||||
|
||||
from kilostar.utils.error import UserPasswordError
|
||||
|
||||
with pytest.raises(UserPasswordError):
|
||||
await db.change_password("testuser", "old_password", "new_password")
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_delete_user_success(mock_session_maker, mock_dependencies):
|
||||
mock_user_cls, mock_select = mock_dependencies
|
||||
from kilostar.core.postgres_database.module import AuthDatabase
|
||||
|
||||
maker, session = mock_session_maker
|
||||
db = AuthDatabase(maker)
|
||||
|
||||
mock_statement = MagicMock()
|
||||
mock_select.return_value.where.return_value = mock_statement
|
||||
|
||||
mock_user = MagicMock()
|
||||
mock_exec_result = MagicMock()
|
||||
mock_exec_result.scalar_one_or_none.return_value = mock_user
|
||||
session.execute = AsyncMock(return_value=mock_exec_result)
|
||||
|
||||
await db.delete_user("testuser")
|
||||
session.execute.assert_called_once_with(mock_statement)
|
||||
session.delete.assert_called_once_with(mock_user)
|
||||
session.commit.assert_called_once()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_delete_user_not_exist(mock_session_maker, mock_dependencies):
|
||||
mock_user_cls, mock_select = mock_dependencies
|
||||
from kilostar.core.postgres_database.module import AuthDatabase
|
||||
|
||||
maker, session = mock_session_maker
|
||||
db = AuthDatabase(maker)
|
||||
|
||||
mock_exec_result = MagicMock()
|
||||
mock_exec_result.scalar_one_or_none.return_value = None
|
||||
session.execute = AsyncMock(return_value=mock_exec_result)
|
||||
|
||||
result = await db.delete_user("testuser")
|
||||
assert result is None
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_login_user_success(mock_session_maker, mock_dependencies):
|
||||
mock_user_cls, mock_select = mock_dependencies
|
||||
from kilostar.core.postgres_database.module import AuthDatabase
|
||||
|
||||
maker, session = mock_session_maker
|
||||
db = AuthDatabase(maker)
|
||||
|
||||
mock_statement = MagicMock()
|
||||
mock_select.return_value.where.return_value = mock_statement
|
||||
|
||||
mock_user = MagicMock()
|
||||
mock_exec_result = MagicMock()
|
||||
mock_exec_result.scalar_one_or_none.return_value = mock_user
|
||||
session.execute = AsyncMock(return_value=mock_exec_result)
|
||||
|
||||
user = await db.login_user("testuser")
|
||||
session.execute.assert_called_once_with(mock_statement)
|
||||
assert user == mock_user
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_login_user_not_exist(mock_session_maker, mock_dependencies):
|
||||
mock_user_cls, mock_select = mock_dependencies
|
||||
from kilostar.core.postgres_database.module import AuthDatabase
|
||||
|
||||
maker, session = mock_session_maker
|
||||
db = AuthDatabase(maker)
|
||||
|
||||
mock_exec_result = MagicMock()
|
||||
mock_exec_result.scalar_one_or_none.return_value = None
|
||||
session.execute = AsyncMock(return_value=mock_exec_result)
|
||||
|
||||
result = await db.login_user("testuser")
|
||||
assert result is None
|
||||
@@ -1,15 +0,0 @@
|
||||
from kilostar.core.postgres_database.model import Provider
|
||||
|
||||
|
||||
def test_provider_table():
|
||||
# Provide required fields
|
||||
provider = Provider(
|
||||
provider_title="title",
|
||||
provider_url="url",
|
||||
provider_apikey="key",
|
||||
provider_models=["model_1"],
|
||||
provider_type="type",
|
||||
provider_owner=1,
|
||||
)
|
||||
assert Provider.__tablename__ == "provider"
|
||||
assert provider.provider_title == "title"
|
||||
@@ -1,7 +0,0 @@
|
||||
from kilostar.core.postgres_database.model import User
|
||||
|
||||
|
||||
def test_user_table():
|
||||
user = User(user_id="id", user_name="name", hashed_password="pw")
|
||||
assert User.__tablename__ == "user"
|
||||
assert user.user_name == "name"
|
||||
@@ -1,128 +0,0 @@
|
||||
import pytest
|
||||
from unittest.mock import MagicMock, AsyncMock, patch
|
||||
import sys
|
||||
import builtins
|
||||
|
||||
real_import = builtins.__import__
|
||||
|
||||
|
||||
def mock_import(name, globals=None, locals=None, fromlist=(), level=0):
|
||||
if name == "ray":
|
||||
mock_ray = MagicMock()
|
||||
|
||||
def mock_remote(*args, **kwargs):
|
||||
if len(args) == 1 and callable(args[0]):
|
||||
return args[0]
|
||||
|
||||
def decorator(cls):
|
||||
return cls
|
||||
|
||||
return decorator
|
||||
|
||||
mock_ray.remote = mock_remote
|
||||
return mock_ray
|
||||
return real_import(name, globals, locals, fromlist, level)
|
||||
|
||||
|
||||
builtins.__import__ = mock_import
|
||||
|
||||
for mod in list(sys.modules.keys()):
|
||||
if "kilostar.core.global_state_machine.global_state_machine" in mod or "ray" in mod:
|
||||
del sys.modules[mod]
|
||||
|
||||
from kilostar.core.global_state_machine.global_state_machine import GlobalStateMachine # noqa: E402
|
||||
|
||||
builtins.__import__ = real_import
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_postgres():
|
||||
return MagicMock()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def gsm(mock_postgres):
|
||||
manager = GlobalStateMachine(mock_postgres)
|
||||
return manager
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_add_provider_success(gsm, mock_postgres):
|
||||
mock_provider_class = AsyncMock()
|
||||
mock_provider = MagicMock()
|
||||
mock_provider.provider_title = "MyProvider"
|
||||
mock_provider.provider_url = "url"
|
||||
mock_provider.provider_apikey = "key"
|
||||
mock_provider.provider_models = ["model"]
|
||||
mock_provider.provider_type = "openai"
|
||||
mock_provider_class.create_provider.return_value = mock_provider
|
||||
|
||||
gsm._global_provider_manager.provider_mapper = {"openai": mock_provider_class}
|
||||
gsm._global_provider_manager.provider_register = {}
|
||||
|
||||
mock_add_provider = AsyncMock()
|
||||
mock_postgres.add_provider_db = MagicMock()
|
||||
mock_postgres.add_provider_db.remote = mock_add_provider
|
||||
|
||||
await gsm.add_provider_wrap("openai", "title", "url", "key", "1")
|
||||
|
||||
assert gsm._global_provider_manager.provider_register["title"] == mock_provider
|
||||
mock_add_provider.assert_called_once()
|
||||
assert mock_provider.provider_owner == "1"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_add_provider_unsupported(gsm):
|
||||
gsm._global_provider_manager.provider_mapper = {}
|
||||
with patch("kilostar.utils.logger.global_logger.bind") as mock_bind:
|
||||
mock_logger = MagicMock()
|
||||
mock_bind.return_value = mock_logger
|
||||
await gsm.add_provider_wrap("magic", "title", "url", "key", "1")
|
||||
mock_logger.warning.assert_called_with("Provider type magic is not supported.")
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_add_provider_request_error(gsm):
|
||||
from httpx import RequestError
|
||||
|
||||
mock_provider_class = AsyncMock()
|
||||
mock_provider_class.create_provider.side_effect = RequestError(
|
||||
"Network Error", request=MagicMock()
|
||||
)
|
||||
gsm._global_provider_manager.provider_mapper = {"openai": mock_provider_class}
|
||||
|
||||
with patch("kilostar.utils.logger.global_logger.bind") as mock_bind:
|
||||
from kilostar.utils.error import RetryableError
|
||||
import pytest
|
||||
|
||||
mock_logger = MagicMock()
|
||||
mock_bind.return_value = mock_logger
|
||||
with pytest.raises(RetryableError):
|
||||
await gsm.add_provider_wrap("openai", "title", "url", "key", "1")
|
||||
mock_logger.warning.assert_called()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_add_provider_generic_error(gsm):
|
||||
mock_provider_class = AsyncMock()
|
||||
mock_provider_class.create_provider.side_effect = ValueError("Some Error")
|
||||
gsm._global_provider_manager.provider_mapper = {"openai": mock_provider_class}
|
||||
|
||||
with patch("kilostar.utils.logger.global_logger.bind") as mock_bind:
|
||||
mock_logger = MagicMock()
|
||||
mock_bind.return_value = mock_logger
|
||||
await gsm.add_provider_wrap("openai", "title", "url", "key", "1")
|
||||
mock_logger.warning.assert_called_once()
|
||||
assert "解析模型列表时发生错误" in mock_logger.warning.call_args[0][0]
|
||||
|
||||
|
||||
def test_get_provider_list_and_get_provider(gsm):
|
||||
mock_provider = MagicMock()
|
||||
gsm._global_provider_manager.provider_register = {"p1": mock_provider}
|
||||
|
||||
assert gsm._global_provider_manager.get_provider_list() == {"p1": mock_provider}
|
||||
assert gsm._global_provider_manager.get_provider("p1") == mock_provider
|
||||
assert gsm._global_provider_manager.get_provider("missing") is None
|
||||
|
||||
|
||||
# noqa: E402
|
||||
@@ -1,32 +0,0 @@
|
||||
from kilostar.core.global_state_machine.model_provider.base_provider import (
|
||||
Provider,
|
||||
ProviderArgs,
|
||||
ProviderStatus,
|
||||
)
|
||||
|
||||
|
||||
def test_provider_status():
|
||||
assert ProviderStatus.UP == "up"
|
||||
assert ProviderStatus.DOWN == "down"
|
||||
|
||||
|
||||
def test_provider_args():
|
||||
args = ProviderArgs(
|
||||
provider_title="title",
|
||||
provider_url="url",
|
||||
provider_apikey="key",
|
||||
provider_owner="1",
|
||||
)
|
||||
assert args.provider_title == "title"
|
||||
|
||||
|
||||
def test_provider_model():
|
||||
p = Provider(
|
||||
provider_title="title",
|
||||
provider_url="url",
|
||||
provider_apikey="key",
|
||||
provider_models=["model"],
|
||||
provider_type="openai",
|
||||
)
|
||||
assert p.provider_status == ProviderStatus.UP
|
||||
assert p.provider_owner is None
|
||||
@@ -1,60 +0,0 @@
|
||||
import pytest
|
||||
from unittest.mock import patch, MagicMock, AsyncMock
|
||||
from kilostar.core.global_state_machine.model_provider.claude_provider import (
|
||||
ClaudeProvider,
|
||||
ProviderArgs,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def provider_args():
|
||||
return ProviderArgs(
|
||||
provider_title="TestClaude",
|
||||
provider_url="https://api.anthropic.com",
|
||||
provider_apikey="testkey",
|
||||
provider_owner="1",
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch(
|
||||
"kilostar.core.global_state_machine.model_provider.claude_provider.httpx.AsyncClient"
|
||||
)
|
||||
async def test_load_models_success(mock_client, provider_args):
|
||||
mock_response = MagicMock()
|
||||
mock_response.status_code = 200
|
||||
mock_response.json.return_value = {
|
||||
"data": [{"id": "claude-3-opus-20240229"}, {"id": "claude-3-sonnet-20240229"}]
|
||||
}
|
||||
|
||||
mock_client_instance = AsyncMock()
|
||||
mock_client_instance.get.return_value = mock_response
|
||||
mock_client.return_value.__aenter__.return_value = mock_client_instance
|
||||
|
||||
models = await ClaudeProvider._load_models(provider_args)
|
||||
assert models == ["claude-3-opus-20240229", "claude-3-sonnet-20240229"]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch(
|
||||
"kilostar.core.global_state_machine.model_provider.claude_provider.httpx.AsyncClient"
|
||||
)
|
||||
async def test_load_models_error(mock_client, provider_args):
|
||||
mock_client_instance = AsyncMock()
|
||||
mock_client_instance.get.side_effect = Exception("network error")
|
||||
mock_client.return_value.__aenter__.return_value = mock_client_instance
|
||||
|
||||
models = await ClaudeProvider._load_models(provider_args)
|
||||
assert models == []
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch(
|
||||
"kilostar.core.global_state_machine.model_provider.claude_provider.ClaudeProvider._load_models",
|
||||
return_value=["claude-3"],
|
||||
)
|
||||
async def test_create_provider(mock_load, provider_args):
|
||||
provider = await ClaudeProvider.create_provider(provider_args)
|
||||
assert provider.provider_title == "TestClaude"
|
||||
assert provider.provider_models == ["claude-3"]
|
||||
assert provider.provider_type == "claude"
|
||||
@@ -1,131 +0,0 @@
|
||||
import pytest
|
||||
from unittest.mock import patch, MagicMock, AsyncMock
|
||||
from kilostar.core.global_state_machine.model_provider.openai_provider import (
|
||||
OpenAIProvider,
|
||||
ProviderArgs,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def provider_args():
|
||||
return ProviderArgs(
|
||||
provider_title="TestOpenAI",
|
||||
provider_url="https://api.openai.com/v1",
|
||||
provider_apikey="testkey",
|
||||
provider_owner="1",
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def provider_args_no_v1():
|
||||
return ProviderArgs(
|
||||
provider_title="TestOpenAI",
|
||||
provider_url="https://api.openai.com",
|
||||
provider_apikey="testkey",
|
||||
provider_owner="1",
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch(
|
||||
"kilostar.core.global_state_machine.model_provider.openai_provider.httpx.AsyncClient"
|
||||
)
|
||||
async def test_load_models_success(mock_client, provider_args):
|
||||
mock_response = MagicMock()
|
||||
mock_response.status_code = 200
|
||||
mock_response.json.return_value = {
|
||||
"data": [{"id": "gpt-4"}, {"id": "gpt-3.5-turbo"}]
|
||||
}
|
||||
|
||||
mock_client_instance = AsyncMock()
|
||||
mock_client_instance.get.return_value = mock_response
|
||||
mock_client.return_value.__aenter__.return_value = mock_client_instance
|
||||
|
||||
models = await OpenAIProvider._load_models(provider_args)
|
||||
assert models == ["gpt-3.5-turbo", "gpt-4"]
|
||||
mock_client_instance.get.assert_called_once_with(
|
||||
"https://api.openai.com/v1/models",
|
||||
headers={"Authorization": "Bearer testkey", "Content-Type": "application/json"},
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch(
|
||||
"kilostar.core.global_state_machine.model_provider.openai_provider.httpx.AsyncClient"
|
||||
)
|
||||
async def test_load_models_no_v1(mock_client, provider_args_no_v1):
|
||||
mock_response = MagicMock()
|
||||
mock_response.status_code = 200
|
||||
mock_response.json.return_value = {"data": []}
|
||||
|
||||
mock_client_instance = AsyncMock()
|
||||
mock_client_instance.get.return_value = mock_response
|
||||
mock_client.return_value.__aenter__.return_value = mock_client_instance
|
||||
|
||||
models = await OpenAIProvider._load_models(provider_args_no_v1)
|
||||
assert models == []
|
||||
mock_client_instance.get.assert_called_once_with(
|
||||
"https://api.openai.com/v1/models",
|
||||
headers={"Authorization": "Bearer testkey", "Content-Type": "application/json"},
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch(
|
||||
"kilostar.core.global_state_machine.model_provider.openai_provider.httpx.AsyncClient"
|
||||
)
|
||||
async def test_load_models_status_error(mock_client, provider_args):
|
||||
mock_response = MagicMock()
|
||||
mock_response.status_code = 401
|
||||
|
||||
mock_client_instance = AsyncMock()
|
||||
mock_client_instance.get.return_value = mock_response
|
||||
mock_client.return_value.__aenter__.return_value = mock_client_instance
|
||||
|
||||
models = await OpenAIProvider._load_models(provider_args)
|
||||
assert models == []
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch(
|
||||
"kilostar.core.global_state_machine.model_provider.openai_provider.httpx.AsyncClient"
|
||||
)
|
||||
async def test_load_models_request_error(mock_client, provider_args):
|
||||
import httpx
|
||||
|
||||
mock_client_instance = AsyncMock()
|
||||
mock_client_instance.get.side_effect = httpx.RequestError(
|
||||
"network error", request=MagicMock()
|
||||
)
|
||||
mock_client.return_value.__aenter__.return_value = mock_client_instance
|
||||
|
||||
import pytest
|
||||
from kilostar.utils.error import RetryableError
|
||||
|
||||
with pytest.raises(RetryableError):
|
||||
await OpenAIProvider._load_models(provider_args)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch(
|
||||
"kilostar.core.global_state_machine.model_provider.openai_provider.httpx.AsyncClient"
|
||||
)
|
||||
async def test_load_models_generic_error(mock_client, provider_args):
|
||||
mock_client_instance = AsyncMock()
|
||||
mock_client_instance.get.side_effect = Exception("generic error")
|
||||
mock_client.return_value.__aenter__.return_value = mock_client_instance
|
||||
|
||||
models = await OpenAIProvider._load_models(provider_args)
|
||||
assert models == []
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch(
|
||||
"kilostar.core.global_state_machine.model_provider.openai_provider.OpenAIProvider._load_models",
|
||||
return_value=["gpt-4"],
|
||||
)
|
||||
async def test_create_provider(mock_load, provider_args):
|
||||
provider = await OpenAIProvider.create_provider(provider_args)
|
||||
assert provider.provider_title == "TestOpenAI"
|
||||
assert provider.provider_models == ["gpt-4"]
|
||||
assert provider.provider_type == "openai"
|
||||
@@ -1,30 +0,0 @@
|
||||
import pytest
|
||||
from unittest.mock import MagicMock, AsyncMock
|
||||
from kilostar.core.global_state_machine.provider_manager import ProviderManager
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_provider_manager_init():
|
||||
mock_postgres = MagicMock()
|
||||
mock_provider1 = MagicMock()
|
||||
mock_provider1.provider_title = "title1"
|
||||
|
||||
mock_provider2 = MagicMock()
|
||||
mock_provider2.provider_title = "title2"
|
||||
|
||||
mock_postgres.get_provider = MagicMock()
|
||||
mock_postgres.get_provider.remote = AsyncMock(
|
||||
return_value=[mock_provider1, mock_provider2]
|
||||
)
|
||||
|
||||
manager = ProviderManager(mock_postgres)
|
||||
mock_postgres.provider_database = MagicMock()
|
||||
mock_postgres.provider_database.remote = AsyncMock(
|
||||
return_value=[mock_provider1, mock_provider2]
|
||||
)
|
||||
await manager.init_provider_register(mock_postgres)
|
||||
|
||||
assert "openai" in manager.provider_mapper
|
||||
assert "claude" in manager.provider_mapper
|
||||
|
||||
assert manager.provider_register["title1"] == mock_provider1
|
||||
@@ -1,6 +0,0 @@
|
||||
from kilostar.core.global_state_machine.tool_manager import GlobalToolManager
|
||||
|
||||
|
||||
def test_global_tool_manager_init():
|
||||
manager = GlobalToolManager()
|
||||
assert isinstance(manager, GlobalToolManager)
|
||||
@@ -1,87 +0,0 @@
|
||||
import pytest
|
||||
from unittest.mock import patch, MagicMock
|
||||
import sys
|
||||
import builtins
|
||||
|
||||
real_import = builtins.__import__
|
||||
|
||||
|
||||
def mock_import(name, globals=None, locals=None, fromlist=(), level=0):
|
||||
if name == "ray":
|
||||
mock_ray = MagicMock()
|
||||
|
||||
def mock_remote(*args, **kwargs):
|
||||
if len(args) == 1 and callable(args[0]):
|
||||
return args[0]
|
||||
|
||||
def decorator(cls):
|
||||
return cls
|
||||
|
||||
return decorator
|
||||
|
||||
mock_ray.remote = mock_remote
|
||||
return mock_ray
|
||||
return real_import(name, globals, locals, fromlist, level)
|
||||
|
||||
|
||||
builtins.__import__ = mock_import
|
||||
for mod in list(sys.modules.keys()):
|
||||
if "kilostar.core.postgres_database.postgres" in mod or "ray" in mod:
|
||||
del sys.modules[mod]
|
||||
|
||||
from kilostar.core.postgres_database.postgres import PostgresDatabase # noqa: E402
|
||||
|
||||
builtins.__import__ = real_import
|
||||
|
||||
|
||||
@patch("kilostar.core.postgres_database.postgres.create_async_engine")
|
||||
@patch("kilostar.core.postgres_database.postgres.sessionmaker")
|
||||
@patch("kilostar.core.postgres_database.postgres.AuthDatabase")
|
||||
@patch("kilostar.core.postgres_database.postgres.ProviderDatabase")
|
||||
@patch("kilostar.core.postgres_database.postgres.os.environ.get")
|
||||
@pytest.mark.asyncio
|
||||
async def test_postgres_database(
|
||||
mock_env_get, mock_provider_db, mock_auth_db, mock_sessionmaker, mock_create_engine
|
||||
):
|
||||
def env_side_effect(key):
|
||||
return {
|
||||
"POSTGRES_USER": "testuser",
|
||||
"POSTGRES_PASSWORD": "testpassword",
|
||||
"POSTGRES_HOST": "localhost",
|
||||
"POSTGRES_PORT": "5432",
|
||||
"POSTGRES_DB": "testdb",
|
||||
}.get(key)
|
||||
|
||||
mock_env_get.side_effect = env_side_effect
|
||||
|
||||
mock_engine = MagicMock()
|
||||
mock_conn = MagicMock()
|
||||
from unittest.mock import AsyncMock
|
||||
|
||||
mock_conn.run_sync = AsyncMock()
|
||||
|
||||
mock_begin_ctx = MagicMock()
|
||||
mock_begin_ctx.__aenter__ = AsyncMock(return_value=mock_conn)
|
||||
mock_begin_ctx.__aexit__ = AsyncMock()
|
||||
mock_engine.begin.return_value = mock_begin_ctx
|
||||
mock_create_engine.return_value = mock_engine
|
||||
|
||||
db = PostgresDatabase()
|
||||
|
||||
mock_create_engine.assert_called_once_with(
|
||||
"postgresql+asyncpg://testuser:testpassword@localhost:5432/testdb", echo=True
|
||||
)
|
||||
mock_auth_db.assert_called_once()
|
||||
mock_provider_db.assert_called_once()
|
||||
mock_auth_db.return_value.get_user_authority = AsyncMock(return_value="test_auth")
|
||||
|
||||
with patch(
|
||||
"kilostar.core.postgres_database.postgres.SQLModel.metadata.create_all"
|
||||
) as mock_create_all:
|
||||
await db.init_db()
|
||||
mock_conn.run_sync.assert_called_once_with(mock_create_all)
|
||||
|
||||
assert await db.get_user_authority(user_id="123") == "test_auth"
|
||||
|
||||
|
||||
# noqa: E402
|
||||
@@ -0,0 +1,185 @@
|
||||
"""``AgentFactory.create_agent`` 的协议分发与异常分支。
|
||||
|
||||
为避免真的实例化 OpenAI / Anthropic / Google 等 provider,所有 ``provider_class``
|
||||
与 ``model_class`` 都被替换为 spy 类,仅记录构造参数。
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any, Dict
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
from kilostar.adapter.model_adapter import agent_factory as af_mod
|
||||
from kilostar.adapter.model_adapter.agent_factory import AgentFactory
|
||||
from kilostar.core.global_state_machine.model_provider.base_provider import Provider
|
||||
from kilostar.utils.agent_model import DepsModel, ResponseModel
|
||||
from kilostar.utils.error import ModelNotExistError
|
||||
|
||||
|
||||
class _SpyProvider:
|
||||
last_init: Dict[str, Any] = {}
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
type(self).last_init = kwargs
|
||||
|
||||
|
||||
class _SpyModel:
|
||||
last_init: Dict[str, Any] = {}
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
type(self).last_init = {"args": args, "kwargs": kwargs}
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def factory(monkeypatch):
|
||||
"""构造 AgentFactory 并将 provider/model class 全部替换成 spy 类。"""
|
||||
af = AgentFactory()
|
||||
for proto in af._models_mapping:
|
||||
af._models_mapping[proto]["model_class"] = type(
|
||||
f"_Model_{proto}", (_SpyModel,), {}
|
||||
)
|
||||
af._models_mapping[proto]["provider_class"] = type(
|
||||
f"_Provider_{proto}", (_SpyProvider,), {}
|
||||
)
|
||||
|
||||
fake_agent = MagicMock(name="Agent")
|
||||
fake_agent_class = MagicMock(name="Agent_class", return_value=fake_agent)
|
||||
monkeypatch.setattr(af_mod, "Agent", fake_agent_class)
|
||||
return af, fake_agent_class, fake_agent
|
||||
|
||||
|
||||
def _provider(provider_type: str = "openai") -> Provider:
|
||||
return Provider(
|
||||
provider_title="t",
|
||||
provider_url="https://example.com",
|
||||
provider_apikey="sk-123",
|
||||
provider_models=["m1", "m2"],
|
||||
provider_type=provider_type,
|
||||
)
|
||||
|
||||
|
||||
def test_create_agent_returns_agent_instance(factory):
|
||||
af, agent_cls, fake_agent = factory
|
||||
result = af.create_agent(
|
||||
provider=_provider("openai"),
|
||||
model_id="m1",
|
||||
output_type=ResponseModel,
|
||||
system_prompt="sys",
|
||||
deps_type=DepsModel,
|
||||
agent_name="hello",
|
||||
)
|
||||
assert result is fake_agent
|
||||
agent_cls.assert_called_once()
|
||||
kwargs = agent_cls.call_args.kwargs
|
||||
assert kwargs["name"] == "hello"
|
||||
assert kwargs["system_prompt"] == "sys"
|
||||
assert kwargs["tools"] == []
|
||||
assert kwargs["toolsets"] == []
|
||||
|
||||
|
||||
def test_create_agent_passes_through_tools_and_toolsets(factory):
|
||||
af, agent_cls, _ = factory
|
||||
tools = [lambda: 1]
|
||||
toolsets = [object()]
|
||||
af.create_agent(
|
||||
provider=_provider("openai"),
|
||||
model_id="m1",
|
||||
output_type=ResponseModel,
|
||||
system_prompt="sys",
|
||||
deps_type=DepsModel,
|
||||
agent_name="hello",
|
||||
tools=tools,
|
||||
toolsets=toolsets,
|
||||
)
|
||||
kwargs = agent_cls.call_args.kwargs
|
||||
assert kwargs["tools"] is tools
|
||||
assert kwargs["toolsets"] is toolsets
|
||||
|
||||
|
||||
def test_create_agent_unknown_model_raises(factory):
|
||||
af, _, _ = factory
|
||||
with pytest.raises(ModelNotExistError):
|
||||
af.create_agent(
|
||||
provider=_provider("openai"),
|
||||
model_id="not-in-list",
|
||||
output_type=ResponseModel,
|
||||
system_prompt="sys",
|
||||
deps_type=DepsModel,
|
||||
agent_name="hello",
|
||||
)
|
||||
|
||||
|
||||
def test_create_agent_unknown_protocol_raises(factory):
|
||||
af, _, _ = factory
|
||||
bad = _provider("openai")
|
||||
bad.provider_type = "weird-protocol"
|
||||
with pytest.raises(ValueError):
|
||||
af.create_agent(
|
||||
provider=bad,
|
||||
model_id="m1",
|
||||
output_type=ResponseModel,
|
||||
system_prompt="sys",
|
||||
deps_type=DepsModel,
|
||||
agent_name="hello",
|
||||
)
|
||||
|
||||
|
||||
def test_openai_protocol_passes_api_key_and_base_url(factory):
|
||||
af, _, _ = factory
|
||||
af.create_agent(
|
||||
provider=_provider("openai"),
|
||||
model_id="m1",
|
||||
output_type=ResponseModel,
|
||||
system_prompt="sys",
|
||||
deps_type=DepsModel,
|
||||
agent_name="hello",
|
||||
)
|
||||
init = af._models_mapping["openai"]["provider_class"].last_init
|
||||
assert init["api_key"] == "sk-123"
|
||||
assert init["base_url"] == "https://example.com"
|
||||
|
||||
|
||||
def test_claude_protocol_passes_api_key_only(factory):
|
||||
af, _, _ = factory
|
||||
af.create_agent(
|
||||
provider=_provider("claude"),
|
||||
model_id="m1",
|
||||
output_type=ResponseModel,
|
||||
system_prompt="sys",
|
||||
deps_type=DepsModel,
|
||||
agent_name="hello",
|
||||
)
|
||||
init = af._models_mapping["claude"]["provider_class"].last_init
|
||||
assert init["api_key"] == "sk-123"
|
||||
assert "base_url" not in init
|
||||
|
||||
|
||||
def test_gemini_protocol_uses_kwarg_model_name(factory):
|
||||
af, _, _ = factory
|
||||
af.create_agent(
|
||||
provider=_provider("gemini"),
|
||||
model_id="m1",
|
||||
output_type=ResponseModel,
|
||||
system_prompt="sys",
|
||||
deps_type=DepsModel,
|
||||
agent_name="hello",
|
||||
)
|
||||
init = af._models_mapping["gemini"]["model_class"].last_init
|
||||
assert init["kwargs"].get("model_name") == "m1"
|
||||
assert "provider" in init["kwargs"]
|
||||
|
||||
|
||||
def test_non_gemini_uses_positional_model_id(factory):
|
||||
af, _, _ = factory
|
||||
af.create_agent(
|
||||
provider=_provider("deepseek"),
|
||||
model_id="m1",
|
||||
output_type=ResponseModel,
|
||||
system_prompt="sys",
|
||||
deps_type=DepsModel,
|
||||
agent_name="hello",
|
||||
)
|
||||
init = af._models_mapping["deepseek"]["model_class"].last_init
|
||||
assert init["args"] == ("m1",)
|
||||
@@ -0,0 +1,77 @@
|
||||
"""``api/__init__.py`` 全局异常兜底 handler 与 CORS 中间件。"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import pytest
|
||||
from fastapi import FastAPI, Request
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from fastapi.responses import JSONResponse
|
||||
from httpx import AsyncClient, ASGITransport
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def app_with_global_handler():
|
||||
"""构造一个最小 app,套用与生产相同的兜底 handler。"""
|
||||
from kilostar.utils.logger import get_logger
|
||||
|
||||
_logger = get_logger("api")
|
||||
|
||||
app = FastAPI()
|
||||
|
||||
@app.exception_handler(Exception)
|
||||
async def unhandled_exception_handler(request: Request, exc: Exception):
|
||||
_logger.exception(
|
||||
f"Unhandled exception on {request.method} {request.url.path}: {exc}"
|
||||
)
|
||||
return JSONResponse(
|
||||
status_code=500,
|
||||
content={"message": "服务内部错误,请稍后重试"},
|
||||
)
|
||||
|
||||
@app.get("/boom")
|
||||
async def boom():
|
||||
raise RuntimeError("internal stack trace with secrets")
|
||||
|
||||
return app
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_unhandled_exception_returns_masked_500(app_with_global_handler):
|
||||
transport = ASGITransport(app=app_with_global_handler, raise_app_exceptions=False)
|
||||
async with AsyncClient(transport=transport, base_url="http://test") as client:
|
||||
resp = await client.get("/boom")
|
||||
|
||||
assert resp.status_code == 500
|
||||
body = resp.json()
|
||||
assert body == {"message": "服务内部错误,请稍后重试"}
|
||||
# 关键:不能把内部 traceback 透出
|
||||
assert "internal stack trace with secrets" not in resp.text
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_cors_preflight_with_explicit_origin():
|
||||
app = FastAPI()
|
||||
app.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=["https://app.example.com"],
|
||||
allow_credentials=True,
|
||||
allow_methods=["*"],
|
||||
allow_headers=["*"],
|
||||
)
|
||||
|
||||
@app.get("/x")
|
||||
async def _x():
|
||||
return {"ok": True}
|
||||
|
||||
transport = ASGITransport(app=app)
|
||||
async with AsyncClient(transport=transport, base_url="http://test") as client:
|
||||
resp = await client.options(
|
||||
"/x",
|
||||
headers={
|
||||
"Origin": "https://app.example.com",
|
||||
"Access-Control-Request-Method": "GET",
|
||||
},
|
||||
)
|
||||
|
||||
assert resp.headers.get("access-control-allow-origin") == "https://app.example.com"
|
||||
assert resp.headers.get("access-control-allow-credentials") == "true"
|
||||
@@ -0,0 +1,75 @@
|
||||
"""``api/chat.py`` 中 ``_ask_regulatory`` 与 RegulatoryNode.working 的接线。
|
||||
|
||||
历史上这里调用的是 ``regulatory_node.handle_chat_message``,但 ``RegulatoryNode``
|
||||
上从未定义该方法 —— 是从老串联架构遗留下来的死代码路径。新定位下 chat 入口
|
||||
应直接调 ``working(MessageRequest)``,本测试钉死这个契约不再回退。
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
|
||||
class _FakeActorRef:
|
||||
def __init__(self, target):
|
||||
self._target = target
|
||||
|
||||
def __getattr__(self, item):
|
||||
return getattr(self._target, item)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_ask_regulatory_calls_working_and_extracts_reply(monkeypatch):
|
||||
from kilostar.api import chat as chat_module
|
||||
from kilostar.core.individual.regulatory_node.template import MessageResponse
|
||||
|
||||
fake_resp = MessageResponse(
|
||||
platform="client", platform_id="chat-1", reply_message="你好"
|
||||
)
|
||||
|
||||
regulatory = MagicMock()
|
||||
regulatory.working = MagicMock()
|
||||
regulatory.working.remote = AsyncMock(return_value=fake_resp)
|
||||
|
||||
def _fake_hook(name):
|
||||
assert name == "regulatory_node"
|
||||
return SimpleNamespace(regulatory_node=_FakeActorRef(regulatory))
|
||||
|
||||
monkeypatch.setattr(chat_module, "ray_actor_hook", _fake_hook)
|
||||
|
||||
out = await chat_module._ask_regulatory(
|
||||
user_id="alice", chat_id="chat-1", message="hi"
|
||||
)
|
||||
|
||||
assert out == "你好"
|
||||
# 调用契约:MessageRequest,且 platform_id 取自 chat_id
|
||||
args, kwargs = regulatory.working.remote.call_args
|
||||
payload = args[0] if args else kwargs.get("payload")
|
||||
assert payload.platform == "client"
|
||||
assert payload.user_name == "alice"
|
||||
assert payload.platform_id == "chat-1"
|
||||
assert payload.message == "hi"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_ask_regulatory_returns_none_when_node_returns_none(monkeypatch):
|
||||
"""节点降级返回 None 时,上层应静默不写回 chat history。"""
|
||||
from kilostar.api import chat as chat_module
|
||||
|
||||
regulatory = MagicMock()
|
||||
regulatory.working = MagicMock()
|
||||
regulatory.working.remote = AsyncMock(return_value=None)
|
||||
|
||||
monkeypatch.setattr(
|
||||
chat_module,
|
||||
"ray_actor_hook",
|
||||
lambda name: SimpleNamespace(regulatory_node=_FakeActorRef(regulatory)),
|
||||
)
|
||||
|
||||
out = await chat_module._ask_regulatory(
|
||||
user_id="bob", chat_id="chat-2", message="hello"
|
||||
)
|
||||
assert out is None
|
||||
@@ -0,0 +1,142 @@
|
||||
"""``api/resource.py`` Custom toolset 路由:归属鉴权。"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import types
|
||||
from unittest.mock import AsyncMock
|
||||
|
||||
import pytest
|
||||
from fastapi import FastAPI
|
||||
from httpx import AsyncClient, ASGITransport
|
||||
|
||||
from kilostar.api.resource import resource_router
|
||||
from kilostar.core.postgres_database.model import UserAuthority
|
||||
from kilostar.utils.access import Accessor, TokenData
|
||||
|
||||
|
||||
def _fake_user(user_id: str = "alice"):
|
||||
return TokenData(user_id=user_id, username=user_id)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def app_with_user(monkeypatch):
|
||||
"""挂上 resource_router;用 dependency_overrides 跳过 JWT,并把 get_authority 默认放成 USER。"""
|
||||
app = FastAPI()
|
||||
app.include_router(resource_router)
|
||||
app.dependency_overrides[Accessor.get_current_user] = lambda: _fake_user("alice")
|
||||
|
||||
# 默认把权限置为 USER;具体 case 内部可再 monkeypatch 覆盖
|
||||
async def _default_authority(uid):
|
||||
return UserAuthority.USER
|
||||
|
||||
monkeypatch.setattr(
|
||||
"kilostar.utils.check_user.role_check.get_authority", _default_authority
|
||||
)
|
||||
return app
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_custom_toolset_forbidden_for_non_owner(
|
||||
app_with_user, fake_actors
|
||||
):
|
||||
gsm = types.SimpleNamespace()
|
||||
gsm.get_custom_toolset = types.SimpleNamespace(
|
||||
remote=AsyncMock(
|
||||
return_value={"toolset_id": "t1", "owner_id": "bob", "tools": []}
|
||||
)
|
||||
)
|
||||
fake_actors.register("global_state_machine", gsm)
|
||||
|
||||
transport = ASGITransport(app=app_with_user)
|
||||
async with AsyncClient(transport=transport, base_url="http://test") as client:
|
||||
resp = await client.get("/api/v1/resource/custom-toolset/t1")
|
||||
|
||||
assert resp.status_code == 403
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_custom_toolset_allowed_for_owner(app_with_user, fake_actors):
|
||||
gsm = types.SimpleNamespace()
|
||||
gsm.get_custom_toolset = types.SimpleNamespace(
|
||||
remote=AsyncMock(
|
||||
return_value={"toolset_id": "t1", "owner_id": "alice", "tools": []}
|
||||
)
|
||||
)
|
||||
fake_actors.register("global_state_machine", gsm)
|
||||
|
||||
transport = ASGITransport(app=app_with_user)
|
||||
async with AsyncClient(transport=transport, base_url="http://test") as client:
|
||||
resp = await client.get("/api/v1/resource/custom-toolset/t1")
|
||||
|
||||
assert resp.status_code == 200
|
||||
assert resp.json()["owner_id"] == "alice"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_custom_toolset_allowed_for_admin(
|
||||
app_with_user, fake_actors, monkeypatch
|
||||
):
|
||||
gsm = types.SimpleNamespace()
|
||||
gsm.get_custom_toolset = types.SimpleNamespace(
|
||||
remote=AsyncMock(
|
||||
return_value={"toolset_id": "t1", "owner_id": "bob", "tools": []}
|
||||
)
|
||||
)
|
||||
fake_actors.register("global_state_machine", gsm)
|
||||
|
||||
async def _admin(uid):
|
||||
return UserAuthority.SUPER_ADMINISTRATOR
|
||||
|
||||
monkeypatch.setattr(
|
||||
"kilostar.utils.check_user.role_check.get_authority", _admin
|
||||
)
|
||||
|
||||
transport = ASGITransport(app=app_with_user)
|
||||
async with AsyncClient(transport=transport, base_url="http://test") as client:
|
||||
resp = await client.get("/api/v1/resource/custom-toolset/t1")
|
||||
|
||||
assert resp.status_code == 200
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_list_custom_toolsets_filters_by_owner(app_with_user, fake_actors):
|
||||
all_sets = [
|
||||
{"toolset_id": "t1", "owner_id": "alice"},
|
||||
{"toolset_id": "t2", "owner_id": "bob"},
|
||||
]
|
||||
gsm = types.SimpleNamespace()
|
||||
gsm.list_custom_toolsets = types.SimpleNamespace(
|
||||
remote=AsyncMock(return_value=all_sets)
|
||||
)
|
||||
fake_actors.register("global_state_machine", gsm)
|
||||
|
||||
transport = ASGITransport(app=app_with_user)
|
||||
async with AsyncClient(transport=transport, base_url="http://test") as client:
|
||||
resp = await client.get("/api/v1/resource/custom-toolset")
|
||||
|
||||
assert resp.status_code == 200
|
||||
body = resp.json()
|
||||
assert len(body["toolsets"]) == 1
|
||||
assert body["toolsets"][0]["toolset_id"] == "t1"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_delete_custom_toolset_forbidden_for_non_owner(
|
||||
app_with_user, fake_actors
|
||||
):
|
||||
gsm = types.SimpleNamespace()
|
||||
gsm.get_custom_toolset = types.SimpleNamespace(
|
||||
remote=AsyncMock(
|
||||
return_value={"toolset_id": "t1", "owner_id": "bob"}
|
||||
)
|
||||
)
|
||||
delete_mock = AsyncMock(return_value=True)
|
||||
gsm.delete_custom_toolset = types.SimpleNamespace(remote=delete_mock)
|
||||
fake_actors.register("global_state_machine", gsm)
|
||||
|
||||
transport = ASGITransport(app=app_with_user)
|
||||
async with AsyncClient(transport=transport, base_url="http://test") as client:
|
||||
resp = await client.delete("/api/v1/resource/custom-toolset/t1")
|
||||
|
||||
assert resp.status_code == 403
|
||||
delete_mock.assert_not_called()
|
||||
@@ -0,0 +1,81 @@
|
||||
"""``api/health.py`` 健康探针端点。"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import types
|
||||
from unittest.mock import AsyncMock
|
||||
|
||||
import pytest
|
||||
from fastapi import FastAPI
|
||||
from httpx import AsyncClient, ASGITransport
|
||||
|
||||
from kilostar.api.health import health_router
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def health_app() -> FastAPI:
|
||||
app = FastAPI()
|
||||
app.include_router(health_router)
|
||||
return app
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_liveness_returns_alive(health_app):
|
||||
transport = ASGITransport(app=health_app)
|
||||
async with AsyncClient(transport=transport, base_url="http://test") as client:
|
||||
resp = await client.get("/health/live")
|
||||
assert resp.status_code == 200
|
||||
assert resp.json() == {"status": "alive"}
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_readiness_all_ok(health_app, fake_actors):
|
||||
pg = types.SimpleNamespace()
|
||||
pg.ping = types.SimpleNamespace(remote=AsyncMock(return_value=True))
|
||||
fake_actors.register("postgres_database", pg)
|
||||
|
||||
gsm = types.SimpleNamespace()
|
||||
gsm.get_skill_list = types.SimpleNamespace(remote=AsyncMock(return_value=[]))
|
||||
fake_actors.register("global_state_machine", gsm)
|
||||
|
||||
transport = ASGITransport(app=health_app)
|
||||
async with AsyncClient(transport=transport, base_url="http://test") as client:
|
||||
resp = await client.get("/health/ready")
|
||||
|
||||
assert resp.status_code == 200
|
||||
body = resp.json()
|
||||
assert body["status"] == "ready"
|
||||
assert body["checks"] == {"postgres": True, "global_state_machine": True}
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_readiness_postgres_down(health_app, fake_actors):
|
||||
pg = types.SimpleNamespace()
|
||||
pg.ping = types.SimpleNamespace(
|
||||
remote=AsyncMock(side_effect=RuntimeError("db down"))
|
||||
)
|
||||
fake_actors.register("postgres_database", pg)
|
||||
|
||||
gsm = types.SimpleNamespace()
|
||||
gsm.get_skill_list = types.SimpleNamespace(remote=AsyncMock(return_value=[]))
|
||||
fake_actors.register("global_state_machine", gsm)
|
||||
|
||||
transport = ASGITransport(app=health_app)
|
||||
async with AsyncClient(transport=transport, base_url="http://test") as client:
|
||||
resp = await client.get("/health/ready")
|
||||
|
||||
assert resp.status_code == 503
|
||||
body = resp.json()
|
||||
assert body["status"] == "not_ready"
|
||||
assert body["checks"]["postgres"] is False
|
||||
assert body["checks"]["global_state_machine"] is True
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_readiness_actor_not_registered(health_app, fake_actors):
|
||||
transport = ASGITransport(app=health_app)
|
||||
async with AsyncClient(transport=transport, base_url="http://test") as client:
|
||||
resp = await client.get("/health/ready")
|
||||
|
||||
assert resp.status_code == 503
|
||||
assert resp.json()["status"] == "not_ready"
|
||||
@@ -0,0 +1,168 @@
|
||||
"""``api/platform/onebot.py`` 中纯函数与事件分派的覆盖。"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
|
||||
import pytest
|
||||
from fastapi import HTTPException
|
||||
|
||||
from kilostar.api.platform import onebot as onebot_mod
|
||||
from kilostar.core.individual.regulatory_node.template import MessageResponse
|
||||
|
||||
|
||||
# ─── _verify_token ──────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
def test_verify_token_skipped_when_env_missing(monkeypatch):
|
||||
monkeypatch.delenv("ONEBOT_ACCESS_TOKEN", raising=False)
|
||||
onebot_mod._verify_token(None)
|
||||
onebot_mod._verify_token("anything")
|
||||
|
||||
|
||||
def test_verify_token_accepts_bearer_prefix(monkeypatch):
|
||||
monkeypatch.setenv("ONEBOT_ACCESS_TOKEN", "expected")
|
||||
onebot_mod._verify_token("Bearer expected")
|
||||
|
||||
|
||||
def test_verify_token_accepts_token_prefix(monkeypatch):
|
||||
monkeypatch.setenv("ONEBOT_ACCESS_TOKEN", "expected")
|
||||
onebot_mod._verify_token("Token expected")
|
||||
|
||||
|
||||
def test_verify_token_accepts_raw_token(monkeypatch):
|
||||
monkeypatch.setenv("ONEBOT_ACCESS_TOKEN", "expected")
|
||||
onebot_mod._verify_token("expected")
|
||||
|
||||
|
||||
def test_verify_token_rejects_missing(monkeypatch):
|
||||
monkeypatch.setenv("ONEBOT_ACCESS_TOKEN", "expected")
|
||||
with pytest.raises(HTTPException) as exc:
|
||||
onebot_mod._verify_token(None)
|
||||
assert exc.value.status_code == 401
|
||||
|
||||
|
||||
def test_verify_token_rejects_wrong_token(monkeypatch):
|
||||
monkeypatch.setenv("ONEBOT_ACCESS_TOKEN", "expected")
|
||||
with pytest.raises(HTTPException) as exc:
|
||||
onebot_mod._verify_token("Bearer wrong")
|
||||
assert exc.value.status_code == 401
|
||||
|
||||
|
||||
# ─── _extract_plain_text ────────────────────────────────────────────────────
|
||||
|
||||
|
||||
def test_extract_plain_text_handles_string():
|
||||
assert onebot_mod._extract_plain_text("hello") == "hello"
|
||||
|
||||
|
||||
def test_extract_plain_text_handles_segment_array():
|
||||
seg = [
|
||||
{"type": "text", "data": {"text": "hello "}},
|
||||
{"type": "image", "data": {"file": "1.png"}},
|
||||
{"type": "text", "data": {"text": "world"}},
|
||||
]
|
||||
assert onebot_mod._extract_plain_text(seg) == "hello world"
|
||||
|
||||
|
||||
def test_extract_plain_text_handles_unknown_type():
|
||||
assert onebot_mod._extract_plain_text(None) == ""
|
||||
assert onebot_mod._extract_plain_text({"foo": "bar"}) == ""
|
||||
|
||||
|
||||
# ─── _dispatch_event ────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def regulatory_actor(fake_actors):
|
||||
"""注入一个 regulatory_node Mock,默认返回固定的 MessageResponse。"""
|
||||
inner = MagicMock()
|
||||
inner.working.remote = AsyncMock(
|
||||
return_value=MessageResponse(
|
||||
platform="onebot", platform_id="private:1234", reply_message="pong"
|
||||
)
|
||||
)
|
||||
fake_actors.register("regulatory_node", inner)
|
||||
return inner
|
||||
|
||||
|
||||
async def test_dispatch_event_ignores_non_message(regulatory_actor):
|
||||
res = await onebot_mod._dispatch_event(
|
||||
{"post_type": "meta_event", "meta_event_type": "heartbeat"}
|
||||
)
|
||||
assert res is None
|
||||
regulatory_actor.working.remote.assert_not_called()
|
||||
|
||||
|
||||
async def test_dispatch_event_ignores_empty_text(regulatory_actor):
|
||||
res = await onebot_mod._dispatch_event(
|
||||
{
|
||||
"post_type": "message",
|
||||
"message_type": "private",
|
||||
"user_id": 1234,
|
||||
"message": " ",
|
||||
}
|
||||
)
|
||||
assert res is None
|
||||
regulatory_actor.working.remote.assert_not_called()
|
||||
|
||||
|
||||
async def test_dispatch_event_private_message_returns_quick_reply(regulatory_actor):
|
||||
payload = {
|
||||
"post_type": "message",
|
||||
"message_type": "private",
|
||||
"user_id": 1234,
|
||||
"sender": {"nickname": "alice"},
|
||||
"message": "ping",
|
||||
}
|
||||
res = await onebot_mod._dispatch_event(payload)
|
||||
assert res is not None
|
||||
assert res["reply"] == "pong"
|
||||
assert "at_sender" not in res
|
||||
assert res["_target"]["message_type"] == "private"
|
||||
assert res["_target"]["user_id"] == 1234
|
||||
|
||||
|
||||
async def test_dispatch_event_group_message_includes_at_sender(regulatory_actor):
|
||||
payload = {
|
||||
"post_type": "message",
|
||||
"message_type": "group",
|
||||
"user_id": 1234,
|
||||
"group_id": 5678,
|
||||
"sender": {"card": "alice", "nickname": "fallback"},
|
||||
"message": [{"type": "text", "data": {"text": "ping"}}],
|
||||
}
|
||||
res = await onebot_mod._dispatch_event(payload)
|
||||
assert res["reply"] == "pong"
|
||||
assert res["at_sender"] is False
|
||||
assert res["_target"]["group_id"] == 5678
|
||||
|
||||
|
||||
async def test_dispatch_event_swallows_actor_error(regulatory_actor):
|
||||
regulatory_actor.working.remote = AsyncMock(side_effect=RuntimeError("ray fail"))
|
||||
payload = {
|
||||
"post_type": "message",
|
||||
"message_type": "private",
|
||||
"user_id": 1,
|
||||
"message": "hi",
|
||||
}
|
||||
res = await onebot_mod._dispatch_event(payload)
|
||||
assert res is None
|
||||
|
||||
|
||||
async def test_dispatch_event_returns_none_when_reply_empty(fake_actors):
|
||||
inner = MagicMock()
|
||||
inner.working.remote = AsyncMock(
|
||||
return_value=MessageResponse(
|
||||
platform="onebot", platform_id="private:1", reply_message=""
|
||||
)
|
||||
)
|
||||
fake_actors.register("regulatory_node", inner)
|
||||
payload = {
|
||||
"post_type": "message",
|
||||
"message_type": "private",
|
||||
"user_id": 1,
|
||||
"message": "hi",
|
||||
}
|
||||
res = await onebot_mod._dispatch_event(payload)
|
||||
assert res is None
|
||||
@@ -0,0 +1,50 @@
|
||||
"""``api/resource.py`` 中的脱敏纯函数。"""
|
||||
|
||||
import pytest
|
||||
|
||||
from kilostar.api.resource import _mask_config, _mask_secret
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"value,expected",
|
||||
[
|
||||
("", ""),
|
||||
("short", "***"),
|
||||
("12345678", "***"),
|
||||
("1234567890", "1234***7890"),
|
||||
("a-very-long-token-string", "a-ve***ring"),
|
||||
(None, None),
|
||||
(12345, 12345),
|
||||
],
|
||||
)
|
||||
def test_mask_secret_behaviour(value, expected):
|
||||
assert _mask_secret(value) == expected
|
||||
|
||||
|
||||
def test_mask_config_masks_known_sensitive_keys():
|
||||
raw = {
|
||||
"api_key": "1234567890",
|
||||
"TOKEN": "1234567890",
|
||||
"secret_value": "1234567890",
|
||||
"DB_PASSWORD": "1234567890",
|
||||
"non_sensitive": "1234567890",
|
||||
"extra": 42,
|
||||
}
|
||||
masked = _mask_config(raw)
|
||||
assert masked["api_key"] == "1234***7890"
|
||||
assert masked["TOKEN"] == "1234***7890"
|
||||
assert masked["secret_value"] == "1234***7890"
|
||||
assert masked["DB_PASSWORD"] == "1234***7890"
|
||||
assert masked["non_sensitive"] == "1234567890"
|
||||
assert masked["extra"] == 42
|
||||
|
||||
|
||||
def test_mask_config_handles_empty_input():
|
||||
assert _mask_config({}) == {}
|
||||
|
||||
|
||||
def test_mask_config_does_not_mutate_input():
|
||||
raw = {"api_key": "1234567890", "url": "https://x"}
|
||||
snapshot = dict(raw)
|
||||
_mask_config(raw)
|
||||
assert raw == snapshot
|
||||
@@ -0,0 +1,117 @@
|
||||
"""kilostar.utils.crypto 加解密模块测试。"""
|
||||
|
||||
import os
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
from cryptography.fernet import Fernet
|
||||
|
||||
from kilostar.utils.crypto import (
|
||||
CryptoError,
|
||||
decrypt_dict_secrets,
|
||||
decrypt_secret,
|
||||
encrypt_dict_secrets,
|
||||
encrypt_secret,
|
||||
is_encrypted,
|
||||
_is_sensitive_key,
|
||||
_get_fernet,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def _set_secret_key(monkeypatch):
|
||||
key = Fernet.generate_key().decode()
|
||||
monkeypatch.setenv("KILOSTAR_SECRET_KEY", key)
|
||||
_get_fernet.cache_clear()
|
||||
yield
|
||||
_get_fernet.cache_clear()
|
||||
|
||||
|
||||
class TestEncryptDecrypt:
|
||||
def test_round_trip(self):
|
||||
plain = "my-secret-token-12345"
|
||||
cipher = encrypt_secret(plain)
|
||||
assert cipher.startswith("v1:")
|
||||
assert decrypt_secret(cipher) == plain
|
||||
|
||||
def test_empty_string_passthrough(self):
|
||||
assert encrypt_secret("") == ""
|
||||
assert decrypt_secret("") == ""
|
||||
|
||||
def test_non_encrypted_passthrough(self):
|
||||
assert decrypt_secret("plain-text") == "plain-text"
|
||||
|
||||
def test_is_encrypted(self):
|
||||
cipher = encrypt_secret("hello")
|
||||
assert is_encrypted(cipher) is True
|
||||
assert is_encrypted("hello") is False
|
||||
assert is_encrypted("") is False
|
||||
|
||||
|
||||
class TestDictSecrets:
|
||||
def test_encrypt_dict_secrets_targets_sensitive_keys(self):
|
||||
data = {"api_key": "abc123", "url": "https://x"}
|
||||
encrypted = encrypt_dict_secrets(data)
|
||||
assert is_encrypted(encrypted["api_key"])
|
||||
assert encrypted["url"] == "https://x"
|
||||
|
||||
def test_decrypt_dict_secrets_round_trip(self):
|
||||
data = {"token": "secret", "name": "foo"}
|
||||
encrypted = encrypt_dict_secrets(data)
|
||||
decrypted = decrypt_dict_secrets(encrypted)
|
||||
assert decrypted == data
|
||||
|
||||
def test_already_encrypted_not_double_encrypted(self):
|
||||
data = {"api_key": "abc123"}
|
||||
enc1 = encrypt_dict_secrets(data)
|
||||
enc2 = encrypt_dict_secrets(enc1)
|
||||
assert enc1["api_key"] == enc2["api_key"]
|
||||
|
||||
def test_non_dict_passthrough(self):
|
||||
assert encrypt_dict_secrets("not a dict") == "not a dict"
|
||||
assert decrypt_dict_secrets(42) == 42
|
||||
|
||||
|
||||
class TestSensitiveKeyDetection:
|
||||
@pytest.mark.parametrize(
|
||||
"key,expected",
|
||||
[
|
||||
("api_key", True),
|
||||
("API_KEY", True),
|
||||
("provider_apikey", True),
|
||||
("token", True),
|
||||
("access_token", True),
|
||||
("secret", True),
|
||||
("password", True),
|
||||
("db_password", True),
|
||||
("url", False),
|
||||
("name", False),
|
||||
("transport", False),
|
||||
],
|
||||
)
|
||||
def test_is_sensitive_key(self, key, expected):
|
||||
assert _is_sensitive_key(key) is expected
|
||||
|
||||
|
||||
class TestMissingKey:
|
||||
def test_raises_when_key_not_set(self, monkeypatch):
|
||||
monkeypatch.delenv("KILOSTAR_SECRET_KEY", raising=False)
|
||||
_get_fernet.cache_clear()
|
||||
with pytest.raises(CryptoError, match="KILOSTAR_SECRET_KEY"):
|
||||
encrypt_secret("hello")
|
||||
|
||||
def test_raises_on_invalid_key(self, monkeypatch):
|
||||
monkeypatch.setenv("KILOSTAR_SECRET_KEY", "not-a-valid-fernet-key")
|
||||
_get_fernet.cache_clear()
|
||||
with pytest.raises(CryptoError, match="格式无效"):
|
||||
encrypt_secret("hello")
|
||||
|
||||
|
||||
class TestDecryptWithWrongKey:
|
||||
def test_wrong_key_raises(self, monkeypatch):
|
||||
cipher = encrypt_secret("hello")
|
||||
new_key = Fernet.generate_key().decode()
|
||||
monkeypatch.setenv("KILOSTAR_SECRET_KEY", new_key)
|
||||
_get_fernet.cache_clear()
|
||||
with pytest.raises(CryptoError, match="解密失败"):
|
||||
decrypt_secret(cipher)
|
||||
@@ -0,0 +1,72 @@
|
||||
"""``database_exception`` 装饰器:透传异常并通过 logger 上报。"""
|
||||
|
||||
import pytest
|
||||
from pydantic import BaseModel, ValidationError
|
||||
from sqlalchemy.exc import IntegrityError, OperationalError
|
||||
|
||||
from kilostar.core.postgres_database.database_exception import database_exception
|
||||
from kilostar.utils.error import UserNotExistError
|
||||
|
||||
|
||||
async def test_normal_path_returns_value():
|
||||
@database_exception
|
||||
async def ok() -> int:
|
||||
return 42
|
||||
|
||||
assert await ok() == 42
|
||||
|
||||
|
||||
async def test_passes_args_and_kwargs():
|
||||
@database_exception
|
||||
async def add(a, b, *, c) -> int:
|
||||
return a + b + c
|
||||
|
||||
assert await add(1, 2, c=3) == 6
|
||||
|
||||
|
||||
async def test_validation_error_propagates():
|
||||
class Model(BaseModel):
|
||||
x: int
|
||||
|
||||
@database_exception
|
||||
async def boom() -> None:
|
||||
Model(x="not-int") # type: ignore[arg-type]
|
||||
|
||||
with pytest.raises(ValidationError):
|
||||
await boom()
|
||||
|
||||
|
||||
async def test_integrity_error_propagates():
|
||||
@database_exception
|
||||
async def boom() -> None:
|
||||
raise IntegrityError("stmt", {}, Exception("dup"))
|
||||
|
||||
with pytest.raises(IntegrityError):
|
||||
await boom()
|
||||
|
||||
|
||||
async def test_operational_error_propagates():
|
||||
@database_exception
|
||||
async def boom() -> None:
|
||||
raise OperationalError("stmt", {}, Exception("conn"))
|
||||
|
||||
with pytest.raises(OperationalError):
|
||||
await boom()
|
||||
|
||||
|
||||
async def test_user_not_exist_error_propagates():
|
||||
@database_exception
|
||||
async def boom() -> None:
|
||||
raise UserNotExistError("missing")
|
||||
|
||||
with pytest.raises(UserNotExistError):
|
||||
await boom()
|
||||
|
||||
|
||||
async def test_unexpected_error_propagates():
|
||||
@database_exception
|
||||
async def boom() -> None:
|
||||
raise RuntimeError("unexpected")
|
||||
|
||||
with pytest.raises(RuntimeError):
|
||||
await boom()
|
||||
@@ -0,0 +1,194 @@
|
||||
"""``GlobalStateMachine`` 中 MCP/ToolConfig/CustomToolset 注册表测试。
|
||||
|
||||
GSM 现在走 PostgresDatabase Actor,这里绕过 Ray 直接构造实例,
|
||||
用 AsyncMock 模拟 postgres_database 的 remote 调用。
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any, Dict
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
from kilostar.core.global_state_machine import global_state_machine as gsm_module
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def gsm_instance(monkeypatch):
|
||||
GSMClass = gsm_module.GlobalStateMachine.__ray_actor_class__
|
||||
obj = GSMClass.__new__(GSMClass)
|
||||
obj._mcp_servers = {}
|
||||
obj._tool_configs = {}
|
||||
obj._custom_toolsets = {}
|
||||
obj._global_provider_manager = MagicMock()
|
||||
obj._global_tool_manager = MagicMock()
|
||||
obj._global_tool_manager.is_third_party_tool = lambda name: name.startswith("tp_")
|
||||
obj._global_tool_manager.rebuild_custom_toolsets = MagicMock()
|
||||
obj._global_skill_manager = MagicMock()
|
||||
obj._global_individual_manager = MagicMock()
|
||||
obj.postgres_database = MagicMock()
|
||||
# 新加的 object-store 快照状态
|
||||
obj._config_version = 0
|
||||
obj._current_ref = None
|
||||
# 这套老测试覆盖的是注册表行为,不关心快照发布;
|
||||
# 把 _publish_snapshot 替换成 no-op 计数器,避免触达 ray.put 与 manager 内部细节
|
||||
obj._publish_count = 0
|
||||
|
||||
def _stub_publish():
|
||||
obj._publish_count += 1
|
||||
obj._config_version += 1
|
||||
|
||||
obj._publish_snapshot = _stub_publish
|
||||
return obj
|
||||
|
||||
|
||||
# ─── MCP server registry ────────────────────────────────────────────────────
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_add_mcp_server(gsm_instance):
|
||||
obj = gsm_instance
|
||||
saved = {"server_id": "fs", "name": "fs", "transport": "stdio"}
|
||||
obj.postgres_database.upsert_mcp_server.remote = AsyncMock(return_value=saved)
|
||||
|
||||
ok = await obj.add_mcp_server("fs", {"name": "fs", "transport": "stdio"})
|
||||
assert ok is True
|
||||
assert obj._mcp_servers["fs"] == saved
|
||||
|
||||
|
||||
def test_get_mcp_server_configs_returns_copy(gsm_instance):
|
||||
obj = gsm_instance
|
||||
obj._mcp_servers["fs"] = {"name": "fs", "transport": "stdio"}
|
||||
res1 = obj.get_mcp_server_configs()
|
||||
res1["fs"] = {"mutated": True}
|
||||
res2 = obj.get_mcp_server_configs()
|
||||
assert res2["fs"]["name"] == "fs"
|
||||
|
||||
|
||||
def test_get_mcp_server_returns_none_when_missing(gsm_instance):
|
||||
assert gsm_instance.get_mcp_server("nope") is None
|
||||
|
||||
|
||||
def test_list_mcp_servers_includes_server_id(gsm_instance):
|
||||
obj = gsm_instance
|
||||
obj._mcp_servers["fs"] = {"name": "fs", "transport": "stdio"}
|
||||
listed = obj.list_mcp_servers()
|
||||
assert listed[0]["server_id"] == "fs"
|
||||
assert listed[0]["name"] == "fs"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_delete_mcp_server(gsm_instance):
|
||||
obj = gsm_instance
|
||||
obj._mcp_servers["fs"] = {"name": "fs"}
|
||||
obj.postgres_database.delete_mcp_server_db.remote = AsyncMock(return_value=True)
|
||||
|
||||
assert await obj.delete_mcp_server("fs") is True
|
||||
assert "fs" not in obj._mcp_servers
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_delete_unknown_mcp_server(gsm_instance):
|
||||
obj = gsm_instance
|
||||
obj.postgres_database.delete_mcp_server_db.remote = AsyncMock(return_value=False)
|
||||
assert await obj.delete_mcp_server("nope") is False
|
||||
|
||||
|
||||
# ─── tool_configs ───────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_set_and_get_tool_config(gsm_instance):
|
||||
obj = gsm_instance
|
||||
obj.postgres_database.upsert_tool_config.remote = AsyncMock(
|
||||
return_value={"tool_name": "tavily_search", "config": {"api_key": "xxx"}}
|
||||
)
|
||||
await obj.set_tool_config("tavily_search", {"api_key": "xxx"})
|
||||
assert obj.get_tool_config("tavily_search") == {"api_key": "xxx"}
|
||||
|
||||
|
||||
def test_get_unknown_tool_config_returns_empty(gsm_instance):
|
||||
assert gsm_instance.get_tool_config("not_exist") == {}
|
||||
|
||||
|
||||
def test_get_tool_config_is_isolated_copy(gsm_instance):
|
||||
obj = gsm_instance
|
||||
obj._tool_configs["tavily_search"] = {"api_key": "xxx"}
|
||||
snapshot = obj.get_tool_config("tavily_search")
|
||||
snapshot["api_key"] = "changed"
|
||||
assert obj.get_tool_config("tavily_search") == {"api_key": "xxx"}
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_delete_tool_config(gsm_instance):
|
||||
obj = gsm_instance
|
||||
obj._tool_configs["tavily_search"] = {"api_key": "xxx"}
|
||||
obj.postgres_database.delete_tool_config_db.remote = AsyncMock(return_value=True)
|
||||
assert await obj.delete_tool_config("tavily_search") is True
|
||||
assert obj.get_tool_config("tavily_search") == {}
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_delete_unknown_tool_config(gsm_instance):
|
||||
obj = gsm_instance
|
||||
obj.postgres_database.delete_tool_config_db.remote = AsyncMock(return_value=False)
|
||||
assert await obj.delete_tool_config("not_exist") is False
|
||||
|
||||
|
||||
def test_list_tool_configs(gsm_instance):
|
||||
obj = gsm_instance
|
||||
obj._tool_configs["tavily_search"] = {"api_key": "xxx"}
|
||||
obj._tool_configs["notion"] = {"token": "yyy"}
|
||||
raw = obj.list_tool_configs()
|
||||
assert raw["tavily_search"] == {"api_key": "xxx"}
|
||||
assert raw["notion"] == {"token": "yyy"}
|
||||
|
||||
|
||||
# ─── Custom Toolset ─────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_add_custom_toolset_success(gsm_instance):
|
||||
obj = gsm_instance
|
||||
saved = {"toolset_id": "t1", "name": "my-set", "tools": ["tp_a", "tp_b"]}
|
||||
obj.postgres_database.upsert_custom_toolset.remote = AsyncMock(return_value=saved)
|
||||
|
||||
result = await obj.add_custom_toolset(
|
||||
toolset_id="t1", name="my-set", tools=["tp_a", "tp_b"]
|
||||
)
|
||||
assert result == saved
|
||||
assert obj._custom_toolsets["t1"] == saved
|
||||
obj._global_tool_manager.rebuild_custom_toolsets.assert_called()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_add_custom_toolset_rejects_system_tools(gsm_instance):
|
||||
obj = gsm_instance
|
||||
with pytest.raises(ValueError, match="不合法"):
|
||||
await obj.add_custom_toolset(
|
||||
toolset_id="t2", name="bad", tools=["system_tool"]
|
||||
)
|
||||
|
||||
|
||||
def test_list_custom_toolsets(gsm_instance):
|
||||
obj = gsm_instance
|
||||
obj._custom_toolsets["t1"] = {"toolset_id": "t1", "name": "a", "tools": []}
|
||||
assert len(obj.list_custom_toolsets()) == 1
|
||||
|
||||
|
||||
def test_get_custom_toolset(gsm_instance):
|
||||
obj = gsm_instance
|
||||
obj._custom_toolsets["t1"] = {"toolset_id": "t1", "name": "a"}
|
||||
assert obj.get_custom_toolset("t1")["name"] == "a"
|
||||
assert obj.get_custom_toolset("nope") is None
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_delete_custom_toolset(gsm_instance):
|
||||
obj = gsm_instance
|
||||
obj._custom_toolsets["t1"] = {"toolset_id": "t1"}
|
||||
obj.postgres_database.delete_custom_toolset.remote = AsyncMock(return_value=True)
|
||||
assert await obj.delete_custom_toolset("t1") is True
|
||||
assert "t1" not in obj._custom_toolsets
|
||||
obj._global_tool_manager.rebuild_custom_toolsets.assert_called()
|
||||
@@ -0,0 +1,359 @@
|
||||
"""GSM 配置快照(Object Store 读路径)相关测试。
|
||||
|
||||
主要验证:
|
||||
|
||||
- ``GSMSnapshot`` 数据类可被 cloudpickle 序列化(ray.put 的隐式约束)
|
||||
- ``_build_snapshot`` 正确从 6 类内存状态打包配置
|
||||
- ``_publish_snapshot`` 让 version 单调递增并刷新 ObjectRef
|
||||
- 写入路径(add_individual / set_tool_config / 等)会自动发布新快照
|
||||
- ``fetch_snapshot`` 客户端:版本号一致时走本地缓存,不一致时重拉
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import pickle
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
# cloudpickle 是 ray 的传递依赖,不直接列在 pyproject 里 —— 通过 ray._private 拿
|
||||
from ray import cloudpickle
|
||||
|
||||
from kilostar.core.global_state_machine.gsm_snapshot import (
|
||||
GSMSnapshot,
|
||||
fetch_snapshot,
|
||||
reset_local_cache,
|
||||
)
|
||||
|
||||
|
||||
def test_empty_snapshot_can_cloudpickle_roundtrip():
|
||||
"""空 snapshot 序列化反序列化语义不变(ray.put 的最低约束)。"""
|
||||
snap = GSMSnapshot()
|
||||
blob = cloudpickle.dumps(snap)
|
||||
restored: GSMSnapshot = cloudpickle.loads(blob)
|
||||
assert restored.version == 0
|
||||
assert restored.providers == {}
|
||||
assert restored.individuals == {}
|
||||
|
||||
|
||||
def test_snapshot_with_real_data_roundtrip():
|
||||
"""带真实 Provider + 函数引用 + dict 数据的 snapshot 也能 round-trip。"""
|
||||
from kilostar.core.global_state_machine.model_provider.base_provider import (
|
||||
Provider,
|
||||
)
|
||||
|
||||
def _sample_tool(query: str) -> str:
|
||||
return f"echo:{query}"
|
||||
|
||||
snap = GSMSnapshot(
|
||||
version=42,
|
||||
providers={
|
||||
"p1": Provider(
|
||||
provider_title="p1",
|
||||
provider_url="http://x",
|
||||
provider_apikey="sk-x",
|
||||
provider_models=["gpt-4o"],
|
||||
provider_type="openai",
|
||||
),
|
||||
},
|
||||
individuals={"agent-a": {"agent_id": "agent-a", "model_id": "gpt-4o"}},
|
||||
tool_funcs={"echo": _sample_tool},
|
||||
)
|
||||
blob = cloudpickle.dumps(snap)
|
||||
restored: GSMSnapshot = cloudpickle.loads(blob)
|
||||
assert restored.version == 42
|
||||
assert restored.providers["p1"].provider_title == "p1"
|
||||
assert restored.individuals["agent-a"]["model_id"] == "gpt-4o"
|
||||
# 模块级函数 cloudpickle 后仍可调用
|
||||
# 注意:此处函数是测试模块的局部,cloudpickle 会把字节码一并序列化
|
||||
assert restored.tool_funcs["echo"]("hi") == "echo:hi"
|
||||
|
||||
|
||||
# ─── GSM actor 集成(绕过 @ray.remote 直接构造) ────────────────────
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def gsm_instance(monkeypatch):
|
||||
from kilostar.core.global_state_machine.global_state_machine import (
|
||||
GlobalStateMachine,
|
||||
)
|
||||
|
||||
cls = GlobalStateMachine.__ray_actor_class__
|
||||
obj = cls.__new__(cls)
|
||||
# 手动还原 __init__ 副作用
|
||||
from kilostar.core.global_state_machine.individual_manager import (
|
||||
GlobalIndividualManager,
|
||||
)
|
||||
from kilostar.core.global_state_machine.provider_manager import ProviderManager
|
||||
from kilostar.core.global_state_machine.skill_manager import GlobalSkillManager
|
||||
from kilostar.core.global_state_machine.tool_manager import GlobalToolManager
|
||||
|
||||
obj._global_provider_manager = ProviderManager(postgres=None)
|
||||
obj._global_tool_manager = GlobalToolManager()
|
||||
obj._global_skill_manager = GlobalSkillManager()
|
||||
obj._global_individual_manager = GlobalIndividualManager()
|
||||
obj._mcp_servers = {}
|
||||
obj._tool_configs = {}
|
||||
obj._custom_toolsets = {}
|
||||
obj._config_version = 0
|
||||
obj._current_ref = None
|
||||
obj.postgres_database = MagicMock()
|
||||
|
||||
# ray.put 在测试沙箱里因 psutil PID 检查失败,mock 成"返回一个 sentinel ref"
|
||||
# 我们关心的是 _publish_snapshot 的语义流,不是真把对象塞进 plasma
|
||||
import kilostar.core.global_state_machine.global_state_machine as gsm_mod
|
||||
|
||||
counter = {"n": 0}
|
||||
|
||||
def _fake_put(snapshot):
|
||||
counter["n"] += 1
|
||||
return f"fake-ref-{counter['n']}"
|
||||
|
||||
monkeypatch.setattr(gsm_mod.ray, "put", _fake_put)
|
||||
return obj
|
||||
|
||||
|
||||
def test_build_snapshot_picks_up_all_six_categories(gsm_instance):
|
||||
"""_build_snapshot 应正确从 GSM 内存的 6 类数据打包。"""
|
||||
from kilostar.core.global_state_machine.model_provider.base_provider import (
|
||||
Provider,
|
||||
)
|
||||
|
||||
gsm_instance._global_provider_manager.provider_register["p1"] = Provider(
|
||||
provider_title="p1",
|
||||
provider_url="http://x",
|
||||
provider_apikey="k",
|
||||
provider_models=[],
|
||||
provider_type="openai",
|
||||
)
|
||||
gsm_instance._global_individual_manager._individuals["a1"] = {"agent_id": "a1"}
|
||||
gsm_instance._mcp_servers["s1"] = {"server_id": "s1"}
|
||||
gsm_instance._tool_configs["t1"] = {"key": "v"}
|
||||
gsm_instance._custom_toolsets["ts1"] = {"toolset_id": "ts1"}
|
||||
|
||||
snap = gsm_instance._build_snapshot()
|
||||
|
||||
assert "p1" in snap.providers
|
||||
assert "a1" in snap.individuals
|
||||
assert "s1" in snap.mcp_servers
|
||||
assert "t1" in snap.tool_configs
|
||||
assert "ts1" in snap.custom_toolsets
|
||||
|
||||
|
||||
def test_build_snapshot_exposes_system_tools_by_scope(gsm_instance):
|
||||
"""系统工具按 scope 分桶的工具名清单要随快照发布出去(客户端重建 toolset 用)。"""
|
||||
tm = gsm_instance._global_tool_manager
|
||||
# 模拟 tool_manager 内部状态:default scope 有 file_reader,control_node 有 approval
|
||||
def _f1():
|
||||
return "f1"
|
||||
|
||||
def _f2():
|
||||
return "f2"
|
||||
|
||||
tm._tool_funcs.clear()
|
||||
tm._tool_funcs["default"]["file_reader"] = _f1
|
||||
tm._tool_funcs["control_node"]["approval"] = _f2
|
||||
|
||||
snap = gsm_instance._build_snapshot()
|
||||
assert snap.system_tools_by_scope.get("default") == ["file_reader"]
|
||||
assert snap.system_tools_by_scope.get("control_node") == ["approval"]
|
||||
# tool_funcs 拍平后两者都应存在
|
||||
assert set(snap.tool_funcs.keys()) == {"file_reader", "approval"}
|
||||
|
||||
|
||||
def test_publish_snapshot_increments_version(gsm_instance):
|
||||
assert gsm_instance._config_version == 0
|
||||
assert gsm_instance._current_ref is None
|
||||
|
||||
gsm_instance._publish_snapshot()
|
||||
v1 = gsm_instance._config_version
|
||||
ref1 = gsm_instance._current_ref
|
||||
assert v1 == 1
|
||||
assert ref1 is not None
|
||||
|
||||
gsm_instance._publish_snapshot()
|
||||
assert gsm_instance._config_version == 2
|
||||
assert gsm_instance._current_ref is not ref1 # 新 put 应是新 ref
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_current_config_ref_lazy_publishes_when_empty(gsm_instance):
|
||||
"""从未发布过快照时,current_config_ref 应自动发布一次而不是返回 None。"""
|
||||
version, ref = await gsm_instance.current_config_ref()
|
||||
assert version == 1
|
||||
assert ref is not None
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_current_version_is_lightweight(gsm_instance):
|
||||
gsm_instance._publish_snapshot()
|
||||
gsm_instance._publish_snapshot()
|
||||
assert await gsm_instance.current_version() == 2
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_add_individual_publishes_new_snapshot(gsm_instance):
|
||||
"""写入路径 add_individual 应自动 +1 version。"""
|
||||
before = gsm_instance._config_version
|
||||
gsm_instance.add_individual("agent-x", {"model_id": "gpt-4o"})
|
||||
after = gsm_instance._config_version
|
||||
assert after == before + 1
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_add_provider_wrap_publishes_new_snapshot(gsm_instance):
|
||||
"""add_provider_wrap 即便走 mock 适配器也应该最终发布一次新快照。"""
|
||||
from kilostar.core.global_state_machine.model_provider.base_provider import (
|
||||
Provider,
|
||||
)
|
||||
|
||||
fake_provider = Provider(
|
||||
provider_title="my-openai",
|
||||
provider_url="http://x",
|
||||
provider_apikey="k",
|
||||
provider_models=[],
|
||||
provider_type="openai",
|
||||
)
|
||||
gsm_instance._global_provider_manager.provider_mapper["openai"] = MagicMock()
|
||||
gsm_instance._global_provider_manager.provider_mapper[
|
||||
"openai"
|
||||
].create_provider = AsyncMock(return_value=fake_provider)
|
||||
gsm_instance.postgres_database.add_provider_db = MagicMock()
|
||||
gsm_instance.postgres_database.add_provider_db.remote = AsyncMock()
|
||||
|
||||
before = gsm_instance._config_version
|
||||
await gsm_instance.add_provider_wrap(
|
||||
provider_type="openai",
|
||||
provider_title="my-openai",
|
||||
provider_url="http://x",
|
||||
provider_apikey="k",
|
||||
provider_owner="alice",
|
||||
)
|
||||
after = gsm_instance._config_version
|
||||
assert after == before + 1
|
||||
|
||||
|
||||
# ─── fetch_snapshot 客户端缓存 ────────────────────────────────────
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_fetch_snapshot_uses_local_cache_when_version_matches():
|
||||
"""模拟 GSM actor,验证版本号一致时不走 ray.get。"""
|
||||
reset_local_cache()
|
||||
snap = GSMSnapshot(version=5, providers={"p": MagicMock()})
|
||||
|
||||
# mock GSM handle:第一次 fetch 全走,第二次只 current_version
|
||||
fake_gsm = MagicMock()
|
||||
fake_gsm.current_version = MagicMock()
|
||||
fake_gsm.current_version.remote = AsyncMock(return_value=5)
|
||||
fake_gsm.current_config_ref = MagicMock()
|
||||
|
||||
# 提前把缓存预热成 v5(模拟之前已经 fetch 过)
|
||||
from kilostar.core.global_state_machine import gsm_snapshot as snap_mod
|
||||
|
||||
snap_mod._local_cache["version"] = 5
|
||||
snap_mod._local_cache["snapshot"] = snap
|
||||
|
||||
# 不 mock current_config_ref —— 如果它被调用了,AttributeError 会让测试失败
|
||||
fake_gsm.current_config_ref.remote = AsyncMock(
|
||||
side_effect=AssertionError("不应触发:缓存版本一致时不应调 current_config_ref")
|
||||
)
|
||||
|
||||
result = await fetch_snapshot(gsm_actor=fake_gsm)
|
||||
assert result is snap
|
||||
fake_gsm.current_version.remote.assert_awaited_once()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_fetch_snapshot_refetches_when_version_changes(monkeypatch):
|
||||
"""版本号变了应重新 ray.get 拉新 snapshot。"""
|
||||
reset_local_cache()
|
||||
new_snap = GSMSnapshot(version=10)
|
||||
|
||||
fake_gsm = MagicMock()
|
||||
fake_gsm.current_version = MagicMock()
|
||||
fake_gsm.current_version.remote = AsyncMock(return_value=10)
|
||||
fake_gsm.current_config_ref = MagicMock()
|
||||
fake_gsm.current_config_ref.remote = AsyncMock(return_value=(10, "fake-ref"))
|
||||
|
||||
# mock ray.get 让它直接返回我们准备的 snap
|
||||
import kilostar.core.global_state_machine.gsm_snapshot as snap_mod
|
||||
|
||||
monkeypatch.setattr(snap_mod.ray, "get", lambda ref: new_snap)
|
||||
|
||||
result = await fetch_snapshot(gsm_actor=fake_gsm)
|
||||
assert result is new_snap
|
||||
fake_gsm.current_config_ref.remote.assert_awaited_once()
|
||||
# 缓存应已更新到 v10
|
||||
assert snap_mod._local_cache["version"] == 10
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_fetch_snapshot_use_cache_false_skips_cache(monkeypatch):
|
||||
"""``use_cache=False`` 直接走 current_config_ref,不读本地缓存。"""
|
||||
reset_local_cache()
|
||||
fresh = GSMSnapshot(version=1)
|
||||
|
||||
fake_gsm = MagicMock()
|
||||
fake_gsm.current_config_ref = MagicMock()
|
||||
fake_gsm.current_config_ref.remote = AsyncMock(return_value=(1, "ref"))
|
||||
|
||||
import kilostar.core.global_state_machine.gsm_snapshot as snap_mod
|
||||
|
||||
monkeypatch.setattr(snap_mod.ray, "get", lambda ref: fresh)
|
||||
|
||||
result = await fetch_snapshot(gsm_actor=fake_gsm, use_cache=False)
|
||||
assert result is fresh
|
||||
|
||||
|
||||
# ─── build_toolsets_for_scope 客户端 helper ────────────────────────
|
||||
|
||||
|
||||
def test_build_toolsets_for_scope_assembles_system_and_custom():
|
||||
"""客户端按 snapshot 的 system_tools_by_scope + custom_toolsets 现场组装。"""
|
||||
from kilostar.core.global_state_machine.gsm_snapshot import (
|
||||
build_toolsets_for_scope,
|
||||
)
|
||||
|
||||
def _sys_default():
|
||||
return "d"
|
||||
|
||||
def _sys_scope():
|
||||
return "s"
|
||||
|
||||
def _tp_a():
|
||||
return "a"
|
||||
|
||||
snap = GSMSnapshot(
|
||||
tool_funcs={"sys_default": _sys_default, "sys_scope": _sys_scope},
|
||||
third_party_funcs={"tp_a": _tp_a},
|
||||
system_tools_by_scope={
|
||||
"default": ["sys_default"],
|
||||
"control_node": ["sys_scope"],
|
||||
},
|
||||
custom_toolsets={
|
||||
"grp": {"toolset_id": "grp", "tools": ["tp_a"]},
|
||||
},
|
||||
)
|
||||
|
||||
result = build_toolsets_for_scope(snap, "control_node")
|
||||
# 应包含两个 system bucket + 一个 custom toolset
|
||||
assert len(result) == 3
|
||||
ids = [getattr(t, "id", None) for t in result]
|
||||
assert ids == ["system::default", "system::control_node", "custom::grp"]
|
||||
|
||||
|
||||
def test_build_toolsets_for_scope_skips_empty_buckets():
|
||||
"""没有工具的 scope 不应产出 toolset,避免空 FunctionToolset 噪声。"""
|
||||
from kilostar.core.global_state_machine.gsm_snapshot import (
|
||||
build_toolsets_for_scope,
|
||||
)
|
||||
|
||||
snap = GSMSnapshot(
|
||||
tool_funcs={},
|
||||
system_tools_by_scope={"default": [], "control_node": []},
|
||||
custom_toolsets={},
|
||||
)
|
||||
assert build_toolsets_for_scope(snap, "control_node") == []
|
||||
@@ -0,0 +1,93 @@
|
||||
"""``GlobalToolManager`` 在真实仓库目录上的扫描结果。
|
||||
|
||||
期望(基于现有 plugin 目录):
|
||||
- approval / file_reader:is_system=True,进 system toolset
|
||||
- tavily_search:is_system=False,category=search,进 third_party
|
||||
"""
|
||||
|
||||
import pytest
|
||||
|
||||
from kilostar.core.global_state_machine.tool_manager import GlobalToolManager
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def manager() -> GlobalToolManager:
|
||||
return GlobalToolManager()
|
||||
|
||||
|
||||
def test_metadata_contains_known_plugins(manager: GlobalToolManager):
|
||||
names = set(manager.tool_metadata.keys())
|
||||
assert {"approval", "file_reader", "tavily_search"} <= names
|
||||
|
||||
|
||||
def test_system_and_third_party_classification(manager: GlobalToolManager):
|
||||
sys_tools = {m["name"] for m in manager.get_system_tools()}
|
||||
tp_tools = {m["name"] for m in manager.get_third_party_tools()}
|
||||
assert {"approval", "file_reader"} <= sys_tools
|
||||
assert "tavily_search" in tp_tools
|
||||
assert sys_tools.isdisjoint(tp_tools)
|
||||
|
||||
|
||||
def test_get_tools_by_category_returns_correct_buckets(manager: GlobalToolManager):
|
||||
by_system = {m["name"] for m in manager.get_tools_by_category("system")}
|
||||
by_search = {m["name"] for m in manager.get_tools_by_category("search")}
|
||||
assert {"approval", "file_reader"} <= by_system
|
||||
assert "tavily_search" in by_search
|
||||
|
||||
|
||||
def test_is_third_party_tool(manager: GlobalToolManager):
|
||||
assert manager.is_third_party_tool("tavily_search") is True
|
||||
assert manager.is_third_party_tool("approval") is False
|
||||
assert manager.is_third_party_tool("nonexistent") is False
|
||||
|
||||
|
||||
def test_get_toolsets_for_scope_returns_function_toolsets(manager: GlobalToolManager):
|
||||
from pydantic_ai.toolsets import FunctionToolset
|
||||
|
||||
sets = manager.get_toolsets_for_scope("control_node")
|
||||
assert sets, "control_node 应该至少有 system toolset"
|
||||
for ts in sets:
|
||||
assert isinstance(ts, FunctionToolset)
|
||||
|
||||
|
||||
def test_system_toolsets_have_correct_ids(manager: GlobalToolManager):
|
||||
sets = manager.get_toolsets_for_scope("consciousness_node")
|
||||
ids = {getattr(ts, "id", None) for ts in sets}
|
||||
for tid in ids:
|
||||
assert tid.startswith("system::") or tid.startswith("custom::")
|
||||
|
||||
|
||||
def test_no_mcp_category_in_metadata_for_local_plugins(manager: GlobalToolManager):
|
||||
for m in manager.tool_metadata.values():
|
||||
if m["category"] == "mcp":
|
||||
assert m["name"] not in manager._third_party_funcs
|
||||
for scope_funcs in manager._tool_funcs.values():
|
||||
assert m["name"] not in scope_funcs
|
||||
|
||||
|
||||
def test_unknown_scope_returns_only_default_and_custom(manager: GlobalToolManager):
|
||||
sets = manager.get_toolsets_for_scope("not_a_real_scope")
|
||||
for ts in sets:
|
||||
tid = getattr(ts, "id", "")
|
||||
assert tid.endswith("::default") or tid.startswith("custom::")
|
||||
|
||||
|
||||
def test_rebuild_custom_toolsets(manager: GlobalToolManager):
|
||||
custom_defs = {
|
||||
"grp1": {"tools": ["tavily_search"], "name": "search-group"},
|
||||
}
|
||||
manager.rebuild_custom_toolsets(custom_defs)
|
||||
assert "grp1" in manager._custom_toolsets
|
||||
sets = manager.get_toolsets_for_scope("default")
|
||||
custom_ids = [getattr(ts, "id", "") for ts in sets if "custom::" in getattr(ts, "id", "")]
|
||||
assert any("grp1" in cid for cid in custom_ids)
|
||||
manager.rebuild_custom_toolsets({})
|
||||
assert manager._custom_toolsets == {}
|
||||
|
||||
|
||||
def test_get_personal_tools_compat(manager: GlobalToolManager):
|
||||
assert manager.get_personal_tools() == manager.get_third_party_tools()
|
||||
|
||||
|
||||
def test_get_non_system_tools_compat(manager: GlobalToolManager):
|
||||
assert manager.get_non_system_tools() == manager.get_third_party_tools()
|
||||
@@ -0,0 +1,210 @@
|
||||
"""Regulatory / Consciousness / Control 三个核心节点的 working/_run 分支逻辑测试。
|
||||
|
||||
绕过 ``@ray.remote`` 装饰,直接通过 ``__ray_actor_class__`` 取出原始类,
|
||||
mock 掉内部的 pydantic-ai Agent,验证节点对各类输入的分发与异常吞吐。
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
|
||||
# ─── RegulatoryNode ─────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def regulatory_instance():
|
||||
from kilostar.core.individual.regulatory_node.regulatory_node import (
|
||||
RegulatoryNode,
|
||||
)
|
||||
cls = RegulatoryNode.__ray_actor_class__
|
||||
obj = cls.__new__(cls)
|
||||
from kilostar.utils.logger import get_logger
|
||||
obj.logger = get_logger("regulatory_node")
|
||||
obj.agent = None
|
||||
return obj
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_regulatory_run_returns_response_with_platform_filled(
|
||||
regulatory_instance,
|
||||
):
|
||||
from kilostar.core.individual.regulatory_node.template import (
|
||||
MessageRequest,
|
||||
MessageResponse,
|
||||
)
|
||||
|
||||
fake_response = MessageResponse(
|
||||
platform=None, platform_id=None, reply_message="hi"
|
||||
)
|
||||
agent_run_result = SimpleNamespace(output=fake_response)
|
||||
regulatory_instance.agent = MagicMock()
|
||||
regulatory_instance.agent.run = AsyncMock(return_value=agent_run_result)
|
||||
|
||||
req = MessageRequest(
|
||||
platform="client",
|
||||
user_name="alice",
|
||||
platform_id="abc",
|
||||
message="hello",
|
||||
)
|
||||
out = await regulatory_instance.working(req)
|
||||
|
||||
assert out is fake_response
|
||||
assert out.platform == "client"
|
||||
assert out.platform_id == "abc"
|
||||
regulatory_instance.agent.run.assert_awaited_once()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_regulatory_run_swallows_exception_returns_none(regulatory_instance):
|
||||
from kilostar.core.individual.regulatory_node.template import MessageRequest
|
||||
|
||||
regulatory_instance.agent = MagicMock()
|
||||
regulatory_instance.agent.run = AsyncMock(side_effect=RuntimeError("boom"))
|
||||
|
||||
req = MessageRequest(
|
||||
platform="onebot",
|
||||
user_name="bob",
|
||||
platform_id="x",
|
||||
message="hello",
|
||||
)
|
||||
out = await regulatory_instance.working(req)
|
||||
assert out is None
|
||||
|
||||
|
||||
# ─── ControlNode ────────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def control_instance():
|
||||
from kilostar.core.individual.control_node.control_node import ControlNode
|
||||
cls = ControlNode.__ray_actor_class__
|
||||
obj = cls.__new__(cls)
|
||||
from kilostar.utils.logger import get_logger
|
||||
obj.logger = get_logger("control_node")
|
||||
obj.agent = None
|
||||
return obj
|
||||
|
||||
|
||||
def _make_workflow_step():
|
||||
from kilostar.core.work.workflow.workflow import WorkflowStep
|
||||
|
||||
return WorkflowStep(
|
||||
step=1,
|
||||
name="do something",
|
||||
action="execute the thing",
|
||||
inputs=None,
|
||||
outputs="result",
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_control_working_returns_for_workflow_output(control_instance):
|
||||
from kilostar.core.individual.control_node.template import (
|
||||
ForWorkflow,
|
||||
ForWorkflowInput,
|
||||
)
|
||||
|
||||
step = _make_workflow_step()
|
||||
expected = ForWorkflow(output="done")
|
||||
agent_run_result = SimpleNamespace(output=expected)
|
||||
|
||||
control_instance.agent = MagicMock()
|
||||
control_instance.agent.run = AsyncMock(return_value=agent_run_result)
|
||||
|
||||
out = await control_instance.working(ForWorkflowInput(workflow_step=step))
|
||||
assert out is expected
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_control_working_swallows_exception_returns_none(control_instance):
|
||||
from kilostar.core.individual.control_node.template import ForWorkflowInput
|
||||
|
||||
step = _make_workflow_step()
|
||||
control_instance.agent = MagicMock()
|
||||
control_instance.agent.run = AsyncMock(side_effect=RuntimeError("boom"))
|
||||
|
||||
out = await control_instance.working(ForWorkflowInput(workflow_step=step))
|
||||
assert out is None
|
||||
|
||||
|
||||
# ─── ConsciousnessNode ──────────────────────────────────────────────────────
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def consciousness_instance():
|
||||
from kilostar.core.individual.consciousness_node.consciousness_node import (
|
||||
ConsciousnessNode,
|
||||
)
|
||||
cls = ConsciousnessNode.__ray_actor_class__
|
||||
obj = cls.__new__(cls)
|
||||
from kilostar.utils.logger import get_logger
|
||||
obj.logger = get_logger("consciousness_node")
|
||||
obj.agent = None
|
||||
return obj
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_consciousness_working_dispatches_workflow_engine_input(
|
||||
consciousness_instance,
|
||||
):
|
||||
from kilostar.core.individual.consciousness_node.template import (
|
||||
ForWorkflowEngine,
|
||||
ForWorkflowEngineInput,
|
||||
)
|
||||
from kilostar.core.work.workflow.workflow import KiloStarWorkflow
|
||||
from kilostar.core.work.workflow.model import WorkflowMetadata
|
||||
|
||||
workflow = KiloStarWorkflow(
|
||||
title="t",
|
||||
work_link=[],
|
||||
workflow_metadata=WorkflowMetadata(),
|
||||
)
|
||||
expected = ForWorkflowEngine(workflow=workflow, reasoning="r")
|
||||
agent_run_result = SimpleNamespace(output=expected)
|
||||
|
||||
consciousness_instance.agent = MagicMock()
|
||||
consciousness_instance.agent.run = AsyncMock(return_value=agent_run_result)
|
||||
|
||||
out = await consciousness_instance.working(
|
||||
ForWorkflowEngineInput(original_command="cmd", available_skills=[])
|
||||
)
|
||||
assert out is expected
|
||||
assert isinstance(out, ForWorkflowEngine)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_consciousness_working_returns_none_on_unknown_output(
|
||||
consciousness_instance,
|
||||
):
|
||||
"""Agent 返回的不是三种已知 ForXxx 类型时,working 应返回 None。"""
|
||||
from kilostar.core.individual.consciousness_node.template import (
|
||||
ForWorkflowEngineInput,
|
||||
)
|
||||
|
||||
agent_run_result = SimpleNamespace(output="unexpected string")
|
||||
consciousness_instance.agent = MagicMock()
|
||||
consciousness_instance.agent.run = AsyncMock(return_value=agent_run_result)
|
||||
|
||||
out = await consciousness_instance.working(
|
||||
ForWorkflowEngineInput(original_command="cmd", available_skills=[])
|
||||
)
|
||||
assert out is None
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_consciousness_working_swallows_exception(consciousness_instance):
|
||||
from kilostar.core.individual.consciousness_node.template import (
|
||||
ForWorkflowEngineInput,
|
||||
)
|
||||
|
||||
consciousness_instance.agent = MagicMock()
|
||||
consciousness_instance.agent.run = AsyncMock(side_effect=RuntimeError("boom"))
|
||||
|
||||
out = await consciousness_instance.working(
|
||||
ForWorkflowEngineInput(original_command="cmd", available_skills=[])
|
||||
)
|
||||
assert out is None
|
||||
@@ -0,0 +1,41 @@
|
||||
"""``plugin/tool_plugin`` 下各工具的元数据类正确性。
|
||||
|
||||
``BaseToolData`` 本身不带 ``name`` 字段;工具名以目录名为准(由 ``GlobalToolManager``
|
||||
扫描时注入到 ``tool_metadata`` 中)。这里只验证子类对 BaseToolData 字段的覆写。
|
||||
"""
|
||||
|
||||
from kilostar.plugin.tool_plugin.approval.approval import ApprovalToolData
|
||||
from kilostar.plugin.tool_plugin.file_reader import FileReaderToolData
|
||||
from kilostar.plugin.tool_plugin.tavily_search import TavilySearchToolData
|
||||
|
||||
|
||||
def test_approval_metadata():
|
||||
data = ApprovalToolData()
|
||||
assert data.is_system is True
|
||||
assert data.category == "system"
|
||||
assert "control_node" in data.action_scope
|
||||
assert "consciousness_node" in data.action_scope
|
||||
|
||||
|
||||
def test_file_reader_metadata():
|
||||
data = FileReaderToolData()
|
||||
assert data.is_system is True
|
||||
assert data.category == "system"
|
||||
assert "control_node" in data.action_scope
|
||||
|
||||
|
||||
def test_tavily_search_metadata():
|
||||
data = TavilySearchToolData()
|
||||
assert data.is_system is False
|
||||
assert data.category == "search"
|
||||
assert "control_node" in data.action_scope
|
||||
assert "consciousness_node" in data.action_scope
|
||||
# 默认配置 schema 含 api_key 字段(用于 GSM 配置面板)
|
||||
assert "api_key" in data.config_args
|
||||
|
||||
|
||||
def test_base_tool_extra_allowed():
|
||||
"""``ConfigDict(extra="allow")`` 允许子类外的 KV 也能装进来。"""
|
||||
data = ApprovalToolData(some_extension="ok") # type: ignore[call-arg]
|
||||
assert data.model_extra is not None
|
||||
assert data.model_extra.get("some_extension") == "ok"
|
||||
@@ -0,0 +1,110 @@
|
||||
"""``ProviderManager.add_provider`` happy path:
|
||||
- mock 掉具体 Provider 适配器的 ``create_provider``;
|
||||
- 验证内存注册表写入 + postgres_database.add_provider_db.remote 被正确调用。
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
from kilostar.core.global_state_machine.provider_manager import ProviderManager
|
||||
from kilostar.core.global_state_machine.model_provider.base_provider import Provider
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_add_provider_happy_path_writes_register_and_db():
|
||||
pm = ProviderManager(postgres=None)
|
||||
|
||||
fake_provider = Provider(
|
||||
provider_title="my-openai",
|
||||
provider_url="https://api.openai.com",
|
||||
provider_apikey="sk-xxx",
|
||||
provider_models=["gpt-4o"],
|
||||
provider_type="openai",
|
||||
provider_owner="alice",
|
||||
)
|
||||
|
||||
pm.provider_mapper["openai"] = MagicMock()
|
||||
pm.provider_mapper["openai"].create_provider = AsyncMock(return_value=fake_provider)
|
||||
|
||||
postgres = MagicMock()
|
||||
postgres.add_provider_db = MagicMock()
|
||||
postgres.add_provider_db.remote = AsyncMock(return_value=None)
|
||||
|
||||
await pm.add_provider(
|
||||
provider_type="openai",
|
||||
provider_title="my-openai",
|
||||
provider_url="https://api.openai.com",
|
||||
provider_apikey="sk-xxx",
|
||||
provider_owner="alice",
|
||||
postgres_database=postgres,
|
||||
)
|
||||
|
||||
assert "my-openai" in pm.provider_register
|
||||
assert pm.provider_register["my-openai"] is fake_provider
|
||||
postgres.add_provider_db.remote.assert_awaited_once()
|
||||
kwargs = postgres.add_provider_db.remote.await_args.kwargs
|
||||
assert kwargs["provider_title"] == "my-openai"
|
||||
assert kwargs["provider_apikey"] == "sk-xxx"
|
||||
assert kwargs["provider_models"] == ["gpt-4o"]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_add_provider_unknown_type_returns_none(caplog):
|
||||
pm = ProviderManager(postgres=None)
|
||||
postgres = MagicMock()
|
||||
postgres.add_provider_db = MagicMock()
|
||||
postgres.add_provider_db.remote = AsyncMock()
|
||||
|
||||
result = await pm.add_provider(
|
||||
provider_type="not_supported",
|
||||
provider_title="x",
|
||||
provider_url="u",
|
||||
provider_apikey="k",
|
||||
provider_owner="o",
|
||||
postgres_database=postgres,
|
||||
)
|
||||
|
||||
assert result is None
|
||||
assert "x" not in pm.provider_register
|
||||
postgres.add_provider_db.remote.assert_not_awaited()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_add_provider_network_error_raises_retryable():
|
||||
"""网络异常应被包装为 RetryableError 重抛。"""
|
||||
import httpx
|
||||
from kilostar.utils.error import RetryableError
|
||||
|
||||
pm = ProviderManager(postgres=None)
|
||||
pm.provider_mapper["openai"] = MagicMock()
|
||||
pm.provider_mapper["openai"].create_provider = AsyncMock(
|
||||
side_effect=httpx.ConnectError("network down")
|
||||
)
|
||||
|
||||
postgres = MagicMock()
|
||||
postgres.add_provider_db = MagicMock()
|
||||
postgres.add_provider_db.remote = AsyncMock()
|
||||
|
||||
with pytest.raises(RetryableError):
|
||||
await pm.add_provider(
|
||||
provider_type="openai",
|
||||
provider_title="x",
|
||||
provider_url="u",
|
||||
provider_apikey="k",
|
||||
provider_owner="o",
|
||||
postgres_database=postgres,
|
||||
)
|
||||
|
||||
|
||||
def test_get_provider_list_returns_internal_dict():
|
||||
pm = ProviderManager(postgres=None)
|
||||
pm.provider_register["a"] = "fake"
|
||||
assert pm.get_provider_list()["a"] == "fake"
|
||||
|
||||
|
||||
def test_get_provider_returns_none_when_missing():
|
||||
pm = ProviderManager(postgres=None)
|
||||
assert pm.get_provider("nope") is None
|
||||
@@ -0,0 +1,66 @@
|
||||
"""``regulatory_node/template.py`` 中的请求/响应/依赖模型校验。"""
|
||||
|
||||
import pytest
|
||||
from pydantic import ValidationError
|
||||
|
||||
from kilostar.core.individual.regulatory_node.template import (
|
||||
MessageRequest,
|
||||
MessageResponse,
|
||||
RegulatoryNodeDeps,
|
||||
)
|
||||
|
||||
|
||||
def test_message_request_valid_for_client_platform():
|
||||
req = MessageRequest(
|
||||
platform="client",
|
||||
user_name="alice",
|
||||
platform_id=None,
|
||||
message="hi",
|
||||
)
|
||||
assert req.platform == "client"
|
||||
assert req.platform_id is None
|
||||
|
||||
|
||||
def test_message_request_valid_for_onebot_platform():
|
||||
req = MessageRequest(
|
||||
platform="onebot",
|
||||
user_name="alice",
|
||||
platform_id="group:1",
|
||||
message="hi",
|
||||
)
|
||||
assert req.platform == "onebot"
|
||||
|
||||
|
||||
def test_message_request_rejects_unknown_platform():
|
||||
with pytest.raises(ValidationError):
|
||||
MessageRequest(
|
||||
platform="qq", # type: ignore[arg-type]
|
||||
user_name="alice",
|
||||
platform_id="x",
|
||||
message="hi",
|
||||
)
|
||||
|
||||
|
||||
def test_message_response_requires_reply_message():
|
||||
with pytest.raises(ValidationError):
|
||||
MessageResponse(
|
||||
platform="client",
|
||||
platform_id="x",
|
||||
) # type: ignore[call-arg]
|
||||
|
||||
|
||||
def test_message_response_full_round_trip():
|
||||
res = MessageResponse(
|
||||
platform="client",
|
||||
platform_id="anonymous",
|
||||
reply_message="hello",
|
||||
)
|
||||
dumped = res.model_dump()
|
||||
assert dumped["reply_message"] == "hello"
|
||||
assert dumped["platform"] == "client"
|
||||
|
||||
|
||||
def test_regulatory_node_deps_defaults():
|
||||
deps = RegulatoryNodeDeps(platform="client", user_name="alice", time="now")
|
||||
assert deps.retry_count == 0
|
||||
assert deps.error_history == ""
|
||||
@@ -0,0 +1,105 @@
|
||||
"""``kilostar.utils.request_context``:双层 ID 的 contextvars 与 logger 集成。
|
||||
|
||||
覆盖:
|
||||
|
||||
- ``request_id`` / ``trace_id`` 默认空、bind 后可读、reset 后还原
|
||||
- ``request_id_scope`` / ``trace_id_scope`` 上下文管理器
|
||||
- ``snapshot`` / ``apply_snapshot`` 跨边界透传
|
||||
- logger 切面:``contextvars`` 中的值会自动写入 ``record["extra"]``
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
|
||||
import pytest
|
||||
|
||||
from kilostar.utils import request_context as rc
|
||||
|
||||
|
||||
def test_default_values_are_empty():
|
||||
assert rc.get_request_id() == ""
|
||||
assert rc.get_trace_id() == ""
|
||||
|
||||
|
||||
def test_bind_and_reset_request_id():
|
||||
token = rc.bind_request_id("req-abc")
|
||||
try:
|
||||
assert rc.get_request_id() == "req-abc"
|
||||
finally:
|
||||
rc.reset_request_id(token)
|
||||
assert rc.get_request_id() == ""
|
||||
|
||||
|
||||
def test_bind_and_reset_trace_id():
|
||||
token = rc.bind_trace_id("trace-xyz")
|
||||
try:
|
||||
assert rc.get_trace_id() == "trace-xyz"
|
||||
finally:
|
||||
rc.reset_trace_id(token)
|
||||
assert rc.get_trace_id() == ""
|
||||
|
||||
|
||||
def test_request_id_scope():
|
||||
with rc.request_id_scope("req-1") as rid:
|
||||
assert rid == "req-1"
|
||||
assert rc.get_request_id() == "req-1"
|
||||
assert rc.get_request_id() == ""
|
||||
|
||||
|
||||
def test_trace_id_scope_nested():
|
||||
with rc.trace_id_scope("outer"):
|
||||
assert rc.get_trace_id() == "outer"
|
||||
with rc.trace_id_scope("inner"):
|
||||
assert rc.get_trace_id() == "inner"
|
||||
assert rc.get_trace_id() == "outer"
|
||||
assert rc.get_trace_id() == ""
|
||||
|
||||
|
||||
def test_snapshot_returns_current_ids():
|
||||
with rc.request_id_scope("r1"), rc.trace_id_scope("t1"):
|
||||
snap = rc.snapshot()
|
||||
assert snap == {"request_id": "r1", "trace_id": "t1"}
|
||||
|
||||
|
||||
def test_apply_snapshot_restores_after_exit():
|
||||
snap = {"request_id": "r2", "trace_id": "t2"}
|
||||
with rc.apply_snapshot(snap):
|
||||
assert rc.get_request_id() == "r2"
|
||||
assert rc.get_trace_id() == "t2"
|
||||
assert rc.get_request_id() == ""
|
||||
assert rc.get_trace_id() == ""
|
||||
|
||||
|
||||
def test_apply_snapshot_handles_none():
|
||||
"""传 None 应是 no-op,不报错。"""
|
||||
with rc.apply_snapshot(None):
|
||||
assert rc.get_request_id() == ""
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_contextvars_isolated_between_concurrent_tasks():
|
||||
"""两个并发的 asyncio task 各自的 trace_id 不应互相串味。"""
|
||||
results: dict[str, str] = {}
|
||||
|
||||
async def worker(name: str, trace_id: str) -> None:
|
||||
with rc.trace_id_scope(trace_id):
|
||||
await asyncio.sleep(0)
|
||||
results[name] = rc.get_trace_id()
|
||||
|
||||
await asyncio.gather(
|
||||
worker("a", "trace-a"),
|
||||
worker("b", "trace-b"),
|
||||
)
|
||||
assert results == {"a": "trace-a", "b": "trace-b"}
|
||||
|
||||
|
||||
def test_new_request_id_format():
|
||||
rid = rc.new_request_id()
|
||||
assert rid.startswith("req-")
|
||||
assert len(rid) > len("req-")
|
||||
|
||||
|
||||
def test_new_request_id_custom_prefix():
|
||||
rid = rc.new_request_id("ws")
|
||||
assert rid.startswith("ws-")
|
||||
@@ -0,0 +1,107 @@
|
||||
"""``request_id_middleware``:请求级 ID 入口生成/继承 + 响应头回写。
|
||||
|
||||
覆盖:
|
||||
|
||||
- 没传 X-Request-Id 时 middleware 生成新 ID 并写回响应头
|
||||
- 传了 X-Request-Id 时被尊重并回写
|
||||
- 路由处理器内可以从 contextvars 读到当前 request_id
|
||||
- 异常路径下 contextvars 也能被正确 reset(不会泄漏到下一请求)
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import pytest
|
||||
from fastapi import FastAPI, Request
|
||||
from httpx import AsyncClient, ASGITransport
|
||||
|
||||
from kilostar.utils import request_context as rc
|
||||
from kilostar.utils.request_context import (
|
||||
bind_request_id,
|
||||
new_request_id,
|
||||
reset_request_id,
|
||||
)
|
||||
|
||||
|
||||
def _build_app() -> FastAPI:
|
||||
app = FastAPI()
|
||||
|
||||
@app.middleware("http")
|
||||
async def request_id_middleware(request: Request, call_next):
|
||||
incoming = request.headers.get("X-Request-Id", "").strip()
|
||||
request_id = incoming or new_request_id()
|
||||
token = bind_request_id(request_id)
|
||||
try:
|
||||
response = await call_next(request)
|
||||
finally:
|
||||
reset_request_id(token)
|
||||
response.headers["X-Request-Id"] = request_id
|
||||
return response
|
||||
|
||||
@app.get("/whoami")
|
||||
async def whoami():
|
||||
return {"request_id": rc.get_request_id()}
|
||||
|
||||
return app
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_generates_request_id_when_header_absent():
|
||||
transport = ASGITransport(app=_build_app(), raise_app_exceptions=False)
|
||||
async with AsyncClient(transport=transport, base_url="http://test") as client:
|
||||
resp = await client.get("/whoami")
|
||||
|
||||
assert resp.status_code == 200
|
||||
rid = resp.headers.get("X-Request-Id")
|
||||
assert rid and rid.startswith("req-")
|
||||
assert resp.json()["request_id"] == rid
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_inherits_request_id_from_header():
|
||||
transport = ASGITransport(app=_build_app(), raise_app_exceptions=False)
|
||||
async with AsyncClient(transport=transport, base_url="http://test") as client:
|
||||
resp = await client.get(
|
||||
"/whoami", headers={"X-Request-Id": "client-supplied-123"}
|
||||
)
|
||||
assert resp.headers.get("X-Request-Id") == "client-supplied-123"
|
||||
assert resp.json()["request_id"] == "client-supplied-123"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_request_id_reset_after_request():
|
||||
"""两次请求的 request_id 不应互相串味(contextvars 在 finally 里被 reset)。"""
|
||||
transport = ASGITransport(app=_build_app(), raise_app_exceptions=False)
|
||||
async with AsyncClient(transport=transport, base_url="http://test") as client:
|
||||
r1 = await client.get("/whoami")
|
||||
r2 = await client.get("/whoami")
|
||||
assert r1.headers["X-Request-Id"] != r2.headers["X-Request-Id"]
|
||||
|
||||
|
||||
def test_logger_picks_up_contextvars():
|
||||
"""logger 切面:contextvars 中的值会被 patcher 注入到 record.extra。"""
|
||||
from kilostar.utils.logger import get_logger
|
||||
|
||||
captured: list[dict] = []
|
||||
|
||||
def _format(record):
|
||||
# loguru 的 format 函数收到 record dict 本身,可以直接读 extra
|
||||
captured.append(dict(record["extra"]))
|
||||
return "{message}\n"
|
||||
|
||||
from loguru import logger as _global
|
||||
|
||||
handler_id = _global.add(lambda _msg: None, format=_format, level="DEBUG")
|
||||
try:
|
||||
with rc.trace_id_scope("trace-from-ctx"), rc.request_id_scope("req-from-ctx"):
|
||||
log = get_logger("test_actor")
|
||||
log.info("hello")
|
||||
finally:
|
||||
_global.remove(handler_id)
|
||||
|
||||
assert captured, "应至少捕获一条日志"
|
||||
# 找到我们的那条(避免被并发中其他 logger 干扰)
|
||||
matched = [c for c in captured if c.get("actor_name") == "test_actor"]
|
||||
assert matched, f"未捕获到来自 test_actor 的日志,全部 captured={captured}"
|
||||
last = matched[-1]
|
||||
assert last.get("trace_id") == "trace-from-ctx"
|
||||
assert last.get("request_id") == "req-from-ctx"
|
||||
@@ -0,0 +1,88 @@
|
||||
"""kilostar.utils.error 类型层级与状态码测试。
|
||||
|
||||
异常体系两条主轴:
|
||||
- ``BusinessError``:4xx;
|
||||
- ``InfraError``:5xx,下分 ``RetryableError`` / ``NonRetryableError``。
|
||||
"""
|
||||
|
||||
import pytest
|
||||
|
||||
from kilostar.utils import error
|
||||
|
||||
|
||||
def test_two_main_branches_under_kilostar_error():
|
||||
assert issubclass(error.BusinessError, error.KiloStarError)
|
||||
assert issubclass(error.InfraError, error.KiloStarError)
|
||||
assert not issubclass(error.BusinessError, error.InfraError)
|
||||
assert not issubclass(error.InfraError, error.BusinessError)
|
||||
|
||||
|
||||
def test_retryable_and_nonretryable_under_infra():
|
||||
assert issubclass(error.RetryableError, error.InfraError)
|
||||
assert issubclass(error.NonRetryableError, error.InfraError)
|
||||
assert not issubclass(error.RetryableError, error.NonRetryableError)
|
||||
assert not issubclass(error.NonRetryableError, error.RetryableError)
|
||||
|
||||
|
||||
def test_demand_error_is_business_4xx():
|
||||
assert issubclass(error.DemandError, error.BusinessError)
|
||||
assert error.DemandError.http_status == 400
|
||||
|
||||
|
||||
def test_workflow_exit_is_business_not_workflow_error():
|
||||
"""WorkflowExit 在新体系里是预期退出(4xx),与 WorkflowError(5xx) 完全分家。"""
|
||||
assert issubclass(error.WorkflowExit, error.BusinessError)
|
||||
assert not issubclass(error.WorkflowExit, error.WorkflowError)
|
||||
assert error.WorkflowExit.http_status == 400
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"child,parent",
|
||||
[
|
||||
(error.UserNotExistError, error.UserError),
|
||||
(error.UserPasswordError, error.UserError),
|
||||
(error.ProviderNotExistError, error.ProviderError),
|
||||
(error.UserError, error.BusinessError),
|
||||
(error.ProviderError, error.BusinessError),
|
||||
(error.ModelNotExistError, error.BusinessError),
|
||||
(error.WorkflowError, error.InfraError),
|
||||
],
|
||||
)
|
||||
def test_subclass_hierarchy(child, parent):
|
||||
assert issubclass(child, parent)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"exc_cls,expected_status",
|
||||
[
|
||||
(error.UserNotExistError, 404),
|
||||
(error.UserPasswordError, 401),
|
||||
(error.ProviderNotExistError, 404),
|
||||
(error.ModelNotExistError, 404),
|
||||
(error.WorkflowExit, 400),
|
||||
(error.RetryableError, 503),
|
||||
(error.NonRetryableError, 500),
|
||||
(error.WorkflowError, 500),
|
||||
],
|
||||
)
|
||||
def test_http_status_mapping(exc_cls, expected_status):
|
||||
assert exc_cls.http_status == expected_status
|
||||
|
||||
|
||||
def test_each_error_has_code_attr():
|
||||
"""所有自定义异常都应带可读的 ``code``,用于 API 响应中 ``{"code": ...}``。"""
|
||||
for cls in [
|
||||
error.KiloStarError,
|
||||
error.BusinessError,
|
||||
error.InfraError,
|
||||
error.UserNotExistError,
|
||||
error.RetryableError,
|
||||
error.WorkflowError,
|
||||
]:
|
||||
assert isinstance(cls.code, str) and cls.code
|
||||
|
||||
|
||||
def test_errors_can_be_raised_and_caught():
|
||||
with pytest.raises(error.UserNotExistError) as exc:
|
||||
raise error.UserNotExistError("missing")
|
||||
assert "missing" in str(exc.value)
|
||||
@@ -0,0 +1,44 @@
|
||||
"""``utils.get_tool`` 在真实仓库目录上的加载行为。"""
|
||||
|
||||
from kilostar.utils import get_tool
|
||||
from kilostar.utils.get_tool import (
|
||||
_get_tool_func,
|
||||
del_tool_cache,
|
||||
load_tools_from_list,
|
||||
)
|
||||
|
||||
|
||||
def setup_function(_func):
|
||||
"""每个测试前清空模块级缓存,避免相互影响。"""
|
||||
get_tool._tool_cache.clear()
|
||||
|
||||
|
||||
def test_load_existing_tool_via_load_tools_from_list():
|
||||
tools = load_tools_from_list(["file_reader"])
|
||||
assert len(tools) == 1
|
||||
assert tools[0].__name__ == "file_reader"
|
||||
|
||||
|
||||
def test_loader_caches_function():
|
||||
f1 = _get_tool_func("file_reader")
|
||||
f2 = _get_tool_func("file_reader")
|
||||
assert f1 is f2
|
||||
assert "file_reader" in get_tool._tool_cache
|
||||
|
||||
|
||||
def test_del_tool_cache_removes_entry():
|
||||
_get_tool_func("file_reader")
|
||||
assert "file_reader" in get_tool._tool_cache
|
||||
del_tool_cache("file_reader")
|
||||
assert "file_reader" not in get_tool._tool_cache
|
||||
|
||||
|
||||
def test_load_unknown_tool_returns_none_and_keeps_others():
|
||||
tools = load_tools_from_list(["file_reader", "definitely_not_exist"])
|
||||
assert len(tools) == 1
|
||||
assert tools[0].__name__ == "file_reader"
|
||||
|
||||
|
||||
def test_load_tools_from_list_handles_none_and_empty():
|
||||
assert load_tools_from_list(None) == []
|
||||
assert load_tools_from_list([]) == []
|
||||
@@ -0,0 +1,244 @@
|
||||
"""``utils.mcp_helper`` 的纯逻辑路径:构建 toolset、合并、列出工具。
|
||||
|
||||
由于 stdio MCP 实例化会立即解析 command/args,但不会主动启动子进程(只在 ``async with``
|
||||
进入后才连),因此 ``build_mcp_toolsets`` 可以在测试里直接验证。
|
||||
"""
|
||||
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
from kilostar.utils import mcp_helper
|
||||
|
||||
|
||||
# ─── build_mcp_toolsets ──────────────────────────────────────────────────────
|
||||
|
||||
|
||||
def test_build_mcp_toolsets_when_mcp_unavailable_returns_empty(monkeypatch):
|
||||
monkeypatch.setattr(mcp_helper, "_MCP_AVAILABLE", False)
|
||||
assert mcp_helper.build_mcp_toolsets({"a": {"transport": "stdio"}}) == []
|
||||
|
||||
|
||||
def test_build_mcp_toolsets_skips_unsupported_transport():
|
||||
configs = {"weird": {"transport": "ftp", "name": "weird"}}
|
||||
assert mcp_helper.build_mcp_toolsets(configs) == []
|
||||
|
||||
|
||||
def test_build_mcp_toolsets_creates_one_per_supported_config():
|
||||
configs = {
|
||||
"fs": {
|
||||
"name": "fs",
|
||||
"transport": "stdio",
|
||||
"command": "echo",
|
||||
"args": ["hi"],
|
||||
"tool_prefix": "fs",
|
||||
},
|
||||
"remote_sse": {
|
||||
"name": "remote_sse",
|
||||
"transport": "sse",
|
||||
"url": "http://localhost:9000/sse",
|
||||
},
|
||||
"remote_http": {
|
||||
"name": "remote_http",
|
||||
"transport": "http",
|
||||
"url": "http://localhost:9001",
|
||||
},
|
||||
}
|
||||
toolsets = mcp_helper.build_mcp_toolsets(configs)
|
||||
assert len(toolsets) == 3
|
||||
ids = {getattr(t, "id", None) for t in toolsets}
|
||||
assert ids == {"fs", "remote_sse", "remote_http"}
|
||||
|
||||
|
||||
# ─── get_mcp_toolsets_from_gsm ───────────────────────────────────────────────
|
||||
|
||||
|
||||
async def test_get_mcp_toolsets_from_gsm_returns_empty_when_unavailable(
|
||||
monkeypatch,
|
||||
):
|
||||
monkeypatch.setattr(mcp_helper, "_MCP_AVAILABLE", False)
|
||||
assert await mcp_helper.get_mcp_toolsets_from_gsm() == []
|
||||
|
||||
|
||||
async def test_get_mcp_toolsets_from_gsm_swallows_errors_with_no_actor(monkeypatch):
|
||||
"""没有注册 actor 时返回空列表而非抛出。"""
|
||||
|
||||
monkeypatch.setattr(mcp_helper, "_MCP_AVAILABLE", True)
|
||||
|
||||
async def boom(*_a, **_kw):
|
||||
raise ValueError("no actor")
|
||||
|
||||
from kilostar.core.global_state_machine import gsm_snapshot as snap_mod
|
||||
|
||||
monkeypatch.setattr(snap_mod, "fetch_snapshot", boom)
|
||||
assert await mcp_helper.get_mcp_toolsets_from_gsm() == []
|
||||
|
||||
|
||||
async def test_get_mcp_toolsets_from_gsm_uses_configs_via_snapshot(monkeypatch):
|
||||
"""注入一个返回 mcp 配置的 snapshot,验证最终走通到 build_mcp_toolsets。"""
|
||||
|
||||
monkeypatch.setattr(mcp_helper, "_MCP_AVAILABLE", True)
|
||||
|
||||
from kilostar.core.global_state_machine import gsm_snapshot as snap_mod
|
||||
from kilostar.core.global_state_machine.gsm_snapshot import GSMSnapshot
|
||||
|
||||
snap = GSMSnapshot(
|
||||
version=1,
|
||||
mcp_servers={
|
||||
"fs": {
|
||||
"name": "fs",
|
||||
"transport": "stdio",
|
||||
"command": "echo",
|
||||
"args": ["hi"],
|
||||
},
|
||||
},
|
||||
)
|
||||
|
||||
async def _fake_fetch(**_kw):
|
||||
return snap
|
||||
|
||||
monkeypatch.setattr(snap_mod, "fetch_snapshot", _fake_fetch)
|
||||
|
||||
toolsets = await mcp_helper.get_mcp_toolsets_from_gsm()
|
||||
assert len(toolsets) == 1
|
||||
assert getattr(toolsets[0], "id", None) == "fs"
|
||||
|
||||
|
||||
# ─── get_all_toolsets_for_scope ──────────────────────────────────────────────
|
||||
|
||||
|
||||
async def test_get_all_toolsets_for_scope_merges_local_and_mcp(monkeypatch):
|
||||
"""本地 toolset 列表和 mcp toolset 列表都应该被拼接。"""
|
||||
|
||||
monkeypatch.setattr(mcp_helper, "_MCP_AVAILABLE", True)
|
||||
|
||||
local_a = MagicMock(name="local-system")
|
||||
local_b = MagicMock(name="local-personal")
|
||||
|
||||
from kilostar.core.global_state_machine import gsm_snapshot as snap_mod
|
||||
from kilostar.core.global_state_machine.gsm_snapshot import GSMSnapshot
|
||||
|
||||
snap = GSMSnapshot(version=1, mcp_servers={})
|
||||
|
||||
async def _fake_fetch(**_kw):
|
||||
return snap
|
||||
|
||||
monkeypatch.setattr(snap_mod, "fetch_snapshot", _fake_fetch)
|
||||
monkeypatch.setattr(
|
||||
snap_mod,
|
||||
"build_toolsets_for_scope",
|
||||
lambda s, scope: [local_a, local_b],
|
||||
)
|
||||
|
||||
result = await mcp_helper.get_all_toolsets_for_scope("control_node")
|
||||
assert result == [local_a, local_b]
|
||||
|
||||
|
||||
async def test_get_all_toolsets_for_scope_local_failure_does_not_block_mcp(
|
||||
monkeypatch,
|
||||
):
|
||||
"""本地 toolset 拉取失败时仍然要返回 mcp toolset。"""
|
||||
|
||||
monkeypatch.setattr(mcp_helper, "_MCP_AVAILABLE", True)
|
||||
|
||||
from kilostar.core.global_state_machine import gsm_snapshot as snap_mod
|
||||
from kilostar.core.global_state_machine.gsm_snapshot import GSMSnapshot
|
||||
|
||||
# local 路径:fetch 成功但 build_toolsets_for_scope 抛错
|
||||
snap = GSMSnapshot(
|
||||
version=1,
|
||||
mcp_servers={
|
||||
"fs": {
|
||||
"name": "fs",
|
||||
"transport": "stdio",
|
||||
"command": "echo",
|
||||
"args": ["hi"],
|
||||
}
|
||||
},
|
||||
)
|
||||
|
||||
async def _fake_fetch(**_kw):
|
||||
return snap
|
||||
|
||||
def _broken_build(*_a, **_kw):
|
||||
raise RuntimeError("boom")
|
||||
|
||||
monkeypatch.setattr(snap_mod, "fetch_snapshot", _fake_fetch)
|
||||
monkeypatch.setattr(snap_mod, "build_toolsets_for_scope", _broken_build)
|
||||
|
||||
result = await mcp_helper.get_all_toolsets_for_scope("control_node")
|
||||
assert len(result) == 1
|
||||
assert getattr(result[0], "id", None) == "fs"
|
||||
|
||||
|
||||
# ─── list_mcp_tools_for_configs ──────────────────────────────────────────────
|
||||
|
||||
|
||||
async def test_list_mcp_tools_for_configs_returns_tool_names(monkeypatch):
|
||||
"""模拟 server.get_tools() 返回若干带 name 的工具,结果应被抽取出来。"""
|
||||
|
||||
monkeypatch.setattr(mcp_helper, "_MCP_AVAILABLE", True)
|
||||
|
||||
class _FakeTool:
|
||||
def __init__(self, name: str) -> None:
|
||||
self.name = name
|
||||
|
||||
class _FakeServer:
|
||||
def __init__(self, server_id: str, tools):
|
||||
self.id = server_id
|
||||
self._tools = tools
|
||||
|
||||
async def __aenter__(self):
|
||||
return self
|
||||
|
||||
async def __aexit__(self, exc_type, exc, tb):
|
||||
return False
|
||||
|
||||
async def get_tools(self):
|
||||
return self._tools
|
||||
|
||||
fake_servers = [_FakeServer("fs", [_FakeTool("read"), _FakeTool("write")])]
|
||||
monkeypatch.setattr(
|
||||
mcp_helper, "build_mcp_toolsets", lambda configs: list(fake_servers)
|
||||
)
|
||||
|
||||
result = await mcp_helper.list_mcp_tools_for_configs(
|
||||
{"fs": {"name": "fs", "transport": "stdio"}}
|
||||
)
|
||||
assert len(result) == 1
|
||||
assert result[0]["server_id"] == "fs"
|
||||
assert result[0]["tools"] == ["read", "write"]
|
||||
assert "error" not in result[0]
|
||||
|
||||
|
||||
async def test_list_mcp_tools_for_configs_records_error_on_failure(monkeypatch):
|
||||
monkeypatch.setattr(mcp_helper, "_MCP_AVAILABLE", True)
|
||||
|
||||
class _BrokenServer:
|
||||
def __init__(self) -> None:
|
||||
self.id = "broken"
|
||||
|
||||
async def __aenter__(self):
|
||||
raise RuntimeError("connect fail")
|
||||
|
||||
async def __aexit__(self, exc_type, exc, tb):
|
||||
return False
|
||||
|
||||
async def get_tools(self): # pragma: no cover - 不会被调用
|
||||
return []
|
||||
|
||||
monkeypatch.setattr(
|
||||
mcp_helper, "build_mcp_toolsets", lambda configs: [_BrokenServer()]
|
||||
)
|
||||
result = await mcp_helper.list_mcp_tools_for_configs(
|
||||
{"broken": {"name": "broken", "transport": "stdio"}}
|
||||
)
|
||||
assert len(result) == 1
|
||||
assert result[0]["server_id"] == "broken"
|
||||
assert result[0]["tools"] == []
|
||||
assert "connect fail" in result[0]["error"]
|
||||
|
||||
|
||||
async def test_list_mcp_tools_for_configs_when_mcp_unavailable(monkeypatch):
|
||||
monkeypatch.setattr(mcp_helper, "_MCP_AVAILABLE", False)
|
||||
assert await mcp_helper.list_mcp_tools_for_configs({"x": {}}) == []
|
||||
@@ -0,0 +1,46 @@
|
||||
"""``ray_hook`` 中纯逻辑容器与 actor 句柄缓存的行为。"""
|
||||
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
from kilostar.utils.ray_hook import ActorList
|
||||
|
||||
|
||||
def test_actor_list_attribute_set_get_delete():
|
||||
actors = ActorList()
|
||||
actors.foo = "bar"
|
||||
assert actors.foo == "bar"
|
||||
del actors.foo
|
||||
with pytest.raises(AttributeError):
|
||||
_ = actors.foo
|
||||
|
||||
|
||||
def test_actor_list_missing_raises_attribute_error():
|
||||
actors = ActorList()
|
||||
with pytest.raises(AttributeError):
|
||||
_ = actors.not_exist
|
||||
|
||||
|
||||
def test_actor_list_delete_missing_raises_attribute_error():
|
||||
actors = ActorList()
|
||||
with pytest.raises(AttributeError):
|
||||
del actors.not_exist
|
||||
|
||||
|
||||
def test_ray_actor_hook_uses_fake_registry(fake_actors):
|
||||
"""``ray_actor_hook`` 通过 fake registry 取 actor 并组装成 ActorList。"""
|
||||
handle = MagicMock()
|
||||
fake_actors.register("postgres_database", handle)
|
||||
|
||||
from kilostar.utils.ray_hook import ray_actor_hook
|
||||
|
||||
actors = ray_actor_hook("postgres_database")
|
||||
assert actors.postgres_database is handle
|
||||
|
||||
|
||||
def test_ray_actor_hook_unknown_actor_raises(fake_actors):
|
||||
from kilostar.utils.ray_hook import ray_actor_hook
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
ray_actor_hook("does_not_exist")
|
||||
@@ -0,0 +1,87 @@
|
||||
"""``retry_on_retryable_error`` 装饰器:覆盖同步/异步、重试次数、退避语义。"""
|
||||
|
||||
import asyncio
|
||||
|
||||
import pytest
|
||||
|
||||
from kilostar.utils.error import NonRetryableError, RetryableError
|
||||
from kilostar.utils.retry import retry_on_retryable_error
|
||||
|
||||
|
||||
def test_sync_retries_until_success(monkeypatch):
|
||||
sleep_calls: list[float] = []
|
||||
monkeypatch.setattr("time.sleep", lambda s: sleep_calls.append(s))
|
||||
|
||||
counter = {"n": 0}
|
||||
|
||||
@retry_on_retryable_error(max_retries=3, base_delay=1)
|
||||
def flaky() -> str:
|
||||
counter["n"] += 1
|
||||
if counter["n"] < 3:
|
||||
raise RetryableError("temp")
|
||||
return "ok"
|
||||
|
||||
assert flaky() == "ok"
|
||||
assert counter["n"] == 3
|
||||
# 第 1、2 次失败前各 sleep 一次:base * 2**0 = 1, base * 2**1 = 2
|
||||
assert sleep_calls == [1, 2]
|
||||
|
||||
|
||||
def test_sync_reraises_after_exhaustion(monkeypatch):
|
||||
monkeypatch.setattr("time.sleep", lambda _s: None)
|
||||
|
||||
@retry_on_retryable_error(max_retries=2, base_delay=1)
|
||||
def always_fail() -> None:
|
||||
raise RetryableError("nope")
|
||||
|
||||
with pytest.raises(RetryableError):
|
||||
always_fail()
|
||||
|
||||
|
||||
def test_sync_does_not_retry_on_non_retryable(monkeypatch):
|
||||
sleep_calls: list[float] = []
|
||||
monkeypatch.setattr("time.sleep", lambda s: sleep_calls.append(s))
|
||||
|
||||
@retry_on_retryable_error(max_retries=5, base_delay=1)
|
||||
def boom() -> None:
|
||||
raise NonRetryableError("hard")
|
||||
|
||||
with pytest.raises(NonRetryableError):
|
||||
boom()
|
||||
assert sleep_calls == []
|
||||
|
||||
|
||||
async def test_async_retries_until_success(monkeypatch):
|
||||
sleep_calls: list[float] = []
|
||||
|
||||
async def fake_sleep(s: float) -> None:
|
||||
sleep_calls.append(s)
|
||||
|
||||
monkeypatch.setattr(asyncio, "sleep", fake_sleep)
|
||||
|
||||
counter = {"n": 0}
|
||||
|
||||
@retry_on_retryable_error(max_retries=4, base_delay=2)
|
||||
async def flaky() -> int:
|
||||
counter["n"] += 1
|
||||
if counter["n"] < 2:
|
||||
raise RetryableError("temp")
|
||||
return 42
|
||||
|
||||
assert await flaky() == 42
|
||||
assert counter["n"] == 2
|
||||
assert sleep_calls == [2] # base * 2**0
|
||||
|
||||
|
||||
async def test_async_reraises_after_exhaustion(monkeypatch):
|
||||
async def fake_sleep(_s: float) -> None:
|
||||
return None
|
||||
|
||||
monkeypatch.setattr(asyncio, "sleep", fake_sleep)
|
||||
|
||||
@retry_on_retryable_error(max_retries=2, base_delay=1)
|
||||
async def boom() -> None:
|
||||
raise RetryableError("nope")
|
||||
|
||||
with pytest.raises(RetryableError):
|
||||
await boom()
|
||||
@@ -0,0 +1,149 @@
|
||||
"""``ConsciousnessNode.start_workflow_design`` fire workflow ray task 的提交逻辑。
|
||||
|
||||
历史上这里有一个常驻的 ``WorkflowRunningEngine`` actor 做中转,现已删除:
|
||||
workflow 是一次性、有头有尾的执行,更适合直接以 ray task 形式触发。
|
||||
本测试保证 ConsciousnessNode 在工作流生成后正确 fire ``run_workflow_task``,
|
||||
并通过 ``put_pending`` 推送 SSE 进度(节点端写 pending → API 端 SSE 读 pending)。
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
from kilostar.core.work.workflow import workflow_engine as engine_module
|
||||
from kilostar.core.work.workflow.workflow import KiloStarWorkflow
|
||||
from kilostar.core.work.workflow.model import WorkflowMetadata
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def consciousness_instance():
|
||||
from kilostar.core.individual.consciousness_node.consciousness_node import (
|
||||
ConsciousnessNode,
|
||||
)
|
||||
from kilostar.utils.logger import get_logger
|
||||
|
||||
cls = ConsciousnessNode.__ray_actor_class__
|
||||
obj = cls.__new__(cls)
|
||||
obj.logger = get_logger("consciousness_node")
|
||||
obj.agent = None
|
||||
return obj
|
||||
|
||||
|
||||
class _FakeActorRef:
|
||||
"""模拟 ``ray_actor_hook("name").<name>`` 的链式取属性返回值。"""
|
||||
|
||||
def __init__(self, target):
|
||||
self._target = target
|
||||
|
||||
def __getattr__(self, item):
|
||||
return getattr(self._target, item)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_start_workflow_design_fires_run_workflow_task(
|
||||
consciousness_instance, monkeypatch
|
||||
):
|
||||
"""快乐路径:working 返回 ForWorkflowEngine,应 fire run_workflow_task 且推送 pending。"""
|
||||
from kilostar.core.individual.consciousness_node.template import (
|
||||
ForWorkflowEngine,
|
||||
)
|
||||
|
||||
wf = KiloStarWorkflow(
|
||||
title="t",
|
||||
work_link=[],
|
||||
workflow_metadata=WorkflowMetadata(),
|
||||
)
|
||||
|
||||
consciousness_instance.working = AsyncMock(
|
||||
return_value=ForWorkflowEngine(workflow=wf, reasoning="r")
|
||||
)
|
||||
|
||||
postgres = MagicMock()
|
||||
postgres.get_all_worker_individual = MagicMock()
|
||||
postgres.get_all_worker_individual.remote = AsyncMock(return_value=[])
|
||||
postgres.update_workflow_status = MagicMock()
|
||||
postgres.update_workflow_status.remote = AsyncMock()
|
||||
|
||||
pending_writes: list[tuple[str, str]] = []
|
||||
|
||||
gwm = MagicMock()
|
||||
gwm.put_pending = MagicMock()
|
||||
gwm.put_pending.remote = AsyncMock(
|
||||
side_effect=lambda tid, msg: pending_writes.append((tid, msg))
|
||||
)
|
||||
|
||||
def _fake_hook(name):
|
||||
if name == "postgres_database":
|
||||
return SimpleNamespace(postgres_database=_FakeActorRef(postgres))
|
||||
if name == "global_workflow_manager":
|
||||
return SimpleNamespace(global_workflow_manager=_FakeActorRef(gwm))
|
||||
raise KeyError(name)
|
||||
|
||||
import kilostar.core.individual.consciousness_node.consciousness_node as cmod
|
||||
|
||||
monkeypatch.setattr(cmod, "ray_actor_hook", _fake_hook)
|
||||
|
||||
captured: dict = {}
|
||||
|
||||
def _fake_task_remote(workflow_dict, trace_id):
|
||||
captured["workflow_dict"] = workflow_dict
|
||||
captured["trace_id"] = trace_id
|
||||
return MagicMock()
|
||||
|
||||
monkeypatch.setattr(engine_module.run_workflow_task, "remote", _fake_task_remote)
|
||||
|
||||
await consciousness_instance.start_workflow_design("trace-123", "do something")
|
||||
|
||||
assert captured["trace_id"] == "trace-123"
|
||||
assert captured["workflow_dict"]["title"] == "t"
|
||||
# SSE 推送方向必须是 pending(put_pending)
|
||||
assert any("正在为您构建" in msg for _, msg in pending_writes)
|
||||
assert any("即将开始执行" in msg for _, msg in pending_writes)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_start_workflow_design_failed_path_marks_failed(
|
||||
consciousness_instance, monkeypatch
|
||||
):
|
||||
"""working 返回 None / 不匹配类型时应推送失败提示并把 workflow 状态置为 failed。"""
|
||||
consciousness_instance.working = AsyncMock(return_value=None)
|
||||
|
||||
postgres = MagicMock()
|
||||
postgres.get_all_worker_individual = MagicMock()
|
||||
postgres.get_all_worker_individual.remote = AsyncMock(return_value=[])
|
||||
postgres.update_workflow_status = MagicMock()
|
||||
postgres.update_workflow_status.remote = AsyncMock()
|
||||
|
||||
pending_writes: list[tuple[str, str]] = []
|
||||
gwm = MagicMock()
|
||||
gwm.put_pending = MagicMock()
|
||||
gwm.put_pending.remote = AsyncMock(
|
||||
side_effect=lambda tid, msg: pending_writes.append((tid, msg))
|
||||
)
|
||||
|
||||
def _fake_hook(name):
|
||||
if name == "postgres_database":
|
||||
return SimpleNamespace(postgres_database=_FakeActorRef(postgres))
|
||||
if name == "global_workflow_manager":
|
||||
return SimpleNamespace(global_workflow_manager=_FakeActorRef(gwm))
|
||||
raise KeyError(name)
|
||||
|
||||
import kilostar.core.individual.consciousness_node.consciousness_node as cmod
|
||||
|
||||
monkeypatch.setattr(cmod, "ray_actor_hook", _fake_hook)
|
||||
|
||||
fired: list = []
|
||||
monkeypatch.setattr(
|
||||
engine_module.run_workflow_task,
|
||||
"remote",
|
||||
lambda *a, **kw: fired.append((a, kw)),
|
||||
)
|
||||
|
||||
await consciousness_instance.start_workflow_design("trace-x", "cmd")
|
||||
|
||||
assert fired == [] # 没有 fire ray task
|
||||
postgres.update_workflow_status.remote.assert_awaited_with("trace-x", "failed")
|
||||
assert any("生成失败" in msg for _, msg in pending_writes)
|
||||
@@ -0,0 +1,290 @@
|
||||
"""Workflow graph 引擎本身(节点跳转 / 失败处理 / 派发 / HITL)的单元测试。
|
||||
|
||||
不经过 ConsciousnessNode 入口,也不依赖 ray runtime —— 直接调用
|
||||
``run_workflow_graph(workflow_data, trace_id, deps=...)``,注入一套全部用
|
||||
``AsyncMock``/lambda 实现的 ``WorkflowDeps``。验证 graph 驱动确实把状态机跑通。
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any, Dict, List, Tuple
|
||||
from unittest.mock import AsyncMock
|
||||
|
||||
import pytest
|
||||
|
||||
from kilostar.core.work.workflow.workflow_engine import (
|
||||
WorkflowDeps,
|
||||
WorkflowGraphState,
|
||||
run_workflow_graph,
|
||||
)
|
||||
from kilostar.core.work.workflow.model import WorkflowStatus
|
||||
|
||||
|
||||
def _make_deps(
|
||||
*,
|
||||
skill_outputs: List[Tuple[str, bool]] | None = None,
|
||||
consciousness_outputs: List[Tuple[str, bool]] | None = None,
|
||||
received_replies: List[str] | None = None,
|
||||
) -> tuple[WorkflowDeps, Dict[str, List[Any]]]:
|
||||
"""构造一对 ``(deps, sink)``:sink 收集所有 IO 调用,方便断言。
|
||||
|
||||
``skill_outputs`` / ``consciousness_outputs`` 是按顺序消费的 ``(text, success)``
|
||||
队列,run_skill / run_consciousness 每被调一次取一个。``received_replies``
|
||||
用于 HumanApproval 节点的 get_received。
|
||||
"""
|
||||
sink: Dict[str, List[Any]] = {
|
||||
"upsert": [],
|
||||
"status": [],
|
||||
"pending": [],
|
||||
"received_calls": [],
|
||||
"skill_calls": [],
|
||||
"consciousness_calls": [],
|
||||
}
|
||||
|
||||
skill_queue = list(skill_outputs or [])
|
||||
consc_queue = list(consciousness_outputs or [])
|
||||
reply_queue = list(received_replies or [])
|
||||
|
||||
upsert = AsyncMock(side_effect=lambda *a, **kw: sink["upsert"].append((a, kw)))
|
||||
status = AsyncMock(side_effect=lambda tid, st: sink["status"].append((tid, st)))
|
||||
pending = AsyncMock(side_effect=lambda tid, msg: sink["pending"].append((tid, msg)))
|
||||
|
||||
async def _get_received(tid: str) -> str:
|
||||
sink["received_calls"].append(tid)
|
||||
return reply_queue.pop(0) if reply_queue else ""
|
||||
|
||||
async def _run_skill(step, state):
|
||||
sink["skill_calls"].append((step, state.current_step_index))
|
||||
if not skill_queue:
|
||||
return "(no fixture)", True
|
||||
return skill_queue.pop(0)
|
||||
|
||||
async def _run_consciousness(step, state):
|
||||
sink["consciousness_calls"].append((step, state.current_step_index))
|
||||
if not consc_queue:
|
||||
return "(no fixture)", True
|
||||
return consc_queue.pop(0)
|
||||
|
||||
return (
|
||||
WorkflowDeps(
|
||||
upsert_workflow_context=upsert,
|
||||
update_workflow_status=status,
|
||||
put_pending=pending,
|
||||
get_received=_get_received,
|
||||
run_skill=_run_skill,
|
||||
run_consciousness=_run_consciousness,
|
||||
),
|
||||
sink,
|
||||
)
|
||||
|
||||
|
||||
# ─── 基本路径 ─────────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_skill_individual_path_completes_linear_steps():
|
||||
"""两步顺序工作流 (skill_individual):成功 → 成功,应推进到 COMPLETED。"""
|
||||
deps, sink = _make_deps(skill_outputs=[("ok1", True), ("ok2", True)])
|
||||
workflow_data = {
|
||||
"work_link": [
|
||||
{"step": 1, "name": "s1", "action": "do1", "outputs": "o1",
|
||||
"node": "skill_individual", "agent_id": "a1"},
|
||||
{"step": 2, "name": "s2", "action": "do2", "outputs": "o2",
|
||||
"node": "skill_individual", "agent_id": "a2"},
|
||||
]
|
||||
}
|
||||
final = await run_workflow_graph(workflow_data, "trace-ok", deps=deps)
|
||||
assert final == WorkflowStatus.COMPLETED.value
|
||||
# run_skill 被调了两次,consciousness 没被调
|
||||
assert len(sink["skill_calls"]) == 2
|
||||
assert len(sink["consciousness_calls"]) == 0
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_consciousness_node_path_dispatches_to_consciousness():
|
||||
"""node=consciousness_node 的 step 应被派给 run_consciousness。"""
|
||||
deps, sink = _make_deps(
|
||||
consciousness_outputs=[("ok-from-consciousness", True)]
|
||||
)
|
||||
workflow_data = {
|
||||
"work_link": [
|
||||
{"step": 1, "name": "s1", "action": "summarize",
|
||||
"node": "consciousness_node"},
|
||||
]
|
||||
}
|
||||
final = await run_workflow_graph(workflow_data, "trace-c", deps=deps)
|
||||
assert final == WorkflowStatus.COMPLETED.value
|
||||
assert len(sink["consciousness_calls"]) == 1
|
||||
assert len(sink["skill_calls"]) == 0
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_mixed_path_dispatches_correctly():
|
||||
"""混合 step:第一步 skill_individual,第二步 consciousness_node。"""
|
||||
deps, sink = _make_deps(
|
||||
skill_outputs=[("o1", True)],
|
||||
consciousness_outputs=[("o2", True)],
|
||||
)
|
||||
workflow_data = {
|
||||
"work_link": [
|
||||
{"step": 1, "name": "s1", "action": "do",
|
||||
"node": "skill_individual", "agent_id": "a1"},
|
||||
{"step": 2, "name": "s2", "action": "review",
|
||||
"node": "consciousness_node"},
|
||||
]
|
||||
}
|
||||
final = await run_workflow_graph(workflow_data, "trace-mix", deps=deps)
|
||||
assert final == WorkflowStatus.COMPLETED.value
|
||||
assert len(sink["skill_calls"]) == 1
|
||||
assert len(sink["consciousness_calls"]) == 1
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_unknown_node_type_falls_to_failed():
|
||||
"""未识别的 node 类型应直接收尾 FAILED,不静默跑成功。"""
|
||||
deps, sink = _make_deps()
|
||||
workflow_data = {
|
||||
"work_link": [
|
||||
{"step": 1, "name": "s1", "action": "do",
|
||||
"node": "phantom_node"},
|
||||
]
|
||||
}
|
||||
final = await run_workflow_graph(workflow_data, "trace-unknown", deps=deps)
|
||||
assert final == WorkflowStatus.FAILED.value
|
||||
assert len(sink["skill_calls"]) == 0
|
||||
assert len(sink["consciousness_calls"]) == 0
|
||||
|
||||
|
||||
# ─── logic_gate 跳转语义 ──────────────────────────────────────────────────
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_if_pass_exit_short_circuits():
|
||||
"""``logic_gate.if_pass=exit`` 应在该步成功后立即收尾。"""
|
||||
deps, sink = _make_deps(skill_outputs=[("ok", True)])
|
||||
workflow_data = {
|
||||
"work_link": [
|
||||
{"step": 1, "name": "s1", "action": "do", "outputs": "o",
|
||||
"node": "skill_individual", "agent_id": "a1",
|
||||
"logic_gate": {"if_pass": "exit", "if_fail": "jump_to_step_1"}},
|
||||
{"step": 2, "name": "s2", "action": "skipped",
|
||||
"node": "skill_individual", "agent_id": "a2"},
|
||||
]
|
||||
}
|
||||
final = await run_workflow_graph(workflow_data, "trace-exit", deps=deps)
|
||||
assert final == WorkflowStatus.COMPLETED.value
|
||||
# 第二步不应被派发
|
||||
assert len(sink["skill_calls"]) == 1
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_failure_without_jump_marks_failed():
|
||||
"""run_skill 报失败且 if_fail 不指向跳转 → FAILED。"""
|
||||
deps, sink = _make_deps(skill_outputs=[("boom", False)])
|
||||
workflow_data = {
|
||||
"work_link": [
|
||||
{"step": 1, "name": "s1", "action": "do",
|
||||
"node": "skill_individual", "agent_id": "a1"},
|
||||
]
|
||||
}
|
||||
final = await run_workflow_graph(workflow_data, "trace-fail", deps=deps)
|
||||
assert final == WorkflowStatus.FAILED.value
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_failure_jumps_back_then_succeeds():
|
||||
"""步 2 失败但 if_fail=jump_to_step_1,回到步 1 重跑后继续到步 2 成功。"""
|
||||
deps, sink = _make_deps(
|
||||
skill_outputs=[
|
||||
("o1-first", True), # step 1 第一次成功
|
||||
("boom", False), # step 2 第一次失败 → 跳回 step 1
|
||||
("o1-retry", True), # step 1 重跑
|
||||
("o2-final", True), # step 2 第二次成功
|
||||
]
|
||||
)
|
||||
workflow_data = {
|
||||
"work_link": [
|
||||
{"step": 1, "name": "s1", "action": "do",
|
||||
"node": "skill_individual", "agent_id": "a1"},
|
||||
{"step": 2, "name": "s2", "action": "do",
|
||||
"node": "skill_individual", "agent_id": "a2",
|
||||
"logic_gate": {"if_fail": "jump_to_step_1", "if_pass": "continue"}},
|
||||
]
|
||||
}
|
||||
final = await run_workflow_graph(workflow_data, "trace-jump", deps=deps)
|
||||
assert final == WorkflowStatus.COMPLETED.value
|
||||
# skill 应被调 4 次
|
||||
assert len(sink["skill_calls"]) == 4
|
||||
|
||||
|
||||
# ─── HITL ────────────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_human_approval_approve_continues():
|
||||
"""require_approval=True 时进入 HumanApproval;用户回 approve 后继续执行。"""
|
||||
deps, sink = _make_deps(
|
||||
skill_outputs=[("ok", True)],
|
||||
received_replies=["approve"],
|
||||
)
|
||||
workflow_data = {
|
||||
"work_link": [
|
||||
{"step": 1, "name": "danger", "action": "rm -rf /",
|
||||
"node": "skill_individual", "agent_id": "a1",
|
||||
"require_approval": True},
|
||||
]
|
||||
}
|
||||
final = await run_workflow_graph(workflow_data, "trace-hitl-ok", deps=deps)
|
||||
assert final == WorkflowStatus.COMPLETED.value
|
||||
# get_received 被调用过一次
|
||||
assert sink["received_calls"] == ["trace-hitl-ok"]
|
||||
# 应该有一条"需要人工审批"的 SSE
|
||||
msgs = [m for _, m in sink["pending"]]
|
||||
assert any("需要人工审批" in m for m in msgs)
|
||||
# skill 才被实际派发
|
||||
assert len(sink["skill_calls"]) == 1
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_human_approval_reject_aborts():
|
||||
"""用户回 reject → 不执行 step,整个工作流落到 FAILED。"""
|
||||
deps, sink = _make_deps(
|
||||
skill_outputs=[("ok", True)],
|
||||
received_replies=["reject"],
|
||||
)
|
||||
workflow_data = {
|
||||
"work_link": [
|
||||
{"step": 1, "name": "danger", "action": "do",
|
||||
"node": "skill_individual", "agent_id": "a1",
|
||||
"require_approval": True},
|
||||
]
|
||||
}
|
||||
final = await run_workflow_graph(workflow_data, "trace-hitl-no", deps=deps)
|
||||
assert final == WorkflowStatus.FAILED.value
|
||||
# skill 不应被派发
|
||||
assert len(sink["skill_calls"]) == 0
|
||||
|
||||
|
||||
# ─── 边界 ─────────────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_empty_work_link_completes_immediately():
|
||||
"""空 work_link:Initialize → Dispatch 直接判定越界 → Finalize(COMPLETED)。"""
|
||||
deps, sink = _make_deps()
|
||||
final = await run_workflow_graph({"work_link": []}, "trace-empty", deps=deps)
|
||||
assert final == WorkflowStatus.COMPLETED.value
|
||||
msgs = [m for _, m in sink["pending"]]
|
||||
assert not any("执行步骤" in m for m in msgs)
|
||||
assert any("工作流执行完成" in m for m in msgs)
|
||||
|
||||
|
||||
def test_workflow_graph_state_defaults():
|
||||
"""State 默认值确保各字段类型契约稳定。"""
|
||||
state = WorkflowGraphState(trace_id="t")
|
||||
assert state.blackboard == {}
|
||||
assert state.work_link == []
|
||||
assert state.current_step_index == 0
|
||||
assert state.final_status == WorkflowStatus.RUNNING.value
|
||||
assert state.logs == []
|
||||
assert state.original_command == ""
|
||||
@@ -0,0 +1,386 @@
|
||||
"""``PostgresStatePersistence`` 与 graph resume 路径单元测试。
|
||||
|
||||
不依赖真实 postgres / ray —— 用两个 lambda 模拟 read/write 即可。
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
from unittest.mock import AsyncMock
|
||||
|
||||
import pytest
|
||||
|
||||
from kilostar.core.work.workflow.workflow_engine import (
|
||||
WorkflowDeps,
|
||||
resume_workflow_graph,
|
||||
run_workflow_graph,
|
||||
workflow_graph,
|
||||
)
|
||||
from kilostar.core.work.workflow.graph_persistence import (
|
||||
PostgresStatePersistence,
|
||||
)
|
||||
from kilostar.core.work.workflow.model import WorkflowStatus
|
||||
|
||||
|
||||
def _make_in_memory_io() -> tuple[
|
||||
Dict[str, Any], "callable", "callable"
|
||||
]:
|
||||
"""构造一对 (db_state, write_fn, read_fn):测试用模拟 postgres。"""
|
||||
db: Dict[str, Any] = {}
|
||||
|
||||
async def write(trace_id: str, history: Any) -> None:
|
||||
db[trace_id] = history
|
||||
|
||||
async def read(trace_id: str) -> Optional[Any]:
|
||||
return db.get(trace_id)
|
||||
|
||||
return db, write, read
|
||||
|
||||
|
||||
def _make_deps(
|
||||
*,
|
||||
skill_outputs: List[Tuple[str, bool]] | None = None,
|
||||
received_replies: List[str] | None = None,
|
||||
) -> tuple[WorkflowDeps, Dict[str, List[Any]]]:
|
||||
sink: Dict[str, List[Any]] = {"skill_calls": [], "pending": []}
|
||||
skill_q = list(skill_outputs or [])
|
||||
reply_q = list(received_replies or [])
|
||||
|
||||
upsert = AsyncMock()
|
||||
status = AsyncMock()
|
||||
pending = AsyncMock(side_effect=lambda tid, msg: sink["pending"].append((tid, msg)))
|
||||
|
||||
async def _get_received(tid):
|
||||
if reply_q:
|
||||
return reply_q.pop(0)
|
||||
return ""
|
||||
|
||||
async def _run_skill(step, state):
|
||||
sink["skill_calls"].append((step.get("name"), state.current_step_index))
|
||||
if not skill_q:
|
||||
return "(no fixture)", True
|
||||
return skill_q.pop(0)
|
||||
|
||||
async def _run_consciousness(step, state): # 不会被本测试触发
|
||||
return "(consc)", True
|
||||
|
||||
return (
|
||||
WorkflowDeps(
|
||||
upsert_workflow_context=upsert,
|
||||
update_workflow_status=status,
|
||||
put_pending=pending,
|
||||
get_received=_get_received,
|
||||
run_skill=_run_skill,
|
||||
run_consciousness=_run_consciousness,
|
||||
),
|
||||
sink,
|
||||
)
|
||||
|
||||
|
||||
# ─── PostgresStatePersistence: 写穿 ──────────────────────────────────────
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_postgres_persistence_writes_history_on_each_node():
|
||||
"""每经过一个节点边界,DB 都应该被更新一次(覆盖式写最新 history)。"""
|
||||
db, write, read = _make_in_memory_io()
|
||||
persistence = PostgresStatePersistence(
|
||||
trace_id="t1", write_history=write, read_history=read
|
||||
)
|
||||
persistence.set_graph_types(workflow_graph)
|
||||
|
||||
deps, _ = _make_deps(skill_outputs=[("ok", True)])
|
||||
workflow_data = {
|
||||
"work_link": [
|
||||
{"step": 1, "name": "s1", "action": "do",
|
||||
"node": "skill_individual", "agent_id": "a1"},
|
||||
]
|
||||
}
|
||||
final = await run_workflow_graph(
|
||||
workflow_data, "t1", deps=deps, persistence=persistence
|
||||
)
|
||||
assert final == WorkflowStatus.COMPLETED.value
|
||||
# DB 里应当有 history 且 history 至少包含一个节点 snapshot
|
||||
assert "t1" in db
|
||||
history = db["t1"]
|
||||
assert isinstance(history, list)
|
||||
assert len(history) >= 1
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_postgres_persistence_swallows_db_errors_during_run():
|
||||
"""DB 写入失败不应中断 graph 运行(只在 snapshot_end 必须 succeed)。"""
|
||||
db, _, read = _make_in_memory_io()
|
||||
|
||||
write_calls = {"n": 0}
|
||||
|
||||
async def flaky_write(trace_id, history):
|
||||
write_calls["n"] += 1
|
||||
# 中途失败几次但最终一次(snapshot_end)成功
|
||||
if write_calls["n"] < 3:
|
||||
raise RuntimeError("transient db error")
|
||||
db[trace_id] = history
|
||||
|
||||
persistence = PostgresStatePersistence(
|
||||
trace_id="t-flaky", write_history=flaky_write, read_history=read
|
||||
)
|
||||
persistence.set_graph_types(workflow_graph)
|
||||
|
||||
deps, _ = _make_deps(skill_outputs=[("ok", True)])
|
||||
workflow_data = {
|
||||
"work_link": [
|
||||
{"step": 1, "name": "s1", "action": "do",
|
||||
"node": "skill_individual", "agent_id": "a1"},
|
||||
]
|
||||
}
|
||||
final = await run_workflow_graph(
|
||||
workflow_data, "t-flaky", deps=deps, persistence=persistence
|
||||
)
|
||||
# 即便中途多次写入失败,graph 仍然跑完
|
||||
assert final == WorkflowStatus.COMPLETED.value
|
||||
|
||||
|
||||
# ─── hydrate / resume:从 DB 续跑 ──────────────────────────────────────
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_hydrate_returns_false_when_no_record():
|
||||
"""DB 没记录时 hydrate 应返回 False,调用方走 fresh start。"""
|
||||
db, write, read = _make_in_memory_io()
|
||||
persistence = PostgresStatePersistence(
|
||||
trace_id="t-empty", write_history=write, read_history=read
|
||||
)
|
||||
persistence.set_graph_types(workflow_graph)
|
||||
assert await persistence.hydrate() is False
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_resume_continues_from_persisted_history(monkeypatch):
|
||||
"""先用 Graph.iter 跑半截留下"created"snapshot,然后用 resume 把剩余节点跑完。
|
||||
|
||||
断言:resume 路径只触发 1 次 skill(第 2 步),第 1 步是从 hydrate 恢复的、
|
||||
不再被重新执行。
|
||||
"""
|
||||
from kilostar.core.work.workflow.workflow_engine import (
|
||||
Initialize,
|
||||
WorkflowGraphState,
|
||||
)
|
||||
|
||||
db, write, read = _make_in_memory_io()
|
||||
persistence_a = PostgresStatePersistence(
|
||||
trace_id="t-resume", write_history=write, read_history=read
|
||||
)
|
||||
persistence_a.set_graph_types(workflow_graph)
|
||||
|
||||
deps_a, sink_a = _make_deps(skill_outputs=[("step-1-ok", True), ("step-2-ok", True)])
|
||||
workflow_data = {
|
||||
"work_link": [
|
||||
{"step": 1, "name": "s1", "action": "do",
|
||||
"node": "skill_individual", "agent_id": "a1"},
|
||||
{"step": 2, "name": "s2", "action": "do",
|
||||
"node": "skill_individual", "agent_id": "a1"},
|
||||
]
|
||||
}
|
||||
|
||||
# 用 iter 手动驱动:跑 Initialize → Dispatch → SkillStep(跑完第 1 步),
|
||||
# 然后停下,让最后一个还没跑的节点(应该是 Dispatch)保持 created
|
||||
state = WorkflowGraphState(
|
||||
trace_id="t-resume",
|
||||
blackboard={},
|
||||
work_link=list(workflow_data["work_link"]),
|
||||
original_command="",
|
||||
)
|
||||
async with workflow_graph.iter(
|
||||
Initialize(),
|
||||
state=state,
|
||||
deps=deps_a,
|
||||
persistence=persistence_a,
|
||||
) as run:
|
||||
# Initialize → Dispatch → SkillStep → Dispatch (这一步停下)
|
||||
await run.next() # 跑 Initialize,next_node = Dispatch
|
||||
await run.next() # 跑 Dispatch,next_node = SkillStep
|
||||
await run.next() # 跑 SkillStep,next_node = Dispatch(第 2 个)
|
||||
# 此时不再调 next,第 2 个 Dispatch 仍是 created 状态
|
||||
assert len(sink_a["skill_calls"]) == 1, "中断时只应跑了 1 步"
|
||||
|
||||
# 第二阶段:用同份 history resume
|
||||
persistence_b = PostgresStatePersistence(
|
||||
trace_id="t-resume", write_history=write, read_history=read
|
||||
)
|
||||
persistence_b.set_graph_types(workflow_graph)
|
||||
assert await persistence_b.hydrate() is True
|
||||
|
||||
deps_b, sink_b = _make_deps(skill_outputs=[("step-2-resumed", True)])
|
||||
final = await resume_workflow_graph(
|
||||
"t-resume", deps=deps_b, persistence=persistence_b
|
||||
)
|
||||
assert final == WorkflowStatus.COMPLETED.value
|
||||
# 关键断言:resume 只执行了第 2 步,第 1 步没被重复
|
||||
assert len(sink_b["skill_calls"]) == 1
|
||||
assert sink_b["skill_calls"][0][0] == "s2"
|
||||
|
||||
|
||||
# ─── HumanApproval idempotent resume ──────────────────────────────────────
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_human_approval_idempotent_on_resume():
|
||||
"""HumanApproval 节点在 resume 后不应重复给前端推 put_pending。
|
||||
|
||||
流程:
|
||||
1. 第一次跑:进 HumanApproval → put_pending 1 次 → 用户没回复 → 中断
|
||||
2. 第二次跑(resume):用户已回复 approve → HumanApproval 第二次进入但
|
||||
不应该再 put_pending;只读 reply 通过。
|
||||
"""
|
||||
from kilostar.core.work.workflow.workflow_engine import (
|
||||
Initialize,
|
||||
WorkflowGraphState,
|
||||
)
|
||||
|
||||
db, write, read = _make_in_memory_io()
|
||||
persistence_a = PostgresStatePersistence(
|
||||
trace_id="t-hitl", write_history=write, read_history=read
|
||||
)
|
||||
persistence_a.set_graph_types(workflow_graph)
|
||||
|
||||
# 第一次:reply 队列空 → get_received 返回空串 → HumanApproval 走拒绝路径之前
|
||||
# 我们要在 put_pending 后停下,所以用 iter 手动驱动,跑到 HumanApproval 内部前停住。
|
||||
# 简化策略:直接让第一次跑完到 Finalize FAILED(reply=""),第二次 resume 验证幂等
|
||||
# 不太合适——FAILED 后没法 resume。改用:第一次 reply=""(拒绝),用 graph_state
|
||||
# 里 approvals_notified 字段来直接验证幂等性更纯粹。
|
||||
|
||||
deps_a, sink_a = _make_deps(
|
||||
skill_outputs=[("step-1", True)],
|
||||
received_replies=[""], # 空 reply → HumanApproval 拒绝
|
||||
)
|
||||
workflow_data_step = {
|
||||
"step": 1, "name": "approve-me", "action": "do",
|
||||
"node": "skill_individual", "agent_id": "a1",
|
||||
"require_approval": True,
|
||||
}
|
||||
|
||||
# 第一次跑:单步 require_approval=True,reply="" → put_pending 1 次 + Finalize FAILED
|
||||
state = WorkflowGraphState(
|
||||
trace_id="t-hitl",
|
||||
blackboard={},
|
||||
work_link=[dict(workflow_data_step)],
|
||||
original_command="",
|
||||
)
|
||||
async with workflow_graph.iter(
|
||||
Initialize(),
|
||||
state=state,
|
||||
deps=deps_a,
|
||||
persistence=persistence_a,
|
||||
) as run:
|
||||
async for _ in run:
|
||||
pass
|
||||
pending_count_first = len(sink_a["pending"])
|
||||
# 至少有审批提示 + 最终 FAILED 提示这两类 pending
|
||||
approval_msgs_first = [
|
||||
m for _, m in sink_a["pending"] if "需要人工审批" in m
|
||||
]
|
||||
assert len(approval_msgs_first) == 1, "第一次跑必须发审批提示一次"
|
||||
# 关键:state 已经记下这次通知
|
||||
assert state.approvals_notified == [0]
|
||||
|
||||
# 第二次:直接构造一个 state,approvals_notified 里已有 0,再次进 HumanApproval
|
||||
# 模拟"resume 后 graph 重新进入审批节点"——put_pending 不应再发
|
||||
deps_b, sink_b = _make_deps(
|
||||
received_replies=["approve"], # 这次用户回复 approve
|
||||
)
|
||||
state_resume = WorkflowGraphState(
|
||||
trace_id="t-hitl",
|
||||
blackboard={},
|
||||
work_link=[dict(workflow_data_step)],
|
||||
original_command="",
|
||||
approvals_notified=[0], # 关键:之前已通知过
|
||||
)
|
||||
persistence_b = PostgresStatePersistence(
|
||||
trace_id="t-hitl-resume", write_history=write, read_history=read
|
||||
)
|
||||
persistence_b.set_graph_types(workflow_graph)
|
||||
# 直接从 Dispatch 起跑(跳过 Initialize 避免它再次 update_workflow_status)
|
||||
from kilostar.core.work.workflow.workflow_engine import Dispatch
|
||||
|
||||
async with workflow_graph.iter(
|
||||
Dispatch(),
|
||||
state=state_resume,
|
||||
deps=deps_b,
|
||||
persistence=persistence_b,
|
||||
) as run:
|
||||
async for _ in run:
|
||||
pass
|
||||
|
||||
approval_msgs_second = [
|
||||
m for _, m in sink_b["pending"] if "需要人工审批" in m
|
||||
]
|
||||
# 关键断言:resume 路径上不应该再出现"需要人工审批"提示
|
||||
assert len(approval_msgs_second) == 0
|
||||
# approve 通过 → require_approval 置 False → 真正跑到 SkillStep
|
||||
assert len(sink_b["skill_calls"]) == 1
|
||||
assert sink_b["skill_calls"][0][0] == "approve-me"
|
||||
|
||||
|
||||
# ─── mermaid 高亮 ─────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_mermaid_highlight_visited_nodes_from_history():
|
||||
"""跑一次 graph 后,dump_json 出来的 history 应足以提取出 visited 节点类名。
|
||||
|
||||
这个测试模拟 ``/graph`` API 的过滤逻辑:
|
||||
- 只取 status="success" 的 NodeSnapshot
|
||||
- id 形如 "ClassName:hash",截前缀
|
||||
- 去重保序
|
||||
|
||||
然后用提取出的 visited 调 ``mermaid_code(highlighted_nodes=...)``,应能在
|
||||
输出里看到对应 ``class X highlighted`` 行。
|
||||
"""
|
||||
import json
|
||||
|
||||
db, write, read = _make_in_memory_io()
|
||||
persistence = PostgresStatePersistence(
|
||||
trace_id="t-mermaid", write_history=write, read_history=read
|
||||
)
|
||||
persistence.set_graph_types(workflow_graph)
|
||||
|
||||
deps, _ = _make_deps(skill_outputs=[("ok", True)])
|
||||
workflow_data = {
|
||||
"work_link": [
|
||||
{"step": 1, "name": "s1", "action": "do",
|
||||
"node": "skill_individual", "agent_id": "a1"},
|
||||
]
|
||||
}
|
||||
final = await run_workflow_graph(
|
||||
workflow_data, "t-mermaid", deps=deps, persistence=persistence
|
||||
)
|
||||
assert final == WorkflowStatus.COMPLETED.value
|
||||
|
||||
# 从 DB 读 history(实际是 list[dict])
|
||||
history = db["t-mermaid"]
|
||||
assert isinstance(history, list)
|
||||
|
||||
# 复刻 API 的过滤逻辑
|
||||
seen: set[str] = set()
|
||||
visited: list[str] = []
|
||||
for entry in history:
|
||||
if not isinstance(entry, dict) or entry.get("kind") != "node":
|
||||
continue
|
||||
if entry.get("status") != "success":
|
||||
continue
|
||||
sid = entry.get("id") or ""
|
||||
cls_name = sid.split(":", 1)[0] if sid else ""
|
||||
if cls_name and cls_name not in seen:
|
||||
seen.add(cls_name)
|
||||
visited.append(cls_name)
|
||||
|
||||
# 至少应该 visit 过 Initialize / Dispatch / SkillStep
|
||||
assert "Initialize" in visited
|
||||
assert "Dispatch" in visited
|
||||
assert "SkillStep" in visited
|
||||
|
||||
mermaid = workflow_graph.mermaid_code(highlighted_nodes=visited)
|
||||
# 高亮一定会带 classDef + class 行
|
||||
assert "classDef" in mermaid
|
||||
assert "class Initialize" in mermaid
|
||||
assert "class SkillStep" in mermaid
|
||||
@@ -0,0 +1,88 @@
|
||||
"""``KiloStarWorkflow`` 的 ``validate_workflow_integrity`` 校验器覆盖。"""
|
||||
|
||||
import pytest
|
||||
from pydantic import ValidationError
|
||||
|
||||
from kilostar.core.work.workflow.workflow import KiloStarWorkflow, WorkflowStep
|
||||
from kilostar.core.work.workflow.model import LogicGate, WorkflowMetadata
|
||||
|
||||
|
||||
def _step(
|
||||
n: int, *, fail_target: str | None = None, name: str | None = None
|
||||
) -> WorkflowStep:
|
||||
return WorkflowStep(
|
||||
step=n,
|
||||
name=name or f"step_{n}",
|
||||
action="noop",
|
||||
logic_gate=LogicGate(if_fail=fail_target) if fail_target else None,
|
||||
)
|
||||
|
||||
|
||||
def test_valid_workflow_passes():
|
||||
wf = KiloStarWorkflow(
|
||||
title="t",
|
||||
work_link=[
|
||||
_step(1),
|
||||
_step(2, fail_target="jump_to_step_1"),
|
||||
_step(3),
|
||||
],
|
||||
workflow_metadata=WorkflowMetadata(),
|
||||
)
|
||||
assert wf.title == "t"
|
||||
assert len(wf.work_link) == 3
|
||||
|
||||
|
||||
def test_non_continuous_steps_rejected():
|
||||
with pytest.raises(ValidationError) as exc:
|
||||
KiloStarWorkflow(
|
||||
title="t",
|
||||
work_link=[_step(1), _step(3)],
|
||||
workflow_metadata=WorkflowMetadata(),
|
||||
)
|
||||
assert "工作链步数不连续" in str(exc.value)
|
||||
|
||||
|
||||
def test_jump_target_out_of_range_rejected():
|
||||
with pytest.raises(ValidationError) as exc:
|
||||
KiloStarWorkflow(
|
||||
title="t",
|
||||
work_link=[
|
||||
_step(1),
|
||||
_step(2, fail_target="jump_to_step_99"),
|
||||
],
|
||||
workflow_metadata=WorkflowMetadata(),
|
||||
)
|
||||
assert "越界" in str(exc.value)
|
||||
|
||||
|
||||
def test_jump_target_below_one_rejected():
|
||||
with pytest.raises(ValidationError) as exc:
|
||||
KiloStarWorkflow(
|
||||
title="t",
|
||||
work_link=[
|
||||
_step(1, fail_target="jump_to_step_0"),
|
||||
],
|
||||
workflow_metadata=WorkflowMetadata(),
|
||||
)
|
||||
assert "越界" in str(exc.value)
|
||||
|
||||
|
||||
def test_logic_gate_without_jump_is_ignored():
|
||||
"""没有 ``jump_to_step_`` 前缀的 if_fail 不进入越界判断。"""
|
||||
wf = KiloStarWorkflow(
|
||||
title="t",
|
||||
work_link=[
|
||||
_step(1, fail_target="abort"),
|
||||
],
|
||||
workflow_metadata=WorkflowMetadata(),
|
||||
)
|
||||
assert wf.work_link[0].logic_gate.if_fail == "abort"
|
||||
|
||||
|
||||
def test_default_trace_id_is_set():
|
||||
wf = KiloStarWorkflow(
|
||||
title="t",
|
||||
work_link=[_step(1)],
|
||||
workflow_metadata=WorkflowMetadata(),
|
||||
)
|
||||
assert wf.trace_id and isinstance(wf.trace_id, str)
|
||||
@@ -1,101 +0,0 @@
|
||||
import sys
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
|
||||
# Mock dependencies before importing the module under test
|
||||
class MockHTTPException(Exception):
|
||||
def __init__(self, status_code, detail=None, headers=None):
|
||||
self.status_code = status_code
|
||||
self.detail = detail
|
||||
self.headers = headers
|
||||
|
||||
|
||||
class MockValidationError(Exception):
|
||||
pass
|
||||
|
||||
|
||||
mock_fastapi = MagicMock()
|
||||
mock_fastapi.HTTPException = MockHTTPException
|
||||
mock_fastapi.status.HTTP_401_UNAUTHORIZED = 401
|
||||
|
||||
mock_pydantic = MagicMock()
|
||||
mock_pydantic.ValidationError = MockValidationError
|
||||
|
||||
sys.modules["fastapi"] = mock_fastapi
|
||||
sys.modules["pydantic"] = mock_pydantic
|
||||
sys.modules["sqlmodel"] = MagicMock()
|
||||
sys.modules["passlib"] = MagicMock()
|
||||
sys.modules["passlib.context"] = MagicMock()
|
||||
sys.modules["kilostar.core.database.table.user"] = MagicMock()
|
||||
|
||||
import pytest # noqa: E402
|
||||
import jwt # noqa: E402
|
||||
from kilostar.utils.access import Accessor # noqa: E402
|
||||
|
||||
|
||||
def test_decode_token_success():
|
||||
"""Test successful token decoding."""
|
||||
token = "valid.token.here"
|
||||
payload = {"user_id": "123", "username": "testuser", "exp": 1234567890}
|
||||
|
||||
with patch("jwt.decode", return_value=payload) as mock_decode:
|
||||
with patch("kilostar.utils.access.TokenData") as mock_token_data_cls:
|
||||
mock_token_data_instance = MagicMock()
|
||||
mock_token_data_cls.return_value = mock_token_data_instance
|
||||
|
||||
result = Accessor._decode_token(token)
|
||||
|
||||
mock_decode.assert_called_once()
|
||||
mock_token_data_cls.assert_called_once_with(**payload)
|
||||
assert result == mock_token_data_instance
|
||||
|
||||
|
||||
def test_decode_token_expired():
|
||||
"""Test token decoding with an expired token."""
|
||||
token = "expired.token.here"
|
||||
|
||||
from fastapi import HTTPException
|
||||
|
||||
with patch("jwt.decode", side_effect=jwt.ExpiredSignatureError):
|
||||
with patch("kilostar.utils.access.HTTPException", HTTPException):
|
||||
with pytest.raises(HTTPException) as excinfo:
|
||||
Accessor._decode_token(token)
|
||||
|
||||
assert excinfo.value.status_code == 401
|
||||
assert excinfo.value.detail == "Token 已过期"
|
||||
|
||||
|
||||
def test_decode_token_invalid():
|
||||
"""Test token decoding with an invalid token."""
|
||||
token = "invalid.token.here"
|
||||
|
||||
from fastapi import HTTPException
|
||||
|
||||
with patch("jwt.decode", side_effect=jwt.InvalidTokenError):
|
||||
with patch("kilostar.utils.access.HTTPException", HTTPException):
|
||||
with pytest.raises(HTTPException) as excinfo:
|
||||
Accessor._decode_token(token)
|
||||
|
||||
assert excinfo.value.status_code == 401
|
||||
assert excinfo.value.detail == "无效的认证凭证"
|
||||
|
||||
|
||||
def test_decode_token_validation_error():
|
||||
"""Test token decoding with a payload that fails validation."""
|
||||
token = "valid.jwt.invalid.payload"
|
||||
payload = {"wrong": "payload"}
|
||||
|
||||
from fastapi import HTTPException
|
||||
|
||||
with patch("jwt.decode", return_value=payload):
|
||||
with patch("kilostar.utils.access.TokenData", side_effect=MockValidationError):
|
||||
with patch("kilostar.utils.access.ValidationError", MockValidationError):
|
||||
with patch("kilostar.utils.access.HTTPException", HTTPException):
|
||||
with pytest.raises(HTTPException) as excinfo:
|
||||
Accessor._decode_token(token)
|
||||
|
||||
assert excinfo.value.status_code == 401
|
||||
assert excinfo.value.detail == "无效的认证凭证"
|
||||
|
||||
|
||||
# noqa: E402
|
||||
Reference in New Issue
Block a user