Files
KiloStar/kilostar/api/provider.py
T

76 lines
2.9 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
# Copyright 2026 zhaoxi826
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from fastapi import APIRouter, Depends
from pydantic import BaseModel
from typing import Literal
from kilostar.utils.access import TokenData, Accessor
from kilostar.utils.check_user.role_check import RoleChecker
from kilostar.core.postgres_database.model import UserAuthority
from typing import Dict
from kilostar.core.global_state_machine.model_provider.base_provider import Provider
from kilostar.utils.ray_hook import ray_actor_hook
provider_router = APIRouter(prefix="/api/v1/provider", tags=["provider"])
class ProviderRegister(BaseModel):
"""``POST /provider`` 入参:注册一个模型 Provider 的最小字段集。"""
provider_type: Literal["openai", "claude", "deepseek"]
provider_title: str
provider_url: str
provider_apikey: str
@provider_router.post("")
async def create_provider(
provider_register: ProviderRegister,
token_data: TokenData = Depends(RoleChecker(allowed_roles=UserAuthority.USER)),
) -> None:
"""注册一个 Providerowner 为当前登录用户的 ``user_id``。"""
global_state_machine = ray_actor_hook("global_state_machine").global_state_machine
await global_state_machine.add_provider_wrap.remote(
provider_type=provider_register.provider_type,
provider_title=provider_register.provider_title,
provider_url=provider_register.provider_url,
provider_apikey=provider_register.provider_apikey,
provider_owner=token_data.user_id,
)
@provider_router.get("/list")
async def get_provider_list(
_: TokenData = Depends(Accessor.get_current_user),
) -> Dict[str, Dict[str, Provider]]:
"""返回当前所有已注册的 Provider,前端用以展示模型清单。"""
global_state_machine = ray_actor_hook("global_state_machine").global_state_machine
provider_list: Dict[
str, Provider
] = await global_state_machine.get_provider_list.remote()
return {"provider_list": provider_list}
@provider_router.delete("/{provider_title}")
async def delete_provider(
provider_title: str,
_: TokenData = Depends(
RoleChecker(allowed_roles=UserAuthority.SUPER_ADMINISTRATOR)
),
) -> dict:
"""删除指定 ``provider_title`` 的 Provider;仅超管可调用。"""
global_state_machine = ray_actor_hook("global_state_machine").global_state_machine
await global_state_machine.delete_provider.remote(provider_title=provider_title)
return {"message": "success"}