Files
KiloStar/tests/unit/test_gsm_snapshot.py
zhaoxi b15eeb9e74 fix(toolset): 工具传递改为展开的 tools 列表,不再用 FunctionToolset 包装
前端/DB 仍用 toolset 做逻辑分组管理,但传给 pydantic-ai Agent 时
把 toolset 内的 callable 展开为 tools=[] 扁平列表,MCP server 等
需要 toolset 语义的单独走 toolsets=[] 参数。解决工具"存在但调不了"的问题。

Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
2026-06-05 19:05:59 +00:00

353 lines
12 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""GSM 配置快照(Object Store 读路径)相关测试。
主要验证:
- ``GSMSnapshot`` 数据类可被 cloudpickle 序列化(ray.put 的隐式约束)
- ``_build_snapshot`` 正确从 6 类内存状态打包配置
- ``_publish_snapshot`` 让 version 单调递增并刷新 ObjectRef
- 写入路径(add_individual / set_tool_config / 等)会自动发布新快照
- ``fetch_snapshot`` 客户端:版本号一致时走本地缓存,不一致时重拉
"""
from __future__ import annotations
import asyncio
import pickle
from types import SimpleNamespace
from unittest.mock import AsyncMock, MagicMock
import pytest
# cloudpickle 是 ray 的传递依赖,不直接列在 pyproject 里 —— 通过 ray._private 拿
from ray import cloudpickle
from kilostar.core.global_state_machine.gsm_snapshot import (
GSMSnapshot,
fetch_snapshot,
reset_local_cache,
)
def test_empty_snapshot_can_cloudpickle_roundtrip():
"""空 snapshot 序列化反序列化语义不变(ray.put 的最低约束)。"""
snap = GSMSnapshot()
blob = cloudpickle.dumps(snap)
restored: GSMSnapshot = cloudpickle.loads(blob)
assert restored.version == 0
assert restored.providers == {}
assert restored.individuals == {}
def test_snapshot_with_real_data_roundtrip():
"""带真实 Provider + 函数引用 + dict 数据的 snapshot 也能 round-trip。"""
from kilostar.core.global_state_machine.model_provider.base_provider import (
Provider,
)
def _sample_tool(query: str) -> str:
return f"echo:{query}"
snap = GSMSnapshot(
version=42,
providers={
"p1": Provider(
provider_title="p1",
provider_url="http://x",
provider_apikey="sk-x",
provider_models=["gpt-4o"],
provider_type="openai",
),
},
individuals={"agent-a": {"agent_id": "agent-a", "model_id": "gpt-4o"}},
tool_funcs={"echo": _sample_tool},
)
blob = cloudpickle.dumps(snap)
restored: GSMSnapshot = cloudpickle.loads(blob)
assert restored.version == 42
assert restored.providers["p1"].provider_title == "p1"
assert restored.individuals["agent-a"]["model_id"] == "gpt-4o"
# 模块级函数 cloudpickle 后仍可调用
# 注意:此处函数是测试模块的局部,cloudpickle 会把字节码一并序列化
assert restored.tool_funcs["echo"]("hi") == "echo:hi"
# ─── GSM actor 集成(绕过 @ray.remote 直接构造) ────────────────────
@pytest.fixture
def gsm_instance(monkeypatch):
from kilostar.core.global_state_machine.global_state_machine import (
GlobalStateMachine,
)
cls = GlobalStateMachine.__ray_actor_class__
obj = cls.__new__(cls)
# 手动还原 __init__ 副作用
from kilostar.core.global_state_machine.individual_manager import (
GlobalIndividualManager,
)
from kilostar.core.global_state_machine.provider_manager import ProviderManager
from kilostar.core.global_state_machine.skill_manager import GlobalSkillManager
from kilostar.core.global_state_machine.tool_manager import GlobalToolManager
obj._global_provider_manager = ProviderManager(postgres=None)
obj._global_tool_manager = GlobalToolManager()
obj._global_skill_manager = GlobalSkillManager()
obj._global_individual_manager = GlobalIndividualManager()
obj._mcp_servers = {}
obj._tool_configs = {}
obj._custom_toolsets = {}
obj._config_version = 0
obj._current_ref = None
obj.postgres_database = MagicMock()
# ray.put 在测试沙箱里因 psutil PID 检查失败,mock 成"返回一个 sentinel ref"
# 我们关心的是 _publish_snapshot 的语义流,不是真把对象塞进 plasma
import kilostar.core.global_state_machine.global_state_machine as gsm_mod
counter = {"n": 0}
def _fake_put(snapshot):
counter["n"] += 1
return f"fake-ref-{counter['n']}"
monkeypatch.setattr(gsm_mod.ray, "put", _fake_put)
return obj
def test_build_snapshot_picks_up_all_six_categories(gsm_instance):
"""_build_snapshot 应正确从 GSM 内存的 6 类数据打包。"""
from kilostar.core.global_state_machine.model_provider.base_provider import (
Provider,
)
gsm_instance._global_provider_manager.provider_register["p1"] = Provider(
provider_title="p1",
provider_url="http://x",
provider_apikey="k",
provider_models=[],
provider_type="openai",
)
gsm_instance._global_individual_manager._individuals["a1"] = {"agent_id": "a1"}
gsm_instance._mcp_servers["s1"] = {"server_id": "s1"}
gsm_instance._tool_configs["t1"] = {"key": "v"}
gsm_instance._custom_toolsets["ts1"] = {"toolset_id": "ts1"}
snap = gsm_instance._build_snapshot()
assert "p1" in snap.providers
assert "a1" in snap.individuals
assert "s1" in snap.mcp_servers
assert "t1" in snap.tool_configs
assert "ts1" in snap.custom_toolsets
def test_build_snapshot_exposes_system_tools_by_scope(gsm_instance):
"""系统工具按 scope 分桶的工具名清单要随快照发布出去(客户端重建 toolset 用)。"""
tm = gsm_instance._global_tool_manager
# 模拟 tool_manager 内部状态:default scope 有 file_readercontrol_node 有 approval
def _f1():
return "f1"
def _f2():
return "f2"
tm._tool_funcs.clear()
tm._tool_funcs["default"]["file_reader"] = _f1
tm._tool_funcs["control_node"]["approval"] = _f2
snap = gsm_instance._build_snapshot()
assert snap.system_tools_by_scope.get("default") == ["file_reader"]
assert snap.system_tools_by_scope.get("control_node") == ["approval"]
# tool_funcs 拍平后两者都应存在
assert set(snap.tool_funcs.keys()) == {"file_reader", "approval"}
def test_publish_snapshot_increments_version(gsm_instance):
assert gsm_instance._config_version == 0
assert gsm_instance._current_ref is None
gsm_instance._publish_snapshot()
v1 = gsm_instance._config_version
ref1 = gsm_instance._current_ref
assert v1 == 1
assert ref1 is not None
gsm_instance._publish_snapshot()
assert gsm_instance._config_version == 2
assert gsm_instance._current_ref is not ref1 # 新 put 应是新 ref
@pytest.mark.asyncio
async def test_current_config_ref_lazy_publishes_when_empty(gsm_instance):
"""从未发布过快照时,current_config_ref 应自动发布一次而不是返回 None。"""
version, ref = await gsm_instance.current_config_ref()
assert version == 1
assert ref is not None
@pytest.mark.asyncio
async def test_current_version_is_lightweight(gsm_instance):
gsm_instance._publish_snapshot()
gsm_instance._publish_snapshot()
assert await gsm_instance.current_version() == 2
@pytest.mark.asyncio
async def test_add_individual_publishes_new_snapshot(gsm_instance):
"""写入路径 add_individual 应自动 +1 version。"""
before = gsm_instance._config_version
gsm_instance.add_individual("agent-x", {"model_id": "gpt-4o"})
after = gsm_instance._config_version
assert after == before + 1
@pytest.mark.asyncio
async def test_add_provider_wrap_publishes_new_snapshot(gsm_instance):
"""add_provider_wrap 即便走 mock 适配器也应该最终发布一次新快照。"""
from kilostar.core.global_state_machine.model_provider.base_provider import (
Provider,
)
fake_provider = Provider(
provider_title="my-openai",
provider_url="http://x",
provider_apikey="k",
provider_models=[],
provider_type="openai",
)
gsm_instance._global_provider_manager.provider_mapper["openai"] = MagicMock()
gsm_instance._global_provider_manager.provider_mapper[
"openai"
].create_provider = AsyncMock(return_value=fake_provider)
gsm_instance.postgres_database.add_provider_db = MagicMock()
gsm_instance.postgres_database.add_provider_db.remote = AsyncMock()
before = gsm_instance._config_version
await gsm_instance.add_provider_wrap(
provider_type="openai",
provider_title="my-openai",
provider_url="http://x",
provider_apikey="k",
provider_owner="alice",
)
after = gsm_instance._config_version
assert after == before + 1
# ─── fetch_snapshot 客户端缓存 ────────────────────────────────────
@pytest.mark.asyncio
async def test_fetch_snapshot_uses_local_cache_when_version_matches():
"""模拟 GSM actor,验证版本号一致时不走 ray.get。"""
reset_local_cache()
snap = GSMSnapshot(version=5, providers={"p": MagicMock()})
# mock GSM handle:第一次 fetch 全走,第二次只 current_version
fake_gsm = MagicMock()
fake_gsm.current_version = MagicMock()
fake_gsm.current_version.remote = AsyncMock(return_value=5)
fake_gsm.current_config_ref = MagicMock()
# 提前把缓存预热成 v5(模拟之前已经 fetch 过)
from kilostar.core.global_state_machine import gsm_snapshot as snap_mod
snap_mod._local_cache["version"] = 5
snap_mod._local_cache["snapshot"] = snap
# 不 mock current_config_ref —— 如果它被调用了,AttributeError 会让测试失败
fake_gsm.current_config_ref.remote = AsyncMock(
side_effect=AssertionError("不应触发:缓存版本一致时不应调 current_config_ref")
)
result = await fetch_snapshot(gsm_actor=fake_gsm)
assert result is snap
fake_gsm.current_version.remote.assert_awaited_once()
@pytest.mark.asyncio
async def test_fetch_snapshot_refetches_when_version_changes(monkeypatch):
"""版本号变了应重新 ray.get 拉新 snapshot。"""
reset_local_cache()
new_snap = GSMSnapshot(version=10)
fake_gsm = MagicMock()
fake_gsm.current_version = MagicMock()
fake_gsm.current_version.remote = AsyncMock(return_value=10)
fake_gsm.current_config_ref = MagicMock()
fake_gsm.current_config_ref.remote = AsyncMock(return_value=(10, "fake-ref"))
# mock ray.get 让它直接返回我们准备的 snap
import kilostar.core.global_state_machine.gsm_snapshot as snap_mod
monkeypatch.setattr(snap_mod.ray, "get", lambda ref: new_snap)
result = await fetch_snapshot(gsm_actor=fake_gsm)
assert result is new_snap
fake_gsm.current_config_ref.remote.assert_awaited_once()
# 缓存应已更新到 v10
assert snap_mod._local_cache["version"] == 10
@pytest.mark.asyncio
async def test_fetch_snapshot_use_cache_false_skips_cache(monkeypatch):
"""``use_cache=False`` 直接走 current_config_ref,不读本地缓存。"""
reset_local_cache()
fresh = GSMSnapshot(version=1)
fake_gsm = MagicMock()
fake_gsm.current_config_ref = MagicMock()
fake_gsm.current_config_ref.remote = AsyncMock(return_value=(1, "ref"))
import kilostar.core.global_state_machine.gsm_snapshot as snap_mod
monkeypatch.setattr(snap_mod.ray, "get", lambda ref: fresh)
result = await fetch_snapshot(gsm_actor=fake_gsm, use_cache=False)
assert result is fresh
# ─── build_tools_for_scope 客户端 helper ────────────────────────
def test_build_tools_for_scope_assembles_system_and_custom():
"""客户端按 snapshot 的 custom_toolsets + all_funcs 展开为扁平 callable 列表。"""
from kilostar.core.global_state_machine.gsm_snapshot import (
build_tools_for_scope,
)
def _sys_default():
return "d"
def _sys_scope():
return "s"
def _tp_a():
return "a"
snap = GSMSnapshot(
all_funcs={"sys_default": _sys_default, "sys_scope": _sys_scope, "tp_a": _tp_a},
custom_toolsets={
"system_basic": {"toolset_id": "system_basic", "tools": ["sys_default", "sys_scope"]},
"grp": {"toolset_id": "grp", "tools": ["tp_a"]},
},
)
result = build_tools_for_scope(snap, "control_node")
assert len(result) == 3
assert result == [_sys_default, _sys_scope, _tp_a]
def test_build_tools_for_scope_skips_empty_buckets():
"""没有工具的 scope 返回空列表。"""
from kilostar.core.global_state_machine.gsm_snapshot import (
build_tools_for_scope,
)
snap = GSMSnapshot(
all_funcs={},
custom_toolsets={},
)
assert build_tools_for_scope(snap, "control_node") == []