From 6f6879dfab3eee5fad992eede776bf7cb3a1cef2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=9C=9D=E5=A4=95?= Date: Fri, 15 May 2026 08:06:10 +0800 Subject: [PATCH] =?UTF-8?q?feat(provider):=E5=A2=9E=E5=8A=A0google,anthrop?= =?UTF-8?q?ic=E4=BE=9B=E5=BA=94=E5=95=86?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit 1.增加更多的模型供应商 --- .env | 6 ++ .../adapter/model_adapter/agent_factory.py | 85 +++++++++++------ .../model_provider/__init__.py | 4 + .../model_provider/gemini_provider.py | 91 +++++++++++++++++++ .../global_state_machine/provider_manager.py | 2 + pyproject.toml | 7 ++ uv.lock | 2 +- 7 files changed, 166 insertions(+), 31 deletions(-) create mode 100644 .env create mode 100644 kilostar/core/global_state_machine/model_provider/gemini_provider.py diff --git a/.env b/.env new file mode 100644 index 0000000..6187fae --- /dev/null +++ b/.env @@ -0,0 +1,6 @@ +POSTGRES_USER=postgres +POSTGRES_PASSWORD=postgrespassword +POSTGRES_HOST=127.0.0.1 +POSTGRES_PORT=5432 +POSTGRES_DB=kilostar +SECRET_KEY=mysecretkey123456789 \ No newline at end of file diff --git a/kilostar/adapter/model_adapter/agent_factory.py b/kilostar/adapter/model_adapter/agent_factory.py index 1d2d3e6..fcb9cfe 100644 --- a/kilostar/adapter/model_adapter/agent_factory.py +++ b/kilostar/adapter/model_adapter/agent_factory.py @@ -15,9 +15,11 @@ from pydantic_ai import Agent from pydantic_ai.models.openai import OpenAIChatModel from pydantic_ai.models.anthropic import AnthropicModel +from pydantic_ai.models.gemini import GeminiModel from pydantic_ai.providers.openai import OpenAIProvider from pydantic_ai.providers.anthropic import AnthropicProvider -from kilostar.adapter.model_adapter.deepseek_reasoner import DeepSeekReasonerAgent +from pydantic_ai.providers.deepseek import DeepSeekProvider +from pydantic_ai.providers.google import GoogleProvider from kilostar.core.global_state_machine.model_provider import Provider from kilostar.utils.agent_model import ResponseModel, DepsModel from kilostar.utils.error import ModelNotExistError @@ -29,9 +31,26 @@ class AgentFactory: def __init__(self): self._models_mapping = { - "openai": (OpenAIChatModel, OpenAIProvider), - "claude": (AnthropicModel, AnthropicProvider), - "deepseek": (OpenAIChatModel, OpenAIProvider), + "openai": { + "model_class": OpenAIChatModel, + "provider_class": OpenAIProvider, + "provider_kwargs": {"base_url": True, "api_key": True}, + }, + "claude": { + "model_class": AnthropicModel, + "provider_class": AnthropicProvider, + "provider_kwargs": {"api_key": True}, + }, + "deepseek": { + "model_class": OpenAIChatModel, + "provider_class": DeepSeekProvider, + "provider_kwargs": {"api_key": True}, + }, + "gemini": { + "model_class": GeminiModel, + "provider_class": GoogleProvider, + "provider_kwargs": {"api_key": True}, + }, } def create_agent( @@ -63,31 +82,37 @@ class AgentFactory: raise ModelNotExistError("模型不存在") if provider.provider_type not in self._models_mapping: raise ValueError(f"不支持的协议类型: {provider.provider_type}") - model_class, provider_class = self._models_mapping[provider.provider_type] - model = model_class( - model_id, - provider=provider_class( - api_key=provider.provider_apikey, base_url=provider.provider_url - ), + + config = self._models_mapping[provider.provider_type] + model_class = config["model_class"] + provider_class = config["provider_class"] + provider_kwargs = config["provider_kwargs"] + + # 构建 provider 实例化参数 + init_kwargs = {} + if provider_kwargs.get("api_key"): + init_kwargs["api_key"] = provider.provider_apikey + if provider_kwargs.get("base_url"): + init_kwargs["base_url"] = provider.provider_url + + model_provider = provider_class(**init_kwargs) + + # 对于 Gemini,provider 需要传递给 model + if provider.provider_type == "gemini": + model = model_class( + model_name=model_id, + provider=model_provider, + ) + else: + model = model_class(model_id, provider=model_provider) + + # 创建 Agent + agent = Agent( + model=model, + name=agent_name, + system_prompt=system_prompt, + output_type=output_type, + deps_type=deps_type, + tools=tools, ) - match provider.provider_type: - case "deepseek": - agent = DeepSeekReasonerAgent( - model=model, - name=agent_name, - output_type=output_type, - deps_type=deps_type, - system_prompt=system_prompt, - tools=tools, - retries=3, - ) - case _: - agent = Agent( - model=model, - name=agent_name, - system_prompt=system_prompt, - output_type=output_type, - deps_type=deps_type, - tools=tools, - ) return agent diff --git a/kilostar/core/global_state_machine/model_provider/__init__.py b/kilostar/core/global_state_machine/model_provider/__init__.py index 3b5109c..9602cb6 100644 --- a/kilostar/core/global_state_machine/model_provider/__init__.py +++ b/kilostar/core/global_state_machine/model_provider/__init__.py @@ -25,6 +25,9 @@ from kilostar.core.global_state_machine.model_provider.claude_provider import ( from kilostar.core.global_state_machine.model_provider.deepseek_provider import ( DeepseekProvider, ) +from kilostar.core.global_state_machine.model_provider.gemini_provider import ( + GeminiProvider, +) __all__ = [ "Provider", @@ -32,4 +35,5 @@ __all__ = [ "OpenAIProvider", "ClaudeProvider", "DeepseekProvider", + "GeminiProvider", ] diff --git a/kilostar/core/global_state_machine/model_provider/gemini_provider.py b/kilostar/core/global_state_machine/model_provider/gemini_provider.py new file mode 100644 index 0000000..48c4620 --- /dev/null +++ b/kilostar/core/global_state_machine/model_provider/gemini_provider.py @@ -0,0 +1,91 @@ +# Copyright 2026 zhaoxi826 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from kilostar.utils.retry import retry_on_retryable_error +from kilostar.core.global_state_machine.model_provider.base_provider import ( + BaseProvider, + Provider, + ProviderArgs, +) +import httpx +from typing import List + + +class GeminiProvider(BaseProvider): + """GeminiProvider 核心组件类。 + 这是一个模型/服务提供商适配器类,屏蔽了外部不同供应商(如 Google Gemini)的底层 API 差异。它负责标准化参数组装、网络请求发送、鉴权处理以及响应结构的反序列化。""" + + @staticmethod + async def create_provider(provider_args: ProviderArgs) -> Provider: + """创建并持久化新的 provider 实体。 + 接收构建参数,执行必要的数据校验与默认值填充后,将新记录安全地写入底层存储或系统注册表中。 + Args: provider_args (ProviderArgs): 目标对象的唯一全局标识符 (UUID/ULID),用于在数据库表或缓存结构中精准匹配该 provider_args 实例。 + Returns: (Provider): 经由当前业务模型加工处理后所输出的具体数据实例或领域模型对象。""" + provider_models: List[str] = await GeminiProvider._load_models(provider_args) + provider: Provider = GeminiProvider._return_provider( + provider_args, provider_models + ) + return provider + + @staticmethod + @retry_on_retryable_error() + async def _load_models(provider_args: ProviderArgs) -> List[str]: + """执行与 load models 相关的核心业务流转操作。 + 该方法封装了具体的算法策略或状态控制逻辑,确保操作能够在事务上下文中被原子且一致地执行。 + Args: provider_args (ProviderArgs): 目标对象的唯一全局标识符 (UUID/ULID),用于在数据库表或缓存结构中精准匹配该 provider_args 实例。 + Returns: (List[str]): 经过筛选、排序或分页处理后的实体对象列表集合。""" + headers = { + "Authorization": f"Bearer {provider_args.provider_apikey}", + "Content-Type": "application/json", + } + url = f"{provider_args.provider_url}/v1beta/models" + + try: + async with httpx.AsyncClient(timeout=10.0) as client: + response = await client.get(url, headers=headers) + if response.status_code != 200: + print( + f"[{provider_args.provider_title}] 获取模型失败: {response.status_code}" + ) + return [] + data = response.json() + raw_models = data.get("models", []) + model_ids = [m["name"].replace("models/", "") for m in raw_models] + return sorted(model_ids) + except httpx.RequestError as e: + from kilostar.utils.error import RetryableError + + print(f"[{provider_args.provider_title}] 网络请求异常: {e}") + raise RetryableError( + f"[{provider_args.provider_title}] 网络请求异常: {e}" + ) from e + except Exception as e: + print(f"[{provider_args.provider_title}] 解析模型列表时发生错误: {e}") + return [] + + @staticmethod + def _return_provider( + provider_args: ProviderArgs, provider_models: List[str] + ) -> Provider: + """执行与 return provider 相关的核心业务流转操作。 + 该方法封装了具体的算法策略或状态控制逻辑,确保操作能够在事务上下文中被原子且一致地执行。 + Args: provider_args (ProviderArgs): 目标对象的唯一全局标识符 (UUID/ULID),用于在数据库表或缓存结构中精准匹配该 provider_args 实例。 provider_models (List[str]): 目标对象的唯一全局标识符 (UUID/ULID),用于在数据库表或缓存结构中精准匹配该 provider_models 实例。 + Returns: (Provider): 经由当前业务模型加工处理后所输出的具体数据实例或领域模型对象。""" + return Provider( + provider_title=provider_args.provider_title, + provider_apikey=provider_args.provider_apikey, + provider_url=provider_args.provider_url, + provider_models=provider_models, + provider_type="gemini", + ) diff --git a/kilostar/core/global_state_machine/provider_manager.py b/kilostar/core/global_state_machine/provider_manager.py index ce9c509..24e1d35 100644 --- a/kilostar/core/global_state_machine/provider_manager.py +++ b/kilostar/core/global_state_machine/provider_manager.py @@ -17,6 +17,7 @@ from kilostar.core.global_state_machine.model_provider import ( OpenAIProvider, ClaudeProvider, DeepseekProvider, + GeminiProvider, ) from typing import Dict, Type @@ -39,6 +40,7 @@ class ProviderManager: "openai": OpenAIProvider, "claude": ClaudeProvider, "deepseek": DeepseekProvider, + "gemini": GeminiProvider, } self.provider_register = {} diff --git a/pyproject.toml b/pyproject.toml index 4647460..2ef5de9 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,3 +1,10 @@ +[build-system] +requires = ["hatchling"] +build-backend = "hatchling.build" + +[tool.hatch.build.targets.wheel] +packages = ["kilostar"] + [project] name = "kilostar" version = "0.1.0" diff --git a/uv.lock b/uv.lock index b3b4210..041ae88 100644 --- a/uv.lock +++ b/uv.lock @@ -2161,7 +2161,7 @@ wheels = [ [[package]] name = "kilostar" version = "0.1.0" -source = { virtual = "." } +source = { editable = "." } dependencies = [ { name = "asyncpg" }, { name = "docker-py" },