# Copyright 2026 zhaoxi826 # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from fastapi import APIRouter, Depends from pydantic import BaseModel, Field from typing import Any, Dict, Literal, Optional from kilostar.utils.access import TokenData, Accessor, RoleChecker from kilostar.core.postgres_database.model import UserAuthority from kilostar.core.global_state_machine.model_provider.base_provider import Provider from kilostar.utils.ray_hook import ray_actor_hook provider_router = APIRouter(prefix="/api/v1/provider", tags=["provider"]) class ProviderRegister(BaseModel): """``POST /provider`` 入参:注册一个模型 Provider 的最小字段集。""" provider_type: Literal["openai", "claude", "deepseek", "gemini"] provider_title: str provider_url: str provider_apikey: str model_settings: Optional[Dict[str, Dict[str, Any]]] = Field(default=None) @provider_router.post("") async def create_provider( provider_register: ProviderRegister, token_data: TokenData = Depends(RoleChecker(allowed_roles=UserAuthority.USER)), ) -> None: """注册一个 Provider;owner 为当前登录用户的 ``user_id``。""" global_state_machine = ray_actor_hook("global_state_machine").global_state_machine await global_state_machine.add_provider_wrap.remote( provider_type=provider_register.provider_type, provider_title=provider_register.provider_title, provider_url=provider_register.provider_url, provider_apikey=provider_register.provider_apikey, provider_owner=token_data.user_id, model_settings=provider_register.model_settings or {}, ) def _mask_apikey(key: str) -> str: if not key or len(key) <= 8: return "***" return key[:4] + "***" + key[-4:] @provider_router.get("/list") async def get_provider_list( _: TokenData = Depends(Accessor.get_current_user), ) -> Dict[str, Any]: """返回当前所有已注册的 Provider,前端用以展示模型清单。apikey 脱敏。""" global_state_machine = ray_actor_hook("global_state_machine").global_state_machine provider_list: Dict[ str, Provider ] = await global_state_machine.get_provider_list.remote() masked = {} for title, p in provider_list.items(): d = p.model_dump() if hasattr(p, "model_dump") else dict(p) d["provider_apikey"] = _mask_apikey(d.get("provider_apikey", "")) masked[title] = d return {"provider_list": masked} @provider_router.post("/test") async def test_provider_connection( provider_register: ProviderRegister, _: TokenData = Depends(Accessor.get_current_user), ) -> Dict[str, Any]: """测试 Provider 连接:按 provider_type 选择对应协议拉取模型列表。""" import httpx ptype = provider_register.provider_type url = provider_register.provider_url apikey = provider_register.provider_apikey try: async with httpx.AsyncClient(timeout=10.0) as client: if ptype == "claude": endpoint = f"{url}/v1/models" headers = { "x-api-key": apikey, "anthropic-version": "2023-06-01", } response = await client.get(endpoint, headers=headers) if response.status_code == 200: data = response.json() models = [m["id"] for m in data.get("data", [])] return {"success": True, "models": sorted(models), "model_count": len(models)} return {"success": False, "error": f"HTTP {response.status_code}", "models": []} elif ptype == "gemini": endpoint = f"{url}/models" params = {"key": apikey} response = await client.get(endpoint, params=params) if response.status_code == 200: data = response.json() models = [m.get("name", "").removeprefix("models/") for m in data.get("models", [])] return {"success": True, "models": sorted(models), "model_count": len(models)} return {"success": False, "error": f"HTTP {response.status_code}", "models": []} else: if "/v1" not in url: endpoint = f"{url}/v1/models" else: endpoint = f"{url}/models" headers = { "Authorization": f"Bearer {apikey}", "Content-Type": "application/json", } response = await client.get(endpoint, headers=headers) if response.status_code == 200: data = response.json() models = [m["id"] for m in data.get("data", [])] return {"success": True, "models": sorted(models), "model_count": len(models)} return {"success": False, "error": f"HTTP {response.status_code}", "models": []} except Exception as e: return {"success": False, "error": str(e), "models": []} @provider_router.delete("/{provider_title}") async def delete_provider( provider_title: str, _: TokenData = Depends( RoleChecker(allowed_roles=UserAuthority.SUPER_ADMINISTRATOR) ), ) -> dict: """删除指定 ``provider_title`` 的 Provider;仅超管可调用。""" global_state_machine = ray_actor_hook("global_state_machine").global_state_machine await global_state_machine.delete_provider.remote(provider_title=provider_title) return {"message": "success"}