b15eeb9e74
前端/DB 仍用 toolset 做逻辑分组管理,但传给 pydantic-ai Agent 时 把 toolset 内的 callable 展开为 tools=[] 扁平列表,MCP server 等 需要 toolset 语义的单独走 toolsets=[] 参数。解决工具"存在但调不了"的问题。 Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
353 lines
12 KiB
Python
353 lines
12 KiB
Python
"""GSM 配置快照(Object Store 读路径)相关测试。
|
||
|
||
主要验证:
|
||
|
||
- ``GSMSnapshot`` 数据类可被 cloudpickle 序列化(ray.put 的隐式约束)
|
||
- ``_build_snapshot`` 正确从 6 类内存状态打包配置
|
||
- ``_publish_snapshot`` 让 version 单调递增并刷新 ObjectRef
|
||
- 写入路径(add_individual / set_tool_config / 等)会自动发布新快照
|
||
- ``fetch_snapshot`` 客户端:版本号一致时走本地缓存,不一致时重拉
|
||
"""
|
||
|
||
from __future__ import annotations
|
||
|
||
import asyncio
|
||
import pickle
|
||
from types import SimpleNamespace
|
||
from unittest.mock import AsyncMock, MagicMock
|
||
|
||
import pytest
|
||
|
||
# cloudpickle 是 ray 的传递依赖,不直接列在 pyproject 里 —— 通过 ray._private 拿
|
||
from ray import cloudpickle
|
||
|
||
from kilostar.core.global_state_machine.gsm_snapshot import (
|
||
GSMSnapshot,
|
||
fetch_snapshot,
|
||
reset_local_cache,
|
||
)
|
||
|
||
|
||
def test_empty_snapshot_can_cloudpickle_roundtrip():
|
||
"""空 snapshot 序列化反序列化语义不变(ray.put 的最低约束)。"""
|
||
snap = GSMSnapshot()
|
||
blob = cloudpickle.dumps(snap)
|
||
restored: GSMSnapshot = cloudpickle.loads(blob)
|
||
assert restored.version == 0
|
||
assert restored.providers == {}
|
||
assert restored.individuals == {}
|
||
|
||
|
||
def test_snapshot_with_real_data_roundtrip():
|
||
"""带真实 Provider + 函数引用 + dict 数据的 snapshot 也能 round-trip。"""
|
||
from kilostar.core.global_state_machine.model_provider.base_provider import (
|
||
Provider,
|
||
)
|
||
|
||
def _sample_tool(query: str) -> str:
|
||
return f"echo:{query}"
|
||
|
||
snap = GSMSnapshot(
|
||
version=42,
|
||
providers={
|
||
"p1": Provider(
|
||
provider_title="p1",
|
||
provider_url="http://x",
|
||
provider_apikey="sk-x",
|
||
provider_models=["gpt-4o"],
|
||
provider_type="openai",
|
||
),
|
||
},
|
||
individuals={"agent-a": {"agent_id": "agent-a", "model_id": "gpt-4o"}},
|
||
tool_funcs={"echo": _sample_tool},
|
||
)
|
||
blob = cloudpickle.dumps(snap)
|
||
restored: GSMSnapshot = cloudpickle.loads(blob)
|
||
assert restored.version == 42
|
||
assert restored.providers["p1"].provider_title == "p1"
|
||
assert restored.individuals["agent-a"]["model_id"] == "gpt-4o"
|
||
# 模块级函数 cloudpickle 后仍可调用
|
||
# 注意:此处函数是测试模块的局部,cloudpickle 会把字节码一并序列化
|
||
assert restored.tool_funcs["echo"]("hi") == "echo:hi"
|
||
|
||
|
||
# ─── GSM actor 集成(绕过 @ray.remote 直接构造) ────────────────────
|
||
|
||
|
||
@pytest.fixture
|
||
def gsm_instance(monkeypatch):
|
||
from kilostar.core.global_state_machine.global_state_machine import (
|
||
GlobalStateMachine,
|
||
)
|
||
|
||
cls = GlobalStateMachine.__ray_actor_class__
|
||
obj = cls.__new__(cls)
|
||
# 手动还原 __init__ 副作用
|
||
from kilostar.core.global_state_machine.individual_manager import (
|
||
GlobalIndividualManager,
|
||
)
|
||
from kilostar.core.global_state_machine.provider_manager import ProviderManager
|
||
from kilostar.core.global_state_machine.skill_manager import GlobalSkillManager
|
||
from kilostar.core.global_state_machine.tool_manager import GlobalToolManager
|
||
|
||
obj._global_provider_manager = ProviderManager(postgres=None)
|
||
obj._global_tool_manager = GlobalToolManager()
|
||
obj._global_skill_manager = GlobalSkillManager()
|
||
obj._global_individual_manager = GlobalIndividualManager()
|
||
obj._mcp_servers = {}
|
||
obj._tool_configs = {}
|
||
obj._custom_toolsets = {}
|
||
obj._config_version = 0
|
||
obj._current_ref = None
|
||
obj.postgres_database = MagicMock()
|
||
|
||
# ray.put 在测试沙箱里因 psutil PID 检查失败,mock 成"返回一个 sentinel ref"
|
||
# 我们关心的是 _publish_snapshot 的语义流,不是真把对象塞进 plasma
|
||
import kilostar.core.global_state_machine.global_state_machine as gsm_mod
|
||
|
||
counter = {"n": 0}
|
||
|
||
def _fake_put(snapshot):
|
||
counter["n"] += 1
|
||
return f"fake-ref-{counter['n']}"
|
||
|
||
monkeypatch.setattr(gsm_mod.ray, "put", _fake_put)
|
||
return obj
|
||
|
||
|
||
def test_build_snapshot_picks_up_all_six_categories(gsm_instance):
|
||
"""_build_snapshot 应正确从 GSM 内存的 6 类数据打包。"""
|
||
from kilostar.core.global_state_machine.model_provider.base_provider import (
|
||
Provider,
|
||
)
|
||
|
||
gsm_instance._global_provider_manager.provider_register["p1"] = Provider(
|
||
provider_title="p1",
|
||
provider_url="http://x",
|
||
provider_apikey="k",
|
||
provider_models=[],
|
||
provider_type="openai",
|
||
)
|
||
gsm_instance._global_individual_manager._individuals["a1"] = {"agent_id": "a1"}
|
||
gsm_instance._mcp_servers["s1"] = {"server_id": "s1"}
|
||
gsm_instance._tool_configs["t1"] = {"key": "v"}
|
||
gsm_instance._custom_toolsets["ts1"] = {"toolset_id": "ts1"}
|
||
|
||
snap = gsm_instance._build_snapshot()
|
||
|
||
assert "p1" in snap.providers
|
||
assert "a1" in snap.individuals
|
||
assert "s1" in snap.mcp_servers
|
||
assert "t1" in snap.tool_configs
|
||
assert "ts1" in snap.custom_toolsets
|
||
|
||
|
||
def test_build_snapshot_exposes_system_tools_by_scope(gsm_instance):
|
||
"""系统工具按 scope 分桶的工具名清单要随快照发布出去(客户端重建 toolset 用)。"""
|
||
tm = gsm_instance._global_tool_manager
|
||
# 模拟 tool_manager 内部状态:default scope 有 file_reader,control_node 有 approval
|
||
def _f1():
|
||
return "f1"
|
||
|
||
def _f2():
|
||
return "f2"
|
||
|
||
tm._tool_funcs.clear()
|
||
tm._tool_funcs["default"]["file_reader"] = _f1
|
||
tm._tool_funcs["control_node"]["approval"] = _f2
|
||
|
||
snap = gsm_instance._build_snapshot()
|
||
assert snap.system_tools_by_scope.get("default") == ["file_reader"]
|
||
assert snap.system_tools_by_scope.get("control_node") == ["approval"]
|
||
# tool_funcs 拍平后两者都应存在
|
||
assert set(snap.tool_funcs.keys()) == {"file_reader", "approval"}
|
||
|
||
|
||
def test_publish_snapshot_increments_version(gsm_instance):
|
||
assert gsm_instance._config_version == 0
|
||
assert gsm_instance._current_ref is None
|
||
|
||
gsm_instance._publish_snapshot()
|
||
v1 = gsm_instance._config_version
|
||
ref1 = gsm_instance._current_ref
|
||
assert v1 == 1
|
||
assert ref1 is not None
|
||
|
||
gsm_instance._publish_snapshot()
|
||
assert gsm_instance._config_version == 2
|
||
assert gsm_instance._current_ref is not ref1 # 新 put 应是新 ref
|
||
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_current_config_ref_lazy_publishes_when_empty(gsm_instance):
|
||
"""从未发布过快照时,current_config_ref 应自动发布一次而不是返回 None。"""
|
||
version, ref = await gsm_instance.current_config_ref()
|
||
assert version == 1
|
||
assert ref is not None
|
||
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_current_version_is_lightweight(gsm_instance):
|
||
gsm_instance._publish_snapshot()
|
||
gsm_instance._publish_snapshot()
|
||
assert await gsm_instance.current_version() == 2
|
||
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_add_individual_publishes_new_snapshot(gsm_instance):
|
||
"""写入路径 add_individual 应自动 +1 version。"""
|
||
before = gsm_instance._config_version
|
||
gsm_instance.add_individual("agent-x", {"model_id": "gpt-4o"})
|
||
after = gsm_instance._config_version
|
||
assert after == before + 1
|
||
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_add_provider_wrap_publishes_new_snapshot(gsm_instance):
|
||
"""add_provider_wrap 即便走 mock 适配器也应该最终发布一次新快照。"""
|
||
from kilostar.core.global_state_machine.model_provider.base_provider import (
|
||
Provider,
|
||
)
|
||
|
||
fake_provider = Provider(
|
||
provider_title="my-openai",
|
||
provider_url="http://x",
|
||
provider_apikey="k",
|
||
provider_models=[],
|
||
provider_type="openai",
|
||
)
|
||
gsm_instance._global_provider_manager.provider_mapper["openai"] = MagicMock()
|
||
gsm_instance._global_provider_manager.provider_mapper[
|
||
"openai"
|
||
].create_provider = AsyncMock(return_value=fake_provider)
|
||
gsm_instance.postgres_database.add_provider_db = MagicMock()
|
||
gsm_instance.postgres_database.add_provider_db.remote = AsyncMock()
|
||
|
||
before = gsm_instance._config_version
|
||
await gsm_instance.add_provider_wrap(
|
||
provider_type="openai",
|
||
provider_title="my-openai",
|
||
provider_url="http://x",
|
||
provider_apikey="k",
|
||
provider_owner="alice",
|
||
)
|
||
after = gsm_instance._config_version
|
||
assert after == before + 1
|
||
|
||
|
||
# ─── fetch_snapshot 客户端缓存 ────────────────────────────────────
|
||
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_fetch_snapshot_uses_local_cache_when_version_matches():
|
||
"""模拟 GSM actor,验证版本号一致时不走 ray.get。"""
|
||
reset_local_cache()
|
||
snap = GSMSnapshot(version=5, providers={"p": MagicMock()})
|
||
|
||
# mock GSM handle:第一次 fetch 全走,第二次只 current_version
|
||
fake_gsm = MagicMock()
|
||
fake_gsm.current_version = MagicMock()
|
||
fake_gsm.current_version.remote = AsyncMock(return_value=5)
|
||
fake_gsm.current_config_ref = MagicMock()
|
||
|
||
# 提前把缓存预热成 v5(模拟之前已经 fetch 过)
|
||
from kilostar.core.global_state_machine import gsm_snapshot as snap_mod
|
||
|
||
snap_mod._local_cache["version"] = 5
|
||
snap_mod._local_cache["snapshot"] = snap
|
||
|
||
# 不 mock current_config_ref —— 如果它被调用了,AttributeError 会让测试失败
|
||
fake_gsm.current_config_ref.remote = AsyncMock(
|
||
side_effect=AssertionError("不应触发:缓存版本一致时不应调 current_config_ref")
|
||
)
|
||
|
||
result = await fetch_snapshot(gsm_actor=fake_gsm)
|
||
assert result is snap
|
||
fake_gsm.current_version.remote.assert_awaited_once()
|
||
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_fetch_snapshot_refetches_when_version_changes(monkeypatch):
|
||
"""版本号变了应重新 ray.get 拉新 snapshot。"""
|
||
reset_local_cache()
|
||
new_snap = GSMSnapshot(version=10)
|
||
|
||
fake_gsm = MagicMock()
|
||
fake_gsm.current_version = MagicMock()
|
||
fake_gsm.current_version.remote = AsyncMock(return_value=10)
|
||
fake_gsm.current_config_ref = MagicMock()
|
||
fake_gsm.current_config_ref.remote = AsyncMock(return_value=(10, "fake-ref"))
|
||
|
||
# mock ray.get 让它直接返回我们准备的 snap
|
||
import kilostar.core.global_state_machine.gsm_snapshot as snap_mod
|
||
|
||
monkeypatch.setattr(snap_mod.ray, "get", lambda ref: new_snap)
|
||
|
||
result = await fetch_snapshot(gsm_actor=fake_gsm)
|
||
assert result is new_snap
|
||
fake_gsm.current_config_ref.remote.assert_awaited_once()
|
||
# 缓存应已更新到 v10
|
||
assert snap_mod._local_cache["version"] == 10
|
||
|
||
|
||
@pytest.mark.asyncio
|
||
async def test_fetch_snapshot_use_cache_false_skips_cache(monkeypatch):
|
||
"""``use_cache=False`` 直接走 current_config_ref,不读本地缓存。"""
|
||
reset_local_cache()
|
||
fresh = GSMSnapshot(version=1)
|
||
|
||
fake_gsm = MagicMock()
|
||
fake_gsm.current_config_ref = MagicMock()
|
||
fake_gsm.current_config_ref.remote = AsyncMock(return_value=(1, "ref"))
|
||
|
||
import kilostar.core.global_state_machine.gsm_snapshot as snap_mod
|
||
|
||
monkeypatch.setattr(snap_mod.ray, "get", lambda ref: fresh)
|
||
|
||
result = await fetch_snapshot(gsm_actor=fake_gsm, use_cache=False)
|
||
assert result is fresh
|
||
|
||
|
||
# ─── build_tools_for_scope 客户端 helper ────────────────────────
|
||
|
||
|
||
def test_build_tools_for_scope_assembles_system_and_custom():
|
||
"""客户端按 snapshot 的 custom_toolsets + all_funcs 展开为扁平 callable 列表。"""
|
||
from kilostar.core.global_state_machine.gsm_snapshot import (
|
||
build_tools_for_scope,
|
||
)
|
||
|
||
def _sys_default():
|
||
return "d"
|
||
|
||
def _sys_scope():
|
||
return "s"
|
||
|
||
def _tp_a():
|
||
return "a"
|
||
|
||
snap = GSMSnapshot(
|
||
all_funcs={"sys_default": _sys_default, "sys_scope": _sys_scope, "tp_a": _tp_a},
|
||
custom_toolsets={
|
||
"system_basic": {"toolset_id": "system_basic", "tools": ["sys_default", "sys_scope"]},
|
||
"grp": {"toolset_id": "grp", "tools": ["tp_a"]},
|
||
},
|
||
)
|
||
|
||
result = build_tools_for_scope(snap, "control_node")
|
||
assert len(result) == 3
|
||
assert result == [_sys_default, _sys_scope, _tp_a]
|
||
|
||
|
||
def test_build_tools_for_scope_skips_empty_buckets():
|
||
"""没有工具的 scope 返回空列表。"""
|
||
from kilostar.core.global_state_machine.gsm_snapshot import (
|
||
build_tools_for_scope,
|
||
)
|
||
|
||
snap = GSMSnapshot(
|
||
all_funcs={},
|
||
custom_toolsets={},
|
||
)
|
||
assert build_tools_for_scope(snap, "control_node") == []
|