99520c69d7
1.新增后端测试 2.增加了后端的加密 3.增加了i18n(国际化)
186 lines
5.4 KiB
Python
186 lines
5.4 KiB
Python
"""``AgentFactory.create_agent`` 的协议分发与异常分支。
|
|
|
|
为避免真的实例化 OpenAI / Anthropic / Google 等 provider,所有 ``provider_class``
|
|
与 ``model_class`` 都被替换为 spy 类,仅记录构造参数。
|
|
"""
|
|
|
|
from __future__ import annotations
|
|
|
|
from typing import Any, Dict
|
|
from unittest.mock import MagicMock
|
|
|
|
import pytest
|
|
|
|
from kilostar.adapter.model_adapter import agent_factory as af_mod
|
|
from kilostar.adapter.model_adapter.agent_factory import AgentFactory
|
|
from kilostar.core.global_state_machine.model_provider.base_provider import Provider
|
|
from kilostar.utils.agent_model import DepsModel, ResponseModel
|
|
from kilostar.utils.error import ModelNotExistError
|
|
|
|
|
|
class _SpyProvider:
|
|
last_init: Dict[str, Any] = {}
|
|
|
|
def __init__(self, **kwargs):
|
|
type(self).last_init = kwargs
|
|
|
|
|
|
class _SpyModel:
|
|
last_init: Dict[str, Any] = {}
|
|
|
|
def __init__(self, *args, **kwargs):
|
|
type(self).last_init = {"args": args, "kwargs": kwargs}
|
|
|
|
|
|
@pytest.fixture
|
|
def factory(monkeypatch):
|
|
"""构造 AgentFactory 并将 provider/model class 全部替换成 spy 类。"""
|
|
af = AgentFactory()
|
|
for proto in af._models_mapping:
|
|
af._models_mapping[proto]["model_class"] = type(
|
|
f"_Model_{proto}", (_SpyModel,), {}
|
|
)
|
|
af._models_mapping[proto]["provider_class"] = type(
|
|
f"_Provider_{proto}", (_SpyProvider,), {}
|
|
)
|
|
|
|
fake_agent = MagicMock(name="Agent")
|
|
fake_agent_class = MagicMock(name="Agent_class", return_value=fake_agent)
|
|
monkeypatch.setattr(af_mod, "Agent", fake_agent_class)
|
|
return af, fake_agent_class, fake_agent
|
|
|
|
|
|
def _provider(provider_type: str = "openai") -> Provider:
|
|
return Provider(
|
|
provider_title="t",
|
|
provider_url="https://example.com",
|
|
provider_apikey="sk-123",
|
|
provider_models=["m1", "m2"],
|
|
provider_type=provider_type,
|
|
)
|
|
|
|
|
|
def test_create_agent_returns_agent_instance(factory):
|
|
af, agent_cls, fake_agent = factory
|
|
result = af.create_agent(
|
|
provider=_provider("openai"),
|
|
model_id="m1",
|
|
output_type=ResponseModel,
|
|
system_prompt="sys",
|
|
deps_type=DepsModel,
|
|
agent_name="hello",
|
|
)
|
|
assert result is fake_agent
|
|
agent_cls.assert_called_once()
|
|
kwargs = agent_cls.call_args.kwargs
|
|
assert kwargs["name"] == "hello"
|
|
assert kwargs["system_prompt"] == "sys"
|
|
assert kwargs["tools"] == []
|
|
assert kwargs["toolsets"] == []
|
|
|
|
|
|
def test_create_agent_passes_through_tools_and_toolsets(factory):
|
|
af, agent_cls, _ = factory
|
|
tools = [lambda: 1]
|
|
toolsets = [object()]
|
|
af.create_agent(
|
|
provider=_provider("openai"),
|
|
model_id="m1",
|
|
output_type=ResponseModel,
|
|
system_prompt="sys",
|
|
deps_type=DepsModel,
|
|
agent_name="hello",
|
|
tools=tools,
|
|
toolsets=toolsets,
|
|
)
|
|
kwargs = agent_cls.call_args.kwargs
|
|
assert kwargs["tools"] is tools
|
|
assert kwargs["toolsets"] is toolsets
|
|
|
|
|
|
def test_create_agent_unknown_model_raises(factory):
|
|
af, _, _ = factory
|
|
with pytest.raises(ModelNotExistError):
|
|
af.create_agent(
|
|
provider=_provider("openai"),
|
|
model_id="not-in-list",
|
|
output_type=ResponseModel,
|
|
system_prompt="sys",
|
|
deps_type=DepsModel,
|
|
agent_name="hello",
|
|
)
|
|
|
|
|
|
def test_create_agent_unknown_protocol_raises(factory):
|
|
af, _, _ = factory
|
|
bad = _provider("openai")
|
|
bad.provider_type = "weird-protocol"
|
|
with pytest.raises(ValueError):
|
|
af.create_agent(
|
|
provider=bad,
|
|
model_id="m1",
|
|
output_type=ResponseModel,
|
|
system_prompt="sys",
|
|
deps_type=DepsModel,
|
|
agent_name="hello",
|
|
)
|
|
|
|
|
|
def test_openai_protocol_passes_api_key_and_base_url(factory):
|
|
af, _, _ = factory
|
|
af.create_agent(
|
|
provider=_provider("openai"),
|
|
model_id="m1",
|
|
output_type=ResponseModel,
|
|
system_prompt="sys",
|
|
deps_type=DepsModel,
|
|
agent_name="hello",
|
|
)
|
|
init = af._models_mapping["openai"]["provider_class"].last_init
|
|
assert init["api_key"] == "sk-123"
|
|
assert init["base_url"] == "https://example.com"
|
|
|
|
|
|
def test_claude_protocol_passes_api_key_only(factory):
|
|
af, _, _ = factory
|
|
af.create_agent(
|
|
provider=_provider("claude"),
|
|
model_id="m1",
|
|
output_type=ResponseModel,
|
|
system_prompt="sys",
|
|
deps_type=DepsModel,
|
|
agent_name="hello",
|
|
)
|
|
init = af._models_mapping["claude"]["provider_class"].last_init
|
|
assert init["api_key"] == "sk-123"
|
|
assert "base_url" not in init
|
|
|
|
|
|
def test_gemini_protocol_uses_kwarg_model_name(factory):
|
|
af, _, _ = factory
|
|
af.create_agent(
|
|
provider=_provider("gemini"),
|
|
model_id="m1",
|
|
output_type=ResponseModel,
|
|
system_prompt="sys",
|
|
deps_type=DepsModel,
|
|
agent_name="hello",
|
|
)
|
|
init = af._models_mapping["gemini"]["model_class"].last_init
|
|
assert init["kwargs"].get("model_name") == "m1"
|
|
assert "provider" in init["kwargs"]
|
|
|
|
|
|
def test_non_gemini_uses_positional_model_id(factory):
|
|
af, _, _ = factory
|
|
af.create_agent(
|
|
provider=_provider("deepseek"),
|
|
model_id="m1",
|
|
output_type=ResponseModel,
|
|
system_prompt="sys",
|
|
deps_type=DepsModel,
|
|
agent_name="hello",
|
|
)
|
|
init = af._models_mapping["deepseek"]["model_class"].last_init
|
|
assert init["args"] == ("m1",)
|