"""``AgentFactory.create_agent`` 的协议分发与异常分支。 为避免真的实例化 OpenAI / Anthropic / Google 等 provider,所有 ``provider_class`` 与 ``model_class`` 都被替换为 spy 类,仅记录构造参数。 """ from __future__ import annotations from typing import Any, Dict from unittest.mock import MagicMock import pytest from kilostar.adapter.model_adapter import agent_factory as af_mod from kilostar.adapter.model_adapter.agent_factory import AgentFactory from kilostar.core.global_state_machine.model_provider.base_provider import Provider from kilostar.utils.agent_model import DepsModel, ResponseModel from kilostar.utils.error import ModelNotExistError class _SpyProvider: last_init: Dict[str, Any] = {} def __init__(self, **kwargs): type(self).last_init = kwargs class _SpyModel: last_init: Dict[str, Any] = {} def __init__(self, *args, **kwargs): type(self).last_init = {"args": args, "kwargs": kwargs} @pytest.fixture def factory(monkeypatch): """构造 AgentFactory 并将 provider/model class 全部替换成 spy 类。""" af = AgentFactory() for proto in af._models_mapping: af._models_mapping[proto]["model_class"] = type( f"_Model_{proto}", (_SpyModel,), {} ) af._models_mapping[proto]["provider_class"] = type( f"_Provider_{proto}", (_SpyProvider,), {} ) fake_agent = MagicMock(name="Agent") fake_agent_class = MagicMock(name="Agent_class", return_value=fake_agent) monkeypatch.setattr(af_mod, "Agent", fake_agent_class) return af, fake_agent_class, fake_agent def _provider(provider_type: str = "openai") -> Provider: return Provider( provider_title="t", provider_url="https://example.com", provider_apikey="sk-123", provider_models=["m1", "m2"], provider_type=provider_type, ) def test_create_agent_returns_agent_instance(factory): af, agent_cls, fake_agent = factory result = af.create_agent( provider=_provider("openai"), model_id="m1", output_type=ResponseModel, system_prompt="sys", deps_type=DepsModel, agent_name="hello", ) assert result is fake_agent agent_cls.assert_called_once() kwargs = agent_cls.call_args.kwargs assert kwargs["name"] == "hello" assert kwargs["system_prompt"] == "sys" assert kwargs["tools"] == [] assert kwargs["toolsets"] == [] def test_create_agent_passes_through_tools_and_toolsets(factory): af, agent_cls, _ = factory tools = [lambda: 1] toolsets = [object()] af.create_agent( provider=_provider("openai"), model_id="m1", output_type=ResponseModel, system_prompt="sys", deps_type=DepsModel, agent_name="hello", tools=tools, toolsets=toolsets, ) kwargs = agent_cls.call_args.kwargs assert kwargs["tools"] is tools assert kwargs["toolsets"] is toolsets def test_create_agent_unknown_model_raises(factory): af, _, _ = factory with pytest.raises(ModelNotExistError): af.create_agent( provider=_provider("openai"), model_id="not-in-list", output_type=ResponseModel, system_prompt="sys", deps_type=DepsModel, agent_name="hello", ) def test_create_agent_unknown_protocol_raises(factory): af, _, _ = factory bad = _provider("openai") bad.provider_type = "weird-protocol" with pytest.raises(ValueError): af.create_agent( provider=bad, model_id="m1", output_type=ResponseModel, system_prompt="sys", deps_type=DepsModel, agent_name="hello", ) def test_openai_protocol_passes_api_key_and_base_url(factory): af, _, _ = factory af.create_agent( provider=_provider("openai"), model_id="m1", output_type=ResponseModel, system_prompt="sys", deps_type=DepsModel, agent_name="hello", ) init = af._models_mapping["openai"]["provider_class"].last_init assert init["api_key"] == "sk-123" assert init["base_url"] == "https://example.com" def test_claude_protocol_passes_api_key_only(factory): af, _, _ = factory af.create_agent( provider=_provider("claude"), model_id="m1", output_type=ResponseModel, system_prompt="sys", deps_type=DepsModel, agent_name="hello", ) init = af._models_mapping["claude"]["provider_class"].last_init assert init["api_key"] == "sk-123" assert "base_url" not in init def test_gemini_protocol_uses_kwarg_model_name(factory): af, _, _ = factory af.create_agent( provider=_provider("gemini"), model_id="m1", output_type=ResponseModel, system_prompt="sys", deps_type=DepsModel, agent_name="hello", ) init = af._models_mapping["gemini"]["model_class"].last_init assert init["kwargs"].get("model_name") == "m1" assert "provider" in init["kwargs"] def test_non_gemini_uses_positional_model_id(factory): af, _, _ = factory af.create_agent( provider=_provider("deepseek"), model_id="m1", output_type=ResponseModel, system_prompt="sys", deps_type=DepsModel, agent_name="hello", ) init = af._models_mapping["deepseek"]["model_class"].last_init assert init["args"] == ("m1",)