# Copyright 2026 zhaoxi826 # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import os from typing import Dict from fastapi import FastAPI, WebSocket, Request from fastapi.middleware.cors import CORSMiddleware from fastapi.responses import FileResponse, JSONResponse from fastapi.staticfiles import StaticFiles from ray import serve from .agent import agent_router from .auth import auth_router from .system import system_router from .platform.frontend import client_router from .platform.onebot import onebot_router from .provider import provider_router from .resource import resource_router from .workflow import workflow_router from .chat import chat_router from kilostar.utils.error import ( KiloStarError, BusinessError, InfraError, ) from kilostar.utils.logger import get_logger from kilostar.utils.request_context import ( bind_request_id, new_request_id, reset_request_id, ) from kilostar.utils.i18n import t _api_logger = get_logger("api") def _get_locale(request: Request) -> str | None: """从请求头解析首选语言,供异常 handler 使用。""" return request.headers.get("accept-language") or None app = FastAPI() _cors_origins_env = os.environ.get("KILOSTAR_CORS_ORIGINS", "") _is_dev = os.environ.get("KILOSTAR_ENV", "production").lower() in ("dev", "development") if not _cors_origins_env and _is_dev: _cors_origins_env = "*" elif not _cors_origins_env: _cors_origins_env = "http://localhost:8000" _cors_origins = [o.strip() for o in _cors_origins_env.split(",") if o.strip()] _allow_credentials = "*" not in _cors_origins app.add_middleware( CORSMiddleware, allow_origins=_cors_origins, allow_credentials=_allow_credentials, allow_methods=["*"], allow_headers=["*"], ) @app.middleware("http") async def request_id_middleware(request: Request, call_next): """请求级 ``request_id`` 注入。 入口策略:``X-Request-Id`` 头存在则继承(便于网关/前端串联调用链), 否则生成新的 UUID。退出时把它写到响应头,方便客户端日志对账。 contextvars 让同一请求生命周期内所有协程的日志都自动带上这个 ID。 """ incoming = request.headers.get("X-Request-Id", "").strip() request_id = incoming or new_request_id() token = bind_request_id(request_id) try: response = await call_next(request) finally: reset_request_id(token) response.headers["X-Request-Id"] = request_id return response app.include_router(system_router) # 健康探针 + 系统信息 app.include_router(client_router) # 客户端路径 app.include_router(onebot_router) # OneBot v11 路径 app.include_router(auth_router) # 用户路径 app.include_router(provider_router) # 供应商路径 app.include_router(resource_router) # 资源路径 app.include_router(agent_router) # agent路径 app.include_router(workflow_router) # workflow路径 app.include_router(chat_router) # chat路径 @app.exception_handler(BusinessError) async def business_error_handler(request: Request, exc: BusinessError): """业务可预期错误:按 ``http_status`` 返回 4xx,附 ``code`` + 异常消息。""" return JSONResponse( status_code=exc.http_status, content={"code": exc.code, "message": str(exc) or exc.code}, ) @app.exception_handler(InfraError) async def infra_error_handler(request: Request, exc: InfraError): """系统失败错误:落日志后返回脱敏的 5xx。""" _api_logger.exception( f"InfraError on {request.method} {request.url.path}: {exc}" ) loc = _get_locale(request) return JSONResponse( status_code=exc.http_status, content={"code": exc.code, "message": t("internal_error", accept_language=loc)}, ) @app.exception_handler(Exception) async def unhandled_exception_handler(request: Request, exc: Exception): """全局兜底:未预期的异常落日志后返回脱敏的 500,避免泄露 traceback。""" _api_logger.exception( f"Unhandled exception on {request.method} {request.url.path}: {exc}" ) loc = _get_locale(request) return JSONResponse( status_code=500, content={"code": "internal_error", "message": t("internal_error", accept_language=loc)}, ) base_dir = os.path.dirname( os.path.dirname(os.path.dirname(os.path.abspath(__file__))) ) frontend_dir = os.path.join(base_dir, "frontend", "dist") if os.path.exists(frontend_dir): app.mount( "/assets", StaticFiles(directory=os.path.join(frontend_dir, "assets")), name="assets", ) @app.get("/favicon.svg", include_in_schema=False) async def serve_favicon(): return FileResponse(os.path.join(frontend_dir, "favicon.svg")) @app.get("/icons.svg", include_in_schema=False) async def serve_icons(): return FileResponse(os.path.join(frontend_dir, "icons.svg")) @app.get("/{full_path:path}", include_in_schema=False) async def serve_frontend(full_path: str): # 【重要安全修复】避免拦截不存在的 API 路由。如果是调用了不存在的 /api/ 接口,直接返回 404,不返回前端页面 if full_path.startswith("api/"): return JSONResponse( status_code=404, content={"detail": "API endpoint not found"} ) index_path = os.path.join(frontend_dir, "index.html") if os.path.exists(index_path): return FileResponse(index_path) return JSONResponse( status_code=404, content={"detail": t("frontend_not_found")} ) else: import logging logging.getLogger("kilostar").warning( f"Frontend dist folder not found at {frontend_dir}. Skipping frontend mount." ) @serve.deployment @serve.ingress(app) class KiloStarGateway: gateway: Dict[str, WebSocket] def __init__(self): self.gateway = {}