feat(system):优化后端
1.新增后端测试 2.增加了后端的加密 3.增加了i18n(国际化)
This commit is contained in:
+82
-49
@@ -16,6 +16,7 @@ import os
|
||||
from typing import Dict
|
||||
|
||||
from fastapi import FastAPI, WebSocket, Request
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from fastapi.responses import FileResponse, JSONResponse
|
||||
from fastapi.staticfiles import StaticFiles
|
||||
from ray import serve
|
||||
@@ -23,26 +24,68 @@ from ray import serve
|
||||
from .agent import agent_router
|
||||
from .auth import auth_router
|
||||
from .cluster import cluster_router
|
||||
from .health import health_router
|
||||
from .platform.frontend import client_router
|
||||
from .platform.onebot import onebot_router
|
||||
from .provider import provider_router
|
||||
from .resource import resource_router
|
||||
from .workflow import workflow_router
|
||||
from .chat import chat_router
|
||||
from kilostar.utils.error import (
|
||||
DemandError,
|
||||
ModelNotExistError,
|
||||
UserError,
|
||||
UserNotExistError,
|
||||
UserPasswordError,
|
||||
ProviderError,
|
||||
ProviderNotExistError,
|
||||
WorkflowError,
|
||||
WorkflowExit,
|
||||
KiloStarError,
|
||||
BusinessError,
|
||||
InfraError,
|
||||
)
|
||||
from kilostar.utils.logger import get_logger
|
||||
from kilostar.utils.request_context import (
|
||||
bind_request_id,
|
||||
new_request_id,
|
||||
reset_request_id,
|
||||
)
|
||||
from kilostar.utils.i18n import t
|
||||
|
||||
_api_logger = get_logger("api")
|
||||
|
||||
|
||||
def _get_locale(request: Request) -> str | None:
|
||||
"""从请求头解析首选语言,供异常 handler 使用。"""
|
||||
return request.headers.get("accept-language") or None
|
||||
|
||||
app = FastAPI()
|
||||
|
||||
_cors_origins_env = os.environ.get("KILOSTAR_CORS_ORIGINS", "*")
|
||||
_cors_origins = [o.strip() for o in _cors_origins_env.split(",") if o.strip()]
|
||||
_allow_credentials = "*" not in _cors_origins
|
||||
app.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=_cors_origins,
|
||||
allow_credentials=_allow_credentials,
|
||||
allow_methods=["*"],
|
||||
allow_headers=["*"],
|
||||
)
|
||||
|
||||
|
||||
@app.middleware("http")
|
||||
async def request_id_middleware(request: Request, call_next):
|
||||
"""请求级 ``request_id`` 注入。
|
||||
|
||||
入口策略:``X-Request-Id`` 头存在则继承(便于网关/前端串联调用链),
|
||||
否则生成新的 UUID。退出时把它写到响应头,方便客户端日志对账。
|
||||
contextvars 让同一请求生命周期内所有协程的日志都自动带上这个 ID。
|
||||
"""
|
||||
incoming = request.headers.get("X-Request-Id", "").strip()
|
||||
request_id = incoming or new_request_id()
|
||||
token = bind_request_id(request_id)
|
||||
try:
|
||||
response = await call_next(request)
|
||||
finally:
|
||||
reset_request_id(token)
|
||||
response.headers["X-Request-Id"] = request_id
|
||||
return response
|
||||
|
||||
app.include_router(health_router) # 健康检查
|
||||
app.include_router(client_router) # 客户端路径
|
||||
app.include_router(onebot_router) # OneBot v11 路径
|
||||
app.include_router(auth_router) # 用户路径
|
||||
app.include_router(provider_router) # 供应商路径
|
||||
app.include_router(resource_router) # 资源路径
|
||||
@@ -52,49 +95,39 @@ app.include_router(workflow_router) # workflow路径
|
||||
app.include_router(chat_router) # chat路径
|
||||
|
||||
|
||||
@app.exception_handler(UserNotExistError)
|
||||
async def user_not_exist_handler(request: Request, exc: UserNotExistError):
|
||||
return JSONResponse(status_code=404, content={"message": "用户不存在"})
|
||||
@app.exception_handler(BusinessError)
|
||||
async def business_error_handler(request: Request, exc: BusinessError):
|
||||
"""业务可预期错误:按 ``http_status`` 返回 4xx,附 ``code`` + 异常消息。"""
|
||||
return JSONResponse(
|
||||
status_code=exc.http_status,
|
||||
content={"code": exc.code, "message": str(exc) or exc.code},
|
||||
)
|
||||
|
||||
|
||||
@app.exception_handler(UserPasswordError)
|
||||
async def user_password_handler(request: Request, exc: UserPasswordError):
|
||||
return JSONResponse(status_code=401, content={"message": "密码错误"})
|
||||
@app.exception_handler(InfraError)
|
||||
async def infra_error_handler(request: Request, exc: InfraError):
|
||||
"""系统失败错误:落日志后返回脱敏的 5xx。"""
|
||||
_api_logger.exception(
|
||||
f"InfraError on {request.method} {request.url.path}: {exc}"
|
||||
)
|
||||
loc = _get_locale(request)
|
||||
return JSONResponse(
|
||||
status_code=exc.http_status,
|
||||
content={"code": exc.code, "message": t("internal_error", accept_language=loc)},
|
||||
)
|
||||
|
||||
|
||||
@app.exception_handler(UserError)
|
||||
async def user_error_handler(request: Request, exc: UserError):
|
||||
return JSONResponse(status_code=400, content={"message": "用户相关错误"})
|
||||
|
||||
|
||||
@app.exception_handler(ProviderNotExistError)
|
||||
async def provider_not_exist_handler(request: Request, exc: ProviderNotExistError):
|
||||
return JSONResponse(status_code=404, content={"message": "服务提供商不存在"})
|
||||
|
||||
|
||||
@app.exception_handler(ProviderError)
|
||||
async def provider_error_handler(request: Request, exc: ProviderError):
|
||||
return JSONResponse(status_code=400, content={"message": "服务提供商错误"})
|
||||
|
||||
|
||||
@app.exception_handler(ModelNotExistError)
|
||||
async def model_not_exist_handler(request: Request, exc: ModelNotExistError):
|
||||
return JSONResponse(status_code=404, content={"message": "模型不存在"})
|
||||
|
||||
|
||||
@app.exception_handler(DemandError)
|
||||
async def demand_error_handler(request: Request, exc: DemandError):
|
||||
return JSONResponse(status_code=400, content={"message": "需求格式错误或不满足"})
|
||||
|
||||
|
||||
@app.exception_handler(WorkflowExit)
|
||||
async def workflow_exit_handler(request: Request, exc: WorkflowExit):
|
||||
return JSONResponse(status_code=400, content={"message": "工作流已退出"})
|
||||
|
||||
|
||||
@app.exception_handler(WorkflowError)
|
||||
async def workflow_error_handler(request: Request, exc: WorkflowError):
|
||||
return JSONResponse(status_code=500, content={"message": "工作流执行错误"})
|
||||
@app.exception_handler(Exception)
|
||||
async def unhandled_exception_handler(request: Request, exc: Exception):
|
||||
"""全局兜底:未预期的异常落日志后返回脱敏的 500,避免泄露 traceback。"""
|
||||
_api_logger.exception(
|
||||
f"Unhandled exception on {request.method} {request.url.path}: {exc}"
|
||||
)
|
||||
loc = _get_locale(request)
|
||||
return JSONResponse(
|
||||
status_code=500,
|
||||
content={"code": "internal_error", "message": t("internal_error", accept_language=loc)},
|
||||
)
|
||||
|
||||
|
||||
base_dir = os.path.dirname(
|
||||
@@ -129,7 +162,7 @@ if os.path.exists(frontend_dir):
|
||||
if os.path.exists(index_path):
|
||||
return FileResponse(index_path)
|
||||
return JSONResponse(
|
||||
status_code=404, content={"detail": "Frontend build not found"}
|
||||
status_code=404, content={"detail": t("frontend_not_found")}
|
||||
)
|
||||
else:
|
||||
import logging
|
||||
|
||||
+15
-2
@@ -15,7 +15,7 @@
|
||||
|
||||
from typing import Union
|
||||
from kilostar.utils.ray_hook import ray_actor_hook
|
||||
from fastapi import APIRouter, Depends
|
||||
from fastapi import APIRouter, Depends, Request
|
||||
from pydantic import BaseModel
|
||||
from kilostar.utils.access import Accessor, TokenData
|
||||
from kilostar.core.postgres_database.model import AgentType
|
||||
@@ -23,6 +23,8 @@ from fastapi import HTTPException
|
||||
from typing import Optional, List, Dict
|
||||
from kilostar.utils.check_user.role_check import RoleChecker
|
||||
from kilostar.core.postgres_database.model import UserAuthority
|
||||
from kilostar.utils.mcp_helper import get_all_toolsets_for_scope
|
||||
from kilostar.utils.i18n import t
|
||||
|
||||
agent_router = APIRouter(prefix="/api/v1/agent", tags=["agent"])
|
||||
|
||||
@@ -57,11 +59,13 @@ async def get_system_nodes(
|
||||
@agent_router.post("")
|
||||
async def load_agent(
|
||||
agent_register: Union[AgentRegister, AgentLocalRegister],
|
||||
request: Request,
|
||||
_: TokenData = Depends(RoleChecker(allowed_roles=UserAuthority.USER)),
|
||||
):
|
||||
"""加载/重载某个系统节点的 Agent:先持久化配置,再调用对应节点 Actor 的 ``create_agent``。"""
|
||||
global_state_machine = ray_actor_hook("global_state_machine").global_state_machine
|
||||
postgres_database = ray_actor_hook("postgres_database").postgres_database
|
||||
accept_lang = request.headers.get("accept-language", "")
|
||||
|
||||
if isinstance(agent_register, AgentLocalRegister):
|
||||
pass
|
||||
@@ -75,7 +79,10 @@ async def load_agent(
|
||||
agent_register.tools,
|
||||
)
|
||||
|
||||
match agent_register.individual_name:
|
||||
scope = agent_register.individual_name
|
||||
toolsets = await get_all_toolsets_for_scope(scope)
|
||||
|
||||
match scope:
|
||||
case "regulatory_node":
|
||||
node = ray_actor_hook("regulatory_node").regulatory_node
|
||||
await node.create_agent.remote(
|
||||
@@ -83,6 +90,8 @@ async def load_agent(
|
||||
agent_register.provider_title,
|
||||
agent_register.model_id,
|
||||
agent_register.tools,
|
||||
toolsets,
|
||||
accept_lang,
|
||||
)
|
||||
case "consciousness_node":
|
||||
node = ray_actor_hook("consciousness_node").consciousness_node
|
||||
@@ -91,6 +100,8 @@ async def load_agent(
|
||||
agent_register.provider_title,
|
||||
agent_register.model_id,
|
||||
agent_register.tools,
|
||||
toolsets,
|
||||
accept_lang,
|
||||
)
|
||||
case "control_node":
|
||||
node = ray_actor_hook("control_node").control_node
|
||||
@@ -99,6 +110,8 @@ async def load_agent(
|
||||
agent_register.provider_title,
|
||||
agent_register.model_id,
|
||||
agent_register.tools,
|
||||
toolsets,
|
||||
accept_lang,
|
||||
)
|
||||
case _:
|
||||
pass
|
||||
|
||||
+35
-6
@@ -16,10 +16,40 @@ from fastapi import APIRouter, Depends
|
||||
from pydantic import BaseModel
|
||||
from kilostar.utils.ray_hook import ray_actor_hook
|
||||
from kilostar.utils.access import Accessor, TokenData
|
||||
from kilostar.core.individual.regulatory_node.template import (
|
||||
MessageRequest,
|
||||
MessageResponse,
|
||||
)
|
||||
|
||||
chat_router = APIRouter(prefix="/api/v1/chat", tags=["chat"])
|
||||
|
||||
|
||||
def _extract_reply(resp: MessageResponse | None) -> str | None:
|
||||
"""从 RegulatoryNode.working 的输出里取出对用户的回复文本。
|
||||
|
||||
RegulatoryNode 现在的 output_type 只剩 ``MessageResponse``(聊天/简单任务/汇报),
|
||||
没有则视为节点降级为静默——上层不写回 chat history。
|
||||
"""
|
||||
if resp is None:
|
||||
return None
|
||||
return resp.reply_message
|
||||
|
||||
|
||||
async def _ask_regulatory(
|
||||
*, user_id: str, chat_id: str, message: str
|
||||
) -> str | None:
|
||||
"""统一封装 chat 入口对 RegulatoryNode 的调用。"""
|
||||
regulatory_node = ray_actor_hook("regulatory_node").regulatory_node
|
||||
payload = MessageRequest(
|
||||
platform="client",
|
||||
user_name=user_id,
|
||||
platform_id=chat_id,
|
||||
message=message,
|
||||
)
|
||||
resp: MessageResponse | None = await regulatory_node.working.remote(payload)
|
||||
return _extract_reply(resp)
|
||||
|
||||
|
||||
class CreateChatRequest(BaseModel):
|
||||
title: str = "新对话"
|
||||
initial_message: str
|
||||
@@ -45,9 +75,7 @@ async def create_chat_session(
|
||||
)
|
||||
|
||||
# 调用监管节点处理简单任务/交流
|
||||
regulatory_node = ray_actor_hook("regulatory_node").regulatory_node
|
||||
# 在此发起任务并等待或异步返回结果
|
||||
response_msg = await regulatory_node.handle_chat_message.remote(
|
||||
response_msg = await _ask_regulatory(
|
||||
user_id=token_data.user_id,
|
||||
chat_id=chat.chat_id,
|
||||
message=request.initial_message,
|
||||
@@ -95,9 +123,10 @@ async def send_chat_message(
|
||||
)
|
||||
|
||||
# 调用监管节点
|
||||
regulatory_node = ray_actor_hook("regulatory_node").regulatory_node
|
||||
response_msg = await regulatory_node.handle_chat_message.remote(
|
||||
user_id=token_data.user_id, chat_id=chat_id, message=request.message
|
||||
response_msg = await _ask_regulatory(
|
||||
user_id=token_data.user_id,
|
||||
chat_id=chat_id,
|
||||
message=request.message,
|
||||
)
|
||||
|
||||
# 存回复
|
||||
|
||||
@@ -0,0 +1,54 @@
|
||||
# Copyright 2026 zhaoxi826
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""健康检查端点:用于容器存活/就绪探针。"""
|
||||
|
||||
from fastapi import APIRouter
|
||||
from fastapi.responses import JSONResponse
|
||||
|
||||
from kilostar.utils.ray_hook import ray_actor_hook
|
||||
|
||||
health_router = APIRouter(tags=["health"])
|
||||
|
||||
|
||||
@health_router.get("/health/live", include_in_schema=True)
|
||||
async def liveness():
|
||||
"""存活探针:进程能响应即视为存活。"""
|
||||
return {"status": "alive"}
|
||||
|
||||
|
||||
@health_router.get("/health/ready", include_in_schema=True)
|
||||
async def readiness():
|
||||
"""就绪探针:检查关键依赖(Postgres / GSM Actor)是否可达。"""
|
||||
checks = {"postgres": False, "global_state_machine": False}
|
||||
|
||||
try:
|
||||
postgres_database = ray_actor_hook("postgres_database").postgres_database
|
||||
await postgres_database.ping.remote()
|
||||
checks["postgres"] = True
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
try:
|
||||
gsm = ray_actor_hook("global_state_machine").global_state_machine
|
||||
await gsm.get_skill_list.remote()
|
||||
checks["global_state_machine"] = True
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
all_ok = all(checks.values())
|
||||
return JSONResponse(
|
||||
status_code=200 if all_ok else 503,
|
||||
content={"status": "ready" if all_ok else "not_ready", "checks": checks},
|
||||
)
|
||||
@@ -13,5 +13,6 @@
|
||||
# limitations under the License.
|
||||
|
||||
from .frontend import client_router
|
||||
from .onebot import onebot_router
|
||||
|
||||
__all__ = ["client_router"]
|
||||
__all__ = ["client_router", "onebot_router"]
|
||||
|
||||
@@ -16,6 +16,10 @@ from fastapi import APIRouter, Depends, HTTPException, UploadFile, File
|
||||
from pydantic import BaseModel
|
||||
from kilostar.utils.access import Accessor, TokenData
|
||||
from kilostar.utils.ray_hook import ray_actor_hook
|
||||
from kilostar.core.individual.regulatory_node.template import (
|
||||
MessageRequest,
|
||||
MessageResponse,
|
||||
)
|
||||
import os
|
||||
import anyio
|
||||
from kilostar.utils.logger import get_logger
|
||||
@@ -39,12 +43,18 @@ async def create_message(
|
||||
logger.info("收到消息,来源:客户端")
|
||||
logger.debug(f"消息内容:{message.message}")
|
||||
regulatory_node = ray_actor_hook("regulatory_node").regulatory_node
|
||||
reply = await regulatory_node.handle_client_message.remote(
|
||||
user_id=token_data.user_id,
|
||||
msg_request = MessageRequest(
|
||||
platform="client",
|
||||
user_name=token_data.username,
|
||||
platform_id=token_data.user_id,
|
||||
message=message.message,
|
||||
)
|
||||
return {"message": reply}
|
||||
result = await regulatory_node.working.remote(msg_request)
|
||||
if isinstance(result, MessageResponse):
|
||||
return {"message": result.reply_message}
|
||||
if isinstance(result, str):
|
||||
return {"message": result}
|
||||
return {"message": ""}
|
||||
|
||||
|
||||
@client_router.post("/upload")
|
||||
|
||||
@@ -0,0 +1,279 @@
|
||||
# Copyright 2026 zhaoxi826
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""OneBot v11 协议适配器。
|
||||
|
||||
接收来自 OneBot 实现端(NapCat / go-cqhttp / Lagrange.OneBot 等)的事件上报,
|
||||
把消息事件翻译成 ``MessageRequest`` 投递给 RegulatoryNode。同时支持两种连接方式:
|
||||
|
||||
- HTTP 上报(POST ``/api/v1/adapter/onebot/event``):实现端把事件 POST 过来,
|
||||
通过返回体里的 ``reply`` 走 v11 "快速操作" 自动回包。
|
||||
- 反向 WebSocket(WS ``/api/v1/adapter/onebot/ws``):实现端主动建立长连接,
|
||||
服务端按 OneBot v11 反向 WS 规范返回 ``send_msg`` 等 action 主动回包。
|
||||
|
||||
模块还提供 ``send_message`` 工具函数,用 OneBot v11 HTTP API 主动给指定会话发消息。
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import os
|
||||
import uuid
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
import httpx
|
||||
from fastapi import APIRouter, Header, HTTPException, Request, WebSocket, WebSocketDisconnect
|
||||
|
||||
from kilostar.core.individual.regulatory_node.template import (
|
||||
MessageRequest,
|
||||
MessageResponse,
|
||||
)
|
||||
from kilostar.utils.logger import get_logger
|
||||
from kilostar.utils.ray_hook import ray_actor_hook
|
||||
|
||||
logger = get_logger("onebot")
|
||||
|
||||
onebot_router = APIRouter(prefix="/api/v1/adapter/onebot", tags=["onebot"])
|
||||
|
||||
|
||||
def _verify_token(token_from_header: Optional[str]) -> None:
|
||||
"""校验 OneBot 实现端在 ``Authorization`` 头里携带的 access_token。
|
||||
|
||||
若环境变量 ``ONEBOT_ACCESS_TOKEN`` 未设置则跳过校验。OneBot v11 规范要求
|
||||
格式为 ``Bearer <token>``,这里同时容忍只填 token 字符串本身的写法。
|
||||
"""
|
||||
expected = os.environ.get("ONEBOT_ACCESS_TOKEN")
|
||||
if not expected:
|
||||
return
|
||||
if not token_from_header:
|
||||
raise HTTPException(status_code=401, detail="Missing access_token")
|
||||
raw = token_from_header.removeprefix("Bearer ").removeprefix("Token ").strip()
|
||||
if raw != expected:
|
||||
raise HTTPException(status_code=401, detail="Invalid access_token")
|
||||
|
||||
|
||||
def _extract_plain_text(message: Any) -> str:
|
||||
"""把 OneBot 消息字段(字符串或 segment 数组)展平成纯文本。
|
||||
|
||||
OneBot v11 既支持 CQ 码字符串,也支持消息段数组形式;这里只抽取 ``text``
|
||||
段,其它段(图片/at/表情等)暂时丢弃。
|
||||
"""
|
||||
if isinstance(message, str):
|
||||
return message
|
||||
if isinstance(message, list):
|
||||
parts = []
|
||||
for seg in message:
|
||||
if isinstance(seg, dict) and seg.get("type") == "text":
|
||||
parts.append(seg.get("data", {}).get("text", ""))
|
||||
return "".join(parts)
|
||||
return ""
|
||||
|
||||
|
||||
async def _dispatch_event(payload: Dict[str, Any]) -> Optional[Dict[str, Any]]:
|
||||
"""把一次 OneBot 事件交给 RegulatoryNode 处理,返回快速操作字典或 ``None``。
|
||||
|
||||
仅处理 ``post_type == "message"`` 的私聊与群聊;元事件、通知、请求事件
|
||||
一律忽略。返回结果遵循 OneBot v11 "快速操作" 约定:
|
||||
|
||||
- 群聊:``{"reply": "...", "at_sender": False, "_target": {...}}``
|
||||
- 私聊:``{"reply": "...", "_target": {...}}``
|
||||
|
||||
其中 ``_target`` 仅供 ``_dispatch_via_ws`` 决定 send_msg 的入参;HTTP 模式下
|
||||
会被剔除后再返回给实现端。
|
||||
"""
|
||||
if payload.get("post_type") != "message":
|
||||
return None
|
||||
|
||||
message_type = payload.get("message_type") # private | group
|
||||
user_id = str(payload.get("user_id", ""))
|
||||
group_id = payload.get("group_id")
|
||||
raw_text = _extract_plain_text(payload.get("message", ""))
|
||||
sender = payload.get("sender") or {}
|
||||
user_name = (
|
||||
sender.get("card") or sender.get("nickname") or user_id or "onebot_user"
|
||||
)
|
||||
platform_id = (
|
||||
f"group:{group_id}" if message_type == "group" else f"private:{user_id}"
|
||||
)
|
||||
|
||||
if not raw_text.strip():
|
||||
return None
|
||||
|
||||
logger.info(
|
||||
f"[OneBot] {message_type} 消息 from {user_name}({user_id}) -> {raw_text!r}"
|
||||
)
|
||||
|
||||
msg_request = MessageRequest(
|
||||
platform="onebot",
|
||||
user_name=user_name,
|
||||
platform_id=platform_id,
|
||||
message=raw_text,
|
||||
)
|
||||
|
||||
try:
|
||||
regulatory_node = ray_actor_hook("regulatory_node").regulatory_node
|
||||
result = await regulatory_node.working.remote(msg_request)
|
||||
except Exception as e:
|
||||
logger.exception(f"[OneBot] RegulatoryNode 调用失败: {e}")
|
||||
return None
|
||||
|
||||
reply_text = ""
|
||||
if isinstance(result, MessageResponse):
|
||||
reply_text = result.reply_message or ""
|
||||
elif isinstance(result, str):
|
||||
reply_text = result
|
||||
|
||||
if not reply_text:
|
||||
return None
|
||||
|
||||
quick = {
|
||||
"reply": reply_text,
|
||||
"_target": {
|
||||
"message_type": message_type,
|
||||
"user_id": int(user_id) if user_id.isdigit() else user_id,
|
||||
"group_id": group_id,
|
||||
},
|
||||
}
|
||||
if message_type == "group":
|
||||
quick["at_sender"] = False
|
||||
return quick
|
||||
|
||||
|
||||
@onebot_router.post("/event")
|
||||
async def receive_event(
|
||||
request: Request,
|
||||
authorization: Optional[str] = Header(None),
|
||||
):
|
||||
"""HTTP 上报入口:接收 OneBot v11 事件并触发 RegulatoryNode。
|
||||
|
||||
若 RegulatoryNode 给出回复,会按 v11 "快速操作" 约定写到响应体里,由实现端
|
||||
自动发送。若不需要回复则返回 ``{"status": "ok"}``。
|
||||
"""
|
||||
_verify_token(authorization)
|
||||
payload: Dict[str, Any] = await request.json()
|
||||
quick = await _dispatch_event(payload)
|
||||
if not quick:
|
||||
return {"status": "ok"}
|
||||
quick.pop("_target", None)
|
||||
return quick
|
||||
|
||||
|
||||
async def _ws_call_action(
|
||||
ws: WebSocket, action: str, params: Dict[str, Any]
|
||||
) -> None:
|
||||
"""通过反向 WS 给实现端发送一次 action 调用,不等待响应。"""
|
||||
echo = uuid.uuid4().hex
|
||||
frame = {"action": action, "params": params, "echo": echo}
|
||||
await ws.send_text(json.dumps(frame, ensure_ascii=False))
|
||||
|
||||
|
||||
@onebot_router.websocket("/ws")
|
||||
async def reverse_websocket(
|
||||
websocket: WebSocket,
|
||||
authorization: Optional[str] = Header(None),
|
||||
x_self_id: Optional[str] = Header(None),
|
||||
):
|
||||
"""反向 WebSocket 入口:接受 OneBot 实现端主动建立的长连接。
|
||||
|
||||
握手时校验 ``Authorization`` 头;之后循环读 JSON 帧。带 ``post_type`` 的
|
||||
视为事件上报,调用 RegulatoryNode 处理后通过 ``send_msg`` action 主动回包;
|
||||
带 ``echo`` 的视为 action 响应,目前直接丢弃(后续若需可在此处认领 future)。
|
||||
"""
|
||||
try:
|
||||
_verify_token(authorization)
|
||||
except HTTPException:
|
||||
await websocket.close(code=4401)
|
||||
return
|
||||
|
||||
await websocket.accept()
|
||||
logger.info(f"[OneBot] reverse WS connected (self_id={x_self_id})")
|
||||
|
||||
try:
|
||||
while True:
|
||||
text = await websocket.receive_text()
|
||||
try:
|
||||
payload = json.loads(text)
|
||||
except json.JSONDecodeError:
|
||||
logger.warning(f"[OneBot] invalid JSON frame: {text[:200]}")
|
||||
continue
|
||||
|
||||
# action 响应帧(含 echo 而无 post_type),目前忽略
|
||||
if "post_type" not in payload and "echo" in payload:
|
||||
continue
|
||||
|
||||
quick = await _dispatch_event(payload)
|
||||
if not quick:
|
||||
continue
|
||||
|
||||
target = quick.get("_target", {})
|
||||
params: Dict[str, Any] = {"message": quick["reply"]}
|
||||
if target.get("message_type") == "group" and target.get("group_id"):
|
||||
params["group_id"] = target["group_id"]
|
||||
action = "send_group_msg"
|
||||
else:
|
||||
params["user_id"] = target.get("user_id")
|
||||
action = "send_private_msg"
|
||||
|
||||
asyncio.create_task(_ws_call_action(websocket, action, params))
|
||||
except WebSocketDisconnect:
|
||||
logger.info("[OneBot] reverse WS disconnected")
|
||||
except Exception as e:
|
||||
logger.exception(f"[OneBot] reverse WS error: {e}")
|
||||
try:
|
||||
await websocket.close()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
|
||||
async def send_message(
|
||||
user_id: Optional[int] = None,
|
||||
group_id: Optional[int] = None,
|
||||
message: str = "",
|
||||
*,
|
||||
base_url: Optional[str] = None,
|
||||
access_token: Optional[str] = None,
|
||||
) -> Dict[str, Any]:
|
||||
"""通过 OneBot v11 HTTP API 主动给私聊或群聊发送一条消息。
|
||||
|
||||
Args:
|
||||
user_id: 目标 QQ 号;与 ``group_id`` 二选一。
|
||||
group_id: 目标群号;与 ``user_id`` 二选一。
|
||||
message: 要发送的消息文本。
|
||||
base_url: OneBot 实现端的 HTTP API 地址;默认读取 ``ONEBOT_HTTP_URL``。
|
||||
access_token: 鉴权 token;默认读取 ``ONEBOT_ACCESS_TOKEN``。
|
||||
|
||||
Returns:
|
||||
OneBot HTTP API 的原始响应 JSON。
|
||||
"""
|
||||
if not user_id and not group_id:
|
||||
raise ValueError("必须指定 user_id 或 group_id 之一")
|
||||
|
||||
base = base_url or os.environ.get("ONEBOT_HTTP_URL", "http://127.0.0.1:5700")
|
||||
token = access_token or os.environ.get("ONEBOT_ACCESS_TOKEN")
|
||||
|
||||
if group_id:
|
||||
action = "send_group_msg"
|
||||
body = {"group_id": int(group_id), "message": message}
|
||||
else:
|
||||
action = "send_private_msg"
|
||||
body = {"user_id": int(user_id), "message": message}
|
||||
|
||||
headers: Dict[str, str] = {}
|
||||
if token:
|
||||
headers["Authorization"] = f"Bearer {token}"
|
||||
|
||||
url = f"{base.rstrip('/')}/{action}"
|
||||
async with httpx.AsyncClient(timeout=15.0) as client:
|
||||
resp = await client.post(url, json=body, headers=headers)
|
||||
resp.raise_for_status()
|
||||
return resp.json()
|
||||
@@ -14,11 +14,10 @@
|
||||
|
||||
from fastapi import APIRouter, Depends
|
||||
from pydantic import BaseModel
|
||||
from typing import Literal
|
||||
from typing import Any, Dict, Literal
|
||||
from kilostar.utils.access import TokenData, Accessor
|
||||
from kilostar.utils.check_user.role_check import RoleChecker
|
||||
from kilostar.core.postgres_database.model import UserAuthority
|
||||
from typing import Dict
|
||||
from kilostar.core.global_state_machine.model_provider.base_provider import Provider
|
||||
from kilostar.utils.ray_hook import ray_actor_hook
|
||||
|
||||
@@ -50,16 +49,27 @@ async def create_provider(
|
||||
)
|
||||
|
||||
|
||||
def _mask_apikey(key: str) -> str:
|
||||
if not key or len(key) <= 8:
|
||||
return "***"
|
||||
return key[:4] + "***" + key[-4:]
|
||||
|
||||
|
||||
@provider_router.get("/list")
|
||||
async def get_provider_list(
|
||||
_: TokenData = Depends(Accessor.get_current_user),
|
||||
) -> Dict[str, Dict[str, Provider]]:
|
||||
"""返回当前所有已注册的 Provider,前端用以展示模型清单。"""
|
||||
) -> Dict[str, Any]:
|
||||
"""返回当前所有已注册的 Provider,前端用以展示模型清单。apikey 脱敏。"""
|
||||
global_state_machine = ray_actor_hook("global_state_machine").global_state_machine
|
||||
provider_list: Dict[
|
||||
str, Provider
|
||||
] = await global_state_machine.get_provider_list.remote()
|
||||
return {"provider_list": provider_list}
|
||||
masked = {}
|
||||
for title, p in provider_list.items():
|
||||
d = p.model_dump() if hasattr(p, "model_dump") else dict(p)
|
||||
d["provider_apikey"] = _mask_apikey(d.get("provider_apikey", ""))
|
||||
masked[title] = d
|
||||
return {"provider_list": masked}
|
||||
|
||||
|
||||
@provider_router.delete("/{provider_title}")
|
||||
|
||||
+274
-5
@@ -12,13 +12,15 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import Any, Dict, List, Optional
|
||||
from pydantic import BaseModel
|
||||
import viceroy
|
||||
from kilostar.utils.ray_hook import ray_actor_hook
|
||||
from fastapi import APIRouter, Depends
|
||||
from fastapi import APIRouter, Depends, HTTPException
|
||||
from kilostar.utils.access import TokenData
|
||||
from kilostar.utils.check_user.role_check import RoleChecker
|
||||
from kilostar.core.postgres_database.model import UserAuthority
|
||||
from kilostar.utils.mcp_helper import list_mcp_tools_from_gsm
|
||||
|
||||
resource_router = APIRouter(prefix="/api/v1/resource")
|
||||
|
||||
@@ -30,13 +32,24 @@ class Skill(BaseModel):
|
||||
path: str | None
|
||||
|
||||
|
||||
class MCPServerConfig(BaseModel):
|
||||
"""``POST /mcp`` 入参:MCP 服务器配置。"""
|
||||
|
||||
name: str
|
||||
transport: str = "stdio" # stdio | sse | http
|
||||
command: str | None = None
|
||||
args: list[str] | None = None
|
||||
url: str | None = None
|
||||
tool_prefix: str | None = None
|
||||
env: Dict[str, str] | None = None
|
||||
|
||||
|
||||
@resource_router.post("/skill")
|
||||
async def install_skill(
|
||||
skill: Skill, _: TokenData = Depends(RoleChecker(allowed_roles=UserAuthority.USER))
|
||||
):
|
||||
"""通过 viceroy 把 skill 仓库克隆到 ``plugin/skill``,并在状态机中登记。"""
|
||||
global_state_machine = ray_actor_hook("global_state_machine").global_state_machine
|
||||
# noinspection PyUnresolvedReferences
|
||||
import os
|
||||
|
||||
skill_output_dir = os.path.abspath(
|
||||
@@ -73,19 +86,275 @@ async def delete_skill(
|
||||
):
|
||||
"""从状态机中移除 skill 注册项;不会删除磁盘上的代码文件。"""
|
||||
global_state_machine = ray_actor_hook("global_state_machine").global_state_machine
|
||||
# Note: this only removes it from the state machine manager.
|
||||
await global_state_machine.remove_skill.remote(skill_name)
|
||||
return {"message": "success"}
|
||||
|
||||
|
||||
# ─── MCP Server Management ───
|
||||
|
||||
@resource_router.post("/mcp")
|
||||
async def add_mcp_server(
|
||||
config: MCPServerConfig,
|
||||
_: TokenData = Depends(RoleChecker(allowed_roles=UserAuthority.SUPER_ADMINISTRATOR)),
|
||||
):
|
||||
"""注册一个 MCP 服务器到全局状态机。"""
|
||||
global_state_machine = ray_actor_hook("global_state_machine").global_state_machine
|
||||
import uuid
|
||||
|
||||
server_id = str(uuid.uuid4())[:8]
|
||||
cfg_dict = config.model_dump(exclude_none=True)
|
||||
await global_state_machine.add_mcp_server.remote(server_id, cfg_dict)
|
||||
return {"server_id": server_id, "message": "MCP server registered"}
|
||||
|
||||
|
||||
@resource_router.get("/mcp")
|
||||
async def list_mcp_servers(
|
||||
_: TokenData = Depends(RoleChecker(allowed_roles=UserAuthority.USER)),
|
||||
):
|
||||
"""返回已注册的全部 MCP 服务器配置;env 中的敏感字段脱敏。"""
|
||||
global_state_machine = ray_actor_hook("global_state_machine").global_state_machine
|
||||
servers = await global_state_machine.list_mcp_servers.remote()
|
||||
for s in servers:
|
||||
if "env" in s and isinstance(s["env"], dict):
|
||||
s["env"] = _mask_config(s["env"])
|
||||
return {"servers": servers}
|
||||
|
||||
|
||||
@resource_router.delete("/mcp/{server_id}")
|
||||
async def delete_mcp_server(
|
||||
server_id: str,
|
||||
_: TokenData = Depends(RoleChecker(allowed_roles=UserAuthority.SUPER_ADMINISTRATOR)),
|
||||
):
|
||||
"""从状态机中移除一个 MCP 服务器配置。"""
|
||||
global_state_machine = ray_actor_hook("global_state_machine").global_state_machine
|
||||
ok = await global_state_machine.delete_mcp_server.remote(server_id)
|
||||
if not ok:
|
||||
raise HTTPException(status_code=404, detail="MCP server not found")
|
||||
return {"message": "success"}
|
||||
|
||||
|
||||
# ─── Tool Management ───
|
||||
|
||||
@resource_router.get("/tool")
|
||||
async def get_tools(
|
||||
_: TokenData = Depends(RoleChecker(allowed_roles=UserAuthority.USER)),
|
||||
):
|
||||
"""汇总各作用域 tool_mapper,返回去重后的工具名称列表。"""
|
||||
"""返回按分类聚合的工具信息(包含系统工具、搜索工具、MCP 工具等)。
|
||||
|
||||
其中 ``mcp_servers`` 会现场尝试连接每个已注册的 MCP 服务器并列出它们暴露的
|
||||
工具名,便于前端展示;任意一台 MCP server 不可达不影响其他工具的返回。
|
||||
"""
|
||||
global_state_machine = ray_actor_hook("global_state_machine").global_state_machine
|
||||
tool_mapper = await global_state_machine.get_tool_mapper.remote()
|
||||
categories = await global_state_machine.get_tool_categories.remote()
|
||||
|
||||
all_tool_names = set()
|
||||
for scope_tools in tool_mapper.values():
|
||||
all_tool_names.update(scope_tools.keys())
|
||||
return {"tools": list(all_tool_names)}
|
||||
|
||||
mcp_servers = await list_mcp_tools_from_gsm()
|
||||
|
||||
return {
|
||||
"tools": list(all_tool_names),
|
||||
"categories": categories,
|
||||
"mcp_servers": mcp_servers,
|
||||
}
|
||||
|
||||
|
||||
# ─── Tool Config Management(Tavily API key 等运行期配置)───
|
||||
|
||||
|
||||
def _mask_secret(value: Any) -> Any:
|
||||
"""对像 ``api_key`` / ``token`` / ``secret`` 这种敏感字段做简单脱敏。"""
|
||||
if not isinstance(value, str) or not value:
|
||||
return value
|
||||
if len(value) <= 8:
|
||||
return "***"
|
||||
return value[:4] + "***" + value[-4:]
|
||||
|
||||
|
||||
def _mask_config(config: Dict[str, Any]) -> Dict[str, Any]:
|
||||
masked: Dict[str, Any] = {}
|
||||
for k, v in config.items():
|
||||
if any(s in k.lower() for s in ("key", "token", "secret", "password")):
|
||||
masked[k] = _mask_secret(v)
|
||||
else:
|
||||
masked[k] = v
|
||||
return masked
|
||||
|
||||
|
||||
class ToolConfigUpdate(BaseModel):
|
||||
"""``PUT /tool/config/{tool_name}`` 入参:要写入的工具配置 KV。"""
|
||||
|
||||
config: Dict[str, Any]
|
||||
|
||||
|
||||
@resource_router.get("/tool/config")
|
||||
async def list_tool_configs(
|
||||
_: TokenData = Depends(RoleChecker(allowed_roles=UserAuthority.SUPER_ADMINISTRATOR)),
|
||||
):
|
||||
"""列出所有工具运行期配置;敏感字段会被脱敏。"""
|
||||
global_state_machine = ray_actor_hook("global_state_machine").global_state_machine
|
||||
raw = await global_state_machine.list_tool_configs.remote()
|
||||
return {
|
||||
"configs": {name: _mask_config(cfg) for name, cfg in raw.items()},
|
||||
}
|
||||
|
||||
|
||||
@resource_router.get("/tool/config/{tool_name}")
|
||||
async def get_tool_config(
|
||||
tool_name: str,
|
||||
_: TokenData = Depends(RoleChecker(allowed_roles=UserAuthority.SUPER_ADMINISTRATOR)),
|
||||
):
|
||||
"""按工具名取出脱敏后的配置。"""
|
||||
global_state_machine = ray_actor_hook("global_state_machine").global_state_machine
|
||||
raw = await global_state_machine.get_tool_config.remote(tool_name)
|
||||
return {"tool_name": tool_name, "config": _mask_config(raw)}
|
||||
|
||||
|
||||
@resource_router.put("/tool/config/{tool_name}")
|
||||
async def set_tool_config(
|
||||
tool_name: str,
|
||||
body: ToolConfigUpdate,
|
||||
_: TokenData = Depends(RoleChecker(allowed_roles=UserAuthority.SUPER_ADMINISTRATOR)),
|
||||
):
|
||||
"""写入/覆盖某工具的运行期配置(如 ``tavily_search`` 的 ``api_key``)。"""
|
||||
global_state_machine = ray_actor_hook("global_state_machine").global_state_machine
|
||||
await global_state_machine.set_tool_config.remote(tool_name, body.config)
|
||||
return {"message": "success"}
|
||||
|
||||
|
||||
@resource_router.delete("/tool/config/{tool_name}")
|
||||
async def delete_tool_config(
|
||||
tool_name: str,
|
||||
_: TokenData = Depends(RoleChecker(allowed_roles=UserAuthority.SUPER_ADMINISTRATOR)),
|
||||
):
|
||||
"""删除某工具的运行期配置。"""
|
||||
global_state_machine = ray_actor_hook("global_state_machine").global_state_machine
|
||||
ok = await global_state_machine.delete_tool_config.remote(tool_name)
|
||||
if not ok:
|
||||
raise HTTPException(status_code=404, detail="Tool config not found")
|
||||
return {"message": "success"}
|
||||
|
||||
|
||||
# ─── Custom Toolset Management ───
|
||||
|
||||
|
||||
class CustomToolsetCreate(BaseModel):
|
||||
name: str
|
||||
tools: List[str]
|
||||
description: Optional[str] = None
|
||||
|
||||
|
||||
class CustomToolsetUpdate(BaseModel):
|
||||
name: Optional[str] = None
|
||||
tools: Optional[List[str]] = None
|
||||
description: Optional[str] = None
|
||||
|
||||
|
||||
async def _assert_toolset_owner_or_admin(
|
||||
toolset: Dict[str, Any], token_data: TokenData
|
||||
) -> None:
|
||||
"""校验 toolset 归属:非 owner 且非管理员则抛 403。"""
|
||||
from kilostar.utils.check_user.role_check import get_authority
|
||||
|
||||
if toolset.get("owner_id") == token_data.user_id:
|
||||
return
|
||||
authority = await get_authority(token_data.user_id)
|
||||
if authority >= UserAuthority.ADMINISTRATOR:
|
||||
return
|
||||
raise HTTPException(status_code=403, detail="无权访问此自定义工具组")
|
||||
|
||||
|
||||
@resource_router.post("/custom-toolset")
|
||||
async def create_custom_toolset(
|
||||
body: CustomToolsetCreate,
|
||||
token_data: TokenData = Depends(RoleChecker(allowed_roles=UserAuthority.USER)),
|
||||
):
|
||||
global_state_machine = ray_actor_hook("global_state_machine").global_state_machine
|
||||
import uuid
|
||||
|
||||
toolset_id = str(uuid.uuid4())[:8]
|
||||
try:
|
||||
saved = await global_state_machine.add_custom_toolset.remote(
|
||||
toolset_id=toolset_id,
|
||||
name=body.name,
|
||||
tools=body.tools,
|
||||
description=body.description,
|
||||
owner_id=token_data.user_id,
|
||||
)
|
||||
except ValueError as e:
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
return {"toolset_id": toolset_id, "toolset": saved}
|
||||
|
||||
|
||||
@resource_router.get("/custom-toolset")
|
||||
async def list_custom_toolsets(
|
||||
token_data: TokenData = Depends(RoleChecker(allowed_roles=UserAuthority.USER)),
|
||||
):
|
||||
"""列出工具组:USER 只能看到自己的;ADMIN 及以上可看全部。"""
|
||||
from kilostar.utils.check_user.role_check import get_authority
|
||||
|
||||
global_state_machine = ray_actor_hook("global_state_machine").global_state_machine
|
||||
toolsets = await global_state_machine.list_custom_toolsets.remote()
|
||||
authority = await get_authority(token_data.user_id)
|
||||
if authority < UserAuthority.ADMINISTRATOR:
|
||||
toolsets = [t for t in toolsets if t.get("owner_id") == token_data.user_id]
|
||||
return {"toolsets": toolsets}
|
||||
|
||||
|
||||
@resource_router.get("/custom-toolset/{toolset_id}")
|
||||
async def get_custom_toolset(
|
||||
toolset_id: str,
|
||||
token_data: TokenData = Depends(RoleChecker(allowed_roles=UserAuthority.USER)),
|
||||
):
|
||||
global_state_machine = ray_actor_hook("global_state_machine").global_state_machine
|
||||
ts = await global_state_machine.get_custom_toolset.remote(toolset_id)
|
||||
if not ts:
|
||||
raise HTTPException(status_code=404, detail="Custom toolset not found")
|
||||
await _assert_toolset_owner_or_admin(ts, token_data)
|
||||
return ts
|
||||
|
||||
|
||||
@resource_router.put("/custom-toolset/{toolset_id}")
|
||||
async def update_custom_toolset(
|
||||
toolset_id: str,
|
||||
body: CustomToolsetUpdate,
|
||||
token_data: TokenData = Depends(RoleChecker(allowed_roles=UserAuthority.USER)),
|
||||
):
|
||||
global_state_machine = ray_actor_hook("global_state_machine").global_state_machine
|
||||
existing = await global_state_machine.get_custom_toolset.remote(toolset_id)
|
||||
if not existing:
|
||||
raise HTTPException(status_code=404, detail="Custom toolset not found")
|
||||
await _assert_toolset_owner_or_admin(existing, token_data)
|
||||
name = body.name if body.name is not None else existing["name"]
|
||||
tools = body.tools if body.tools is not None else existing["tools"]
|
||||
description = body.description if body.description is not None else existing.get("description")
|
||||
try:
|
||||
saved = await global_state_machine.add_custom_toolset.remote(
|
||||
toolset_id=toolset_id,
|
||||
name=name,
|
||||
tools=tools,
|
||||
description=description,
|
||||
owner_id=existing.get("owner_id", token_data.user_id),
|
||||
)
|
||||
except ValueError as e:
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
return {"toolset": saved}
|
||||
|
||||
|
||||
@resource_router.delete("/custom-toolset/{toolset_id}")
|
||||
async def delete_custom_toolset(
|
||||
toolset_id: str,
|
||||
token_data: TokenData = Depends(RoleChecker(allowed_roles=UserAuthority.USER)),
|
||||
):
|
||||
"""删除工具组:USER 只能删自己的;ADMIN 及以上可删任意。"""
|
||||
global_state_machine = ray_actor_hook("global_state_machine").global_state_machine
|
||||
existing = await global_state_machine.get_custom_toolset.remote(toolset_id)
|
||||
if not existing:
|
||||
raise HTTPException(status_code=404, detail="Custom toolset not found")
|
||||
await _assert_toolset_owner_or_admin(existing, token_data)
|
||||
ok = await global_state_machine.delete_custom_toolset.remote(toolset_id)
|
||||
if not ok:
|
||||
raise HTTPException(status_code=404, detail="Custom toolset not found")
|
||||
return {"message": "success"}
|
||||
|
||||
@@ -119,3 +119,89 @@ async def get_workflow_detail(
|
||||
"steps": steps,
|
||||
"context_blackboard": context.blackboard if context else {},
|
||||
}
|
||||
|
||||
|
||||
@workflow_router.post("/{trace_id}/resume")
|
||||
async def resume_workflow(
|
||||
trace_id: str,
|
||||
token_data: TokenData = Depends(Accessor.get_current_user),
|
||||
):
|
||||
"""从 ``workflow_graph_state`` 持久化恢复一个被中断/挂起的工作流。
|
||||
|
||||
新 fire 一个 ray task,task 入口的 ``hydrate`` 检查会自动走 resume 路径
|
||||
把剩余节点跑完。
|
||||
"""
|
||||
postgres_database = ray_actor_hook("postgres_database").postgres_database
|
||||
wf = await postgres_database.get_workflow.remote(trace_id)
|
||||
if not wf:
|
||||
raise HTTPException(status_code=404, detail="Workflow not found")
|
||||
if getattr(wf, "user_id", None) != token_data.user_id:
|
||||
raise HTTPException(status_code=403, detail="Forbidden")
|
||||
|
||||
record = await postgres_database.get_workflow_graph_state.remote(trace_id)
|
||||
if record is None:
|
||||
raise HTTPException(
|
||||
status_code=409, detail="该工作流没有可恢复的图持久化记录"
|
||||
)
|
||||
|
||||
global_workflow_manager = ray_actor_hook(
|
||||
"global_workflow_manager"
|
||||
).global_workflow_manager
|
||||
await global_workflow_manager.create_trace.remote(trace_id)
|
||||
|
||||
from kilostar.core.work.workflow.workflow_engine import run_workflow_task
|
||||
|
||||
# workflow_data 在 resume 路径上不会被使用(hydrate 会走 resume 分支),
|
||||
# 这里给个空 dict 占位即可
|
||||
run_workflow_task.remote({}, trace_id)
|
||||
return {"trace_id": trace_id, "status": "resuming"}
|
||||
|
||||
|
||||
@workflow_router.get("/{trace_id}/graph")
|
||||
async def get_workflow_graph_mermaid(
|
||||
trace_id: str,
|
||||
token_data: TokenData = Depends(Accessor.get_current_user),
|
||||
):
|
||||
"""返回当前 workflow 引擎的 mermaid 图源码(节点拓扑)。
|
||||
|
||||
拓扑本身对所有 trace 是同一份;但如果该 trace 已经有 ``workflow_graph_state``
|
||||
持久化记录,会读出 history 里"已经成功跑过的节点"作为 ``highlighted_nodes``
|
||||
传给 mermaid,前端拿到的 mermaid 源码会自带 visited 节点高亮。
|
||||
"""
|
||||
from kilostar.core.work.workflow.workflow_engine import workflow_graph
|
||||
|
||||
postgres_database = ray_actor_hook("postgres_database").postgres_database
|
||||
wf = await postgres_database.get_workflow.remote(trace_id)
|
||||
if not wf:
|
||||
raise HTTPException(status_code=404, detail="Workflow not found")
|
||||
if getattr(wf, "user_id", None) != token_data.user_id:
|
||||
raise HTTPException(status_code=403, detail="Forbidden")
|
||||
|
||||
visited: list[str] = []
|
||||
record = await postgres_database.get_workflow_graph_state.remote(trace_id)
|
||||
if record is not None:
|
||||
history = getattr(record, "history", None) or []
|
||||
# history 里每条 NodeSnapshot.id 形如 "ClassName:hash",截前缀作为 NodeIdent
|
||||
# 只取 status == "success" 的节点(避免 "created" / "running" 带噪声)
|
||||
seen: set[str] = set()
|
||||
for entry in history:
|
||||
if not isinstance(entry, dict):
|
||||
continue
|
||||
if entry.get("kind") != "node":
|
||||
continue
|
||||
if entry.get("status") != "success":
|
||||
continue
|
||||
sid = entry.get("id") or ""
|
||||
cls_name = sid.split(":", 1)[0] if sid else ""
|
||||
if cls_name and cls_name not in seen:
|
||||
seen.add(cls_name)
|
||||
visited.append(cls_name)
|
||||
|
||||
try:
|
||||
if visited:
|
||||
mermaid = workflow_graph.mermaid_code(highlighted_nodes=visited)
|
||||
else:
|
||||
mermaid = workflow_graph.mermaid_code()
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=f"mermaid 生成失败: {e}")
|
||||
return {"trace_id": trace_id, "mermaid": mermaid, "visited": visited}
|
||||
|
||||
Reference in New Issue
Block a user