feat(provider):增加google,anthropic供应商
1.增加更多的模型供应商
This commit is contained in:
@@ -0,0 +1,6 @@
|
||||
POSTGRES_USER=postgres
|
||||
POSTGRES_PASSWORD=postgrespassword
|
||||
POSTGRES_HOST=127.0.0.1
|
||||
POSTGRES_PORT=5432
|
||||
POSTGRES_DB=kilostar
|
||||
SECRET_KEY=mysecretkey123456789
|
||||
@@ -15,9 +15,11 @@
|
||||
from pydantic_ai import Agent
|
||||
from pydantic_ai.models.openai import OpenAIChatModel
|
||||
from pydantic_ai.models.anthropic import AnthropicModel
|
||||
from pydantic_ai.models.gemini import GeminiModel
|
||||
from pydantic_ai.providers.openai import OpenAIProvider
|
||||
from pydantic_ai.providers.anthropic import AnthropicProvider
|
||||
from kilostar.adapter.model_adapter.deepseek_reasoner import DeepSeekReasonerAgent
|
||||
from pydantic_ai.providers.deepseek import DeepSeekProvider
|
||||
from pydantic_ai.providers.google import GoogleProvider
|
||||
from kilostar.core.global_state_machine.model_provider import Provider
|
||||
from kilostar.utils.agent_model import ResponseModel, DepsModel
|
||||
from kilostar.utils.error import ModelNotExistError
|
||||
@@ -29,9 +31,26 @@ class AgentFactory:
|
||||
|
||||
def __init__(self):
|
||||
self._models_mapping = {
|
||||
"openai": (OpenAIChatModel, OpenAIProvider),
|
||||
"claude": (AnthropicModel, AnthropicProvider),
|
||||
"deepseek": (OpenAIChatModel, OpenAIProvider),
|
||||
"openai": {
|
||||
"model_class": OpenAIChatModel,
|
||||
"provider_class": OpenAIProvider,
|
||||
"provider_kwargs": {"base_url": True, "api_key": True},
|
||||
},
|
||||
"claude": {
|
||||
"model_class": AnthropicModel,
|
||||
"provider_class": AnthropicProvider,
|
||||
"provider_kwargs": {"api_key": True},
|
||||
},
|
||||
"deepseek": {
|
||||
"model_class": OpenAIChatModel,
|
||||
"provider_class": DeepSeekProvider,
|
||||
"provider_kwargs": {"api_key": True},
|
||||
},
|
||||
"gemini": {
|
||||
"model_class": GeminiModel,
|
||||
"provider_class": GoogleProvider,
|
||||
"provider_kwargs": {"api_key": True},
|
||||
},
|
||||
}
|
||||
|
||||
def create_agent(
|
||||
@@ -63,31 +82,37 @@ class AgentFactory:
|
||||
raise ModelNotExistError("模型不存在")
|
||||
if provider.provider_type not in self._models_mapping:
|
||||
raise ValueError(f"不支持的协议类型: {provider.provider_type}")
|
||||
model_class, provider_class = self._models_mapping[provider.provider_type]
|
||||
model = model_class(
|
||||
model_id,
|
||||
provider=provider_class(
|
||||
api_key=provider.provider_apikey, base_url=provider.provider_url
|
||||
),
|
||||
|
||||
config = self._models_mapping[provider.provider_type]
|
||||
model_class = config["model_class"]
|
||||
provider_class = config["provider_class"]
|
||||
provider_kwargs = config["provider_kwargs"]
|
||||
|
||||
# 构建 provider 实例化参数
|
||||
init_kwargs = {}
|
||||
if provider_kwargs.get("api_key"):
|
||||
init_kwargs["api_key"] = provider.provider_apikey
|
||||
if provider_kwargs.get("base_url"):
|
||||
init_kwargs["base_url"] = provider.provider_url
|
||||
|
||||
model_provider = provider_class(**init_kwargs)
|
||||
|
||||
# 对于 Gemini,provider 需要传递给 model
|
||||
if provider.provider_type == "gemini":
|
||||
model = model_class(
|
||||
model_name=model_id,
|
||||
provider=model_provider,
|
||||
)
|
||||
else:
|
||||
model = model_class(model_id, provider=model_provider)
|
||||
|
||||
# 创建 Agent
|
||||
agent = Agent(
|
||||
model=model,
|
||||
name=agent_name,
|
||||
system_prompt=system_prompt,
|
||||
output_type=output_type,
|
||||
deps_type=deps_type,
|
||||
tools=tools,
|
||||
)
|
||||
match provider.provider_type:
|
||||
case "deepseek":
|
||||
agent = DeepSeekReasonerAgent(
|
||||
model=model,
|
||||
name=agent_name,
|
||||
output_type=output_type,
|
||||
deps_type=deps_type,
|
||||
system_prompt=system_prompt,
|
||||
tools=tools,
|
||||
retries=3,
|
||||
)
|
||||
case _:
|
||||
agent = Agent(
|
||||
model=model,
|
||||
name=agent_name,
|
||||
system_prompt=system_prompt,
|
||||
output_type=output_type,
|
||||
deps_type=deps_type,
|
||||
tools=tools,
|
||||
)
|
||||
return agent
|
||||
|
||||
@@ -25,6 +25,9 @@ from kilostar.core.global_state_machine.model_provider.claude_provider import (
|
||||
from kilostar.core.global_state_machine.model_provider.deepseek_provider import (
|
||||
DeepseekProvider,
|
||||
)
|
||||
from kilostar.core.global_state_machine.model_provider.gemini_provider import (
|
||||
GeminiProvider,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"Provider",
|
||||
@@ -32,4 +35,5 @@ __all__ = [
|
||||
"OpenAIProvider",
|
||||
"ClaudeProvider",
|
||||
"DeepseekProvider",
|
||||
"GeminiProvider",
|
||||
]
|
||||
|
||||
@@ -0,0 +1,91 @@
|
||||
# Copyright 2026 zhaoxi826
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from kilostar.utils.retry import retry_on_retryable_error
|
||||
from kilostar.core.global_state_machine.model_provider.base_provider import (
|
||||
BaseProvider,
|
||||
Provider,
|
||||
ProviderArgs,
|
||||
)
|
||||
import httpx
|
||||
from typing import List
|
||||
|
||||
|
||||
class GeminiProvider(BaseProvider):
|
||||
"""GeminiProvider 核心组件类。
|
||||
这是一个模型/服务提供商适配器类,屏蔽了外部不同供应商(如 Google Gemini)的底层 API 差异。它负责标准化参数组装、网络请求发送、鉴权处理以及响应结构的反序列化。"""
|
||||
|
||||
@staticmethod
|
||||
async def create_provider(provider_args: ProviderArgs) -> Provider:
|
||||
"""创建并持久化新的 provider 实体。
|
||||
接收构建参数,执行必要的数据校验与默认值填充后,将新记录安全地写入底层存储或系统注册表中。
|
||||
Args: provider_args (ProviderArgs): 目标对象的唯一全局标识符 (UUID/ULID),用于在数据库表或缓存结构中精准匹配该 provider_args 实例。
|
||||
Returns: (Provider): 经由当前业务模型加工处理后所输出的具体数据实例或领域模型对象。"""
|
||||
provider_models: List[str] = await GeminiProvider._load_models(provider_args)
|
||||
provider: Provider = GeminiProvider._return_provider(
|
||||
provider_args, provider_models
|
||||
)
|
||||
return provider
|
||||
|
||||
@staticmethod
|
||||
@retry_on_retryable_error()
|
||||
async def _load_models(provider_args: ProviderArgs) -> List[str]:
|
||||
"""执行与 load models 相关的核心业务流转操作。
|
||||
该方法封装了具体的算法策略或状态控制逻辑,确保操作能够在事务上下文中被原子且一致地执行。
|
||||
Args: provider_args (ProviderArgs): 目标对象的唯一全局标识符 (UUID/ULID),用于在数据库表或缓存结构中精准匹配该 provider_args 实例。
|
||||
Returns: (List[str]): 经过筛选、排序或分页处理后的实体对象列表集合。"""
|
||||
headers = {
|
||||
"Authorization": f"Bearer {provider_args.provider_apikey}",
|
||||
"Content-Type": "application/json",
|
||||
}
|
||||
url = f"{provider_args.provider_url}/v1beta/models"
|
||||
|
||||
try:
|
||||
async with httpx.AsyncClient(timeout=10.0) as client:
|
||||
response = await client.get(url, headers=headers)
|
||||
if response.status_code != 200:
|
||||
print(
|
||||
f"[{provider_args.provider_title}] 获取模型失败: {response.status_code}"
|
||||
)
|
||||
return []
|
||||
data = response.json()
|
||||
raw_models = data.get("models", [])
|
||||
model_ids = [m["name"].replace("models/", "") for m in raw_models]
|
||||
return sorted(model_ids)
|
||||
except httpx.RequestError as e:
|
||||
from kilostar.utils.error import RetryableError
|
||||
|
||||
print(f"[{provider_args.provider_title}] 网络请求异常: {e}")
|
||||
raise RetryableError(
|
||||
f"[{provider_args.provider_title}] 网络请求异常: {e}"
|
||||
) from e
|
||||
except Exception as e:
|
||||
print(f"[{provider_args.provider_title}] 解析模型列表时发生错误: {e}")
|
||||
return []
|
||||
|
||||
@staticmethod
|
||||
def _return_provider(
|
||||
provider_args: ProviderArgs, provider_models: List[str]
|
||||
) -> Provider:
|
||||
"""执行与 return provider 相关的核心业务流转操作。
|
||||
该方法封装了具体的算法策略或状态控制逻辑,确保操作能够在事务上下文中被原子且一致地执行。
|
||||
Args: provider_args (ProviderArgs): 目标对象的唯一全局标识符 (UUID/ULID),用于在数据库表或缓存结构中精准匹配该 provider_args 实例。 provider_models (List[str]): 目标对象的唯一全局标识符 (UUID/ULID),用于在数据库表或缓存结构中精准匹配该 provider_models 实例。
|
||||
Returns: (Provider): 经由当前业务模型加工处理后所输出的具体数据实例或领域模型对象。"""
|
||||
return Provider(
|
||||
provider_title=provider_args.provider_title,
|
||||
provider_apikey=provider_args.provider_apikey,
|
||||
provider_url=provider_args.provider_url,
|
||||
provider_models=provider_models,
|
||||
provider_type="gemini",
|
||||
)
|
||||
@@ -17,6 +17,7 @@ from kilostar.core.global_state_machine.model_provider import (
|
||||
OpenAIProvider,
|
||||
ClaudeProvider,
|
||||
DeepseekProvider,
|
||||
GeminiProvider,
|
||||
)
|
||||
from typing import Dict, Type
|
||||
|
||||
@@ -39,6 +40,7 @@ class ProviderManager:
|
||||
"openai": OpenAIProvider,
|
||||
"claude": ClaudeProvider,
|
||||
"deepseek": DeepseekProvider,
|
||||
"gemini": GeminiProvider,
|
||||
}
|
||||
self.provider_register = {}
|
||||
|
||||
|
||||
@@ -1,3 +1,10 @@
|
||||
[build-system]
|
||||
requires = ["hatchling"]
|
||||
build-backend = "hatchling.build"
|
||||
|
||||
[tool.hatch.build.targets.wheel]
|
||||
packages = ["kilostar"]
|
||||
|
||||
[project]
|
||||
name = "kilostar"
|
||||
version = "0.1.0"
|
||||
|
||||
Reference in New Issue
Block a user