feat(system):优化后端
1.新增后端测试 2.增加了后端的加密 3.增加了i18n(国际化)
This commit is contained in:
@@ -12,6 +12,8 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import Sequence, Any
|
||||
|
||||
from pydantic_ai import Agent
|
||||
from pydantic_ai.models.openai import OpenAIChatModel
|
||||
from pydantic_ai.models.anthropic import AnthropicModel
|
||||
@@ -20,6 +22,8 @@ from pydantic_ai.providers.openai import OpenAIProvider
|
||||
from pydantic_ai.providers.anthropic import AnthropicProvider
|
||||
from pydantic_ai.providers.deepseek import DeepSeekProvider
|
||||
from pydantic_ai.providers.google import GoogleProvider
|
||||
from pydantic_ai.toolsets import AbstractToolset
|
||||
|
||||
from kilostar.core.global_state_machine.model_provider import Provider
|
||||
from kilostar.utils.agent_model import ResponseModel, DepsModel
|
||||
from kilostar.utils.error import ModelNotExistError
|
||||
@@ -30,6 +34,7 @@ class AgentFactory:
|
||||
|
||||
支持 openai / claude / deepseek / gemini 四类后端,差异通过
|
||||
``_models_mapping`` 中的 ``model_class`` + ``provider_class`` 键值对屏蔽。
|
||||
同时支持传入本地工具(tools)和外部工具集(toolsets),包括 MCP 服务器。
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
@@ -65,21 +70,22 @@ class AgentFactory:
|
||||
deps_type: DepsModel,
|
||||
agent_name: str,
|
||||
tools: list = None,
|
||||
toolsets: Sequence[AbstractToolset[Any]] = None,
|
||||
) -> Agent:
|
||||
"""
|
||||
create_agent方法,将输入的provider对象实例化为一个pydantic-ai的agent对象
|
||||
"""将输入的 provider 对象实例化为一个 pydantic-ai 的 agent 对象。
|
||||
|
||||
Args:
|
||||
provider: Provider对象,从global_state_machine中获取
|
||||
provider: Provider 对象,从 global_state_machine 中获取
|
||||
model_id: 模型名
|
||||
output_type: 输出格式
|
||||
system_prompt: 系统提示词
|
||||
deps_type: 依赖类型,在agent运行时动态输入的格式化消息
|
||||
agent_name: agent的名字
|
||||
tools: 工具列表
|
||||
deps_type: 依赖类型,在 agent 运行时动态输入的格式化消息
|
||||
agent_name: agent 的名字
|
||||
tools: 本地工具函数列表
|
||||
toolsets: 外部工具集列表(包括 MCP 服务器等 AbstractToolset 实例)
|
||||
|
||||
Returns:
|
||||
返回被实例化的pydantic-ai的Agent对象
|
||||
被实例化的 pydantic-ai 的 Agent 对象
|
||||
"""
|
||||
if model_id not in provider.provider_models:
|
||||
raise ModelNotExistError("模型不存在")
|
||||
@@ -109,13 +115,14 @@ class AgentFactory:
|
||||
else:
|
||||
model = model_class(model_id, provider=model_provider)
|
||||
|
||||
# 创建 Agent
|
||||
# 创建 Agent,同时传入 tools 和 toolsets
|
||||
agent = Agent(
|
||||
model=model,
|
||||
name=agent_name,
|
||||
system_prompt=system_prompt,
|
||||
output_type=output_type,
|
||||
deps_type=deps_type,
|
||||
tools=tools,
|
||||
tools=tools or [],
|
||||
toolsets=toolsets or [],
|
||||
)
|
||||
return agent
|
||||
|
||||
+82
-49
@@ -16,6 +16,7 @@ import os
|
||||
from typing import Dict
|
||||
|
||||
from fastapi import FastAPI, WebSocket, Request
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from fastapi.responses import FileResponse, JSONResponse
|
||||
from fastapi.staticfiles import StaticFiles
|
||||
from ray import serve
|
||||
@@ -23,26 +24,68 @@ from ray import serve
|
||||
from .agent import agent_router
|
||||
from .auth import auth_router
|
||||
from .cluster import cluster_router
|
||||
from .health import health_router
|
||||
from .platform.frontend import client_router
|
||||
from .platform.onebot import onebot_router
|
||||
from .provider import provider_router
|
||||
from .resource import resource_router
|
||||
from .workflow import workflow_router
|
||||
from .chat import chat_router
|
||||
from kilostar.utils.error import (
|
||||
DemandError,
|
||||
ModelNotExistError,
|
||||
UserError,
|
||||
UserNotExistError,
|
||||
UserPasswordError,
|
||||
ProviderError,
|
||||
ProviderNotExistError,
|
||||
WorkflowError,
|
||||
WorkflowExit,
|
||||
KiloStarError,
|
||||
BusinessError,
|
||||
InfraError,
|
||||
)
|
||||
from kilostar.utils.logger import get_logger
|
||||
from kilostar.utils.request_context import (
|
||||
bind_request_id,
|
||||
new_request_id,
|
||||
reset_request_id,
|
||||
)
|
||||
from kilostar.utils.i18n import t
|
||||
|
||||
_api_logger = get_logger("api")
|
||||
|
||||
|
||||
def _get_locale(request: Request) -> str | None:
|
||||
"""从请求头解析首选语言,供异常 handler 使用。"""
|
||||
return request.headers.get("accept-language") or None
|
||||
|
||||
app = FastAPI()
|
||||
|
||||
_cors_origins_env = os.environ.get("KILOSTAR_CORS_ORIGINS", "*")
|
||||
_cors_origins = [o.strip() for o in _cors_origins_env.split(",") if o.strip()]
|
||||
_allow_credentials = "*" not in _cors_origins
|
||||
app.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=_cors_origins,
|
||||
allow_credentials=_allow_credentials,
|
||||
allow_methods=["*"],
|
||||
allow_headers=["*"],
|
||||
)
|
||||
|
||||
|
||||
@app.middleware("http")
|
||||
async def request_id_middleware(request: Request, call_next):
|
||||
"""请求级 ``request_id`` 注入。
|
||||
|
||||
入口策略:``X-Request-Id`` 头存在则继承(便于网关/前端串联调用链),
|
||||
否则生成新的 UUID。退出时把它写到响应头,方便客户端日志对账。
|
||||
contextvars 让同一请求生命周期内所有协程的日志都自动带上这个 ID。
|
||||
"""
|
||||
incoming = request.headers.get("X-Request-Id", "").strip()
|
||||
request_id = incoming or new_request_id()
|
||||
token = bind_request_id(request_id)
|
||||
try:
|
||||
response = await call_next(request)
|
||||
finally:
|
||||
reset_request_id(token)
|
||||
response.headers["X-Request-Id"] = request_id
|
||||
return response
|
||||
|
||||
app.include_router(health_router) # 健康检查
|
||||
app.include_router(client_router) # 客户端路径
|
||||
app.include_router(onebot_router) # OneBot v11 路径
|
||||
app.include_router(auth_router) # 用户路径
|
||||
app.include_router(provider_router) # 供应商路径
|
||||
app.include_router(resource_router) # 资源路径
|
||||
@@ -52,49 +95,39 @@ app.include_router(workflow_router) # workflow路径
|
||||
app.include_router(chat_router) # chat路径
|
||||
|
||||
|
||||
@app.exception_handler(UserNotExistError)
|
||||
async def user_not_exist_handler(request: Request, exc: UserNotExistError):
|
||||
return JSONResponse(status_code=404, content={"message": "用户不存在"})
|
||||
@app.exception_handler(BusinessError)
|
||||
async def business_error_handler(request: Request, exc: BusinessError):
|
||||
"""业务可预期错误:按 ``http_status`` 返回 4xx,附 ``code`` + 异常消息。"""
|
||||
return JSONResponse(
|
||||
status_code=exc.http_status,
|
||||
content={"code": exc.code, "message": str(exc) or exc.code},
|
||||
)
|
||||
|
||||
|
||||
@app.exception_handler(UserPasswordError)
|
||||
async def user_password_handler(request: Request, exc: UserPasswordError):
|
||||
return JSONResponse(status_code=401, content={"message": "密码错误"})
|
||||
@app.exception_handler(InfraError)
|
||||
async def infra_error_handler(request: Request, exc: InfraError):
|
||||
"""系统失败错误:落日志后返回脱敏的 5xx。"""
|
||||
_api_logger.exception(
|
||||
f"InfraError on {request.method} {request.url.path}: {exc}"
|
||||
)
|
||||
loc = _get_locale(request)
|
||||
return JSONResponse(
|
||||
status_code=exc.http_status,
|
||||
content={"code": exc.code, "message": t("internal_error", accept_language=loc)},
|
||||
)
|
||||
|
||||
|
||||
@app.exception_handler(UserError)
|
||||
async def user_error_handler(request: Request, exc: UserError):
|
||||
return JSONResponse(status_code=400, content={"message": "用户相关错误"})
|
||||
|
||||
|
||||
@app.exception_handler(ProviderNotExistError)
|
||||
async def provider_not_exist_handler(request: Request, exc: ProviderNotExistError):
|
||||
return JSONResponse(status_code=404, content={"message": "服务提供商不存在"})
|
||||
|
||||
|
||||
@app.exception_handler(ProviderError)
|
||||
async def provider_error_handler(request: Request, exc: ProviderError):
|
||||
return JSONResponse(status_code=400, content={"message": "服务提供商错误"})
|
||||
|
||||
|
||||
@app.exception_handler(ModelNotExistError)
|
||||
async def model_not_exist_handler(request: Request, exc: ModelNotExistError):
|
||||
return JSONResponse(status_code=404, content={"message": "模型不存在"})
|
||||
|
||||
|
||||
@app.exception_handler(DemandError)
|
||||
async def demand_error_handler(request: Request, exc: DemandError):
|
||||
return JSONResponse(status_code=400, content={"message": "需求格式错误或不满足"})
|
||||
|
||||
|
||||
@app.exception_handler(WorkflowExit)
|
||||
async def workflow_exit_handler(request: Request, exc: WorkflowExit):
|
||||
return JSONResponse(status_code=400, content={"message": "工作流已退出"})
|
||||
|
||||
|
||||
@app.exception_handler(WorkflowError)
|
||||
async def workflow_error_handler(request: Request, exc: WorkflowError):
|
||||
return JSONResponse(status_code=500, content={"message": "工作流执行错误"})
|
||||
@app.exception_handler(Exception)
|
||||
async def unhandled_exception_handler(request: Request, exc: Exception):
|
||||
"""全局兜底:未预期的异常落日志后返回脱敏的 500,避免泄露 traceback。"""
|
||||
_api_logger.exception(
|
||||
f"Unhandled exception on {request.method} {request.url.path}: {exc}"
|
||||
)
|
||||
loc = _get_locale(request)
|
||||
return JSONResponse(
|
||||
status_code=500,
|
||||
content={"code": "internal_error", "message": t("internal_error", accept_language=loc)},
|
||||
)
|
||||
|
||||
|
||||
base_dir = os.path.dirname(
|
||||
@@ -129,7 +162,7 @@ if os.path.exists(frontend_dir):
|
||||
if os.path.exists(index_path):
|
||||
return FileResponse(index_path)
|
||||
return JSONResponse(
|
||||
status_code=404, content={"detail": "Frontend build not found"}
|
||||
status_code=404, content={"detail": t("frontend_not_found")}
|
||||
)
|
||||
else:
|
||||
import logging
|
||||
|
||||
+15
-2
@@ -15,7 +15,7 @@
|
||||
|
||||
from typing import Union
|
||||
from kilostar.utils.ray_hook import ray_actor_hook
|
||||
from fastapi import APIRouter, Depends
|
||||
from fastapi import APIRouter, Depends, Request
|
||||
from pydantic import BaseModel
|
||||
from kilostar.utils.access import Accessor, TokenData
|
||||
from kilostar.core.postgres_database.model import AgentType
|
||||
@@ -23,6 +23,8 @@ from fastapi import HTTPException
|
||||
from typing import Optional, List, Dict
|
||||
from kilostar.utils.check_user.role_check import RoleChecker
|
||||
from kilostar.core.postgres_database.model import UserAuthority
|
||||
from kilostar.utils.mcp_helper import get_all_toolsets_for_scope
|
||||
from kilostar.utils.i18n import t
|
||||
|
||||
agent_router = APIRouter(prefix="/api/v1/agent", tags=["agent"])
|
||||
|
||||
@@ -57,11 +59,13 @@ async def get_system_nodes(
|
||||
@agent_router.post("")
|
||||
async def load_agent(
|
||||
agent_register: Union[AgentRegister, AgentLocalRegister],
|
||||
request: Request,
|
||||
_: TokenData = Depends(RoleChecker(allowed_roles=UserAuthority.USER)),
|
||||
):
|
||||
"""加载/重载某个系统节点的 Agent:先持久化配置,再调用对应节点 Actor 的 ``create_agent``。"""
|
||||
global_state_machine = ray_actor_hook("global_state_machine").global_state_machine
|
||||
postgres_database = ray_actor_hook("postgres_database").postgres_database
|
||||
accept_lang = request.headers.get("accept-language", "")
|
||||
|
||||
if isinstance(agent_register, AgentLocalRegister):
|
||||
pass
|
||||
@@ -75,7 +79,10 @@ async def load_agent(
|
||||
agent_register.tools,
|
||||
)
|
||||
|
||||
match agent_register.individual_name:
|
||||
scope = agent_register.individual_name
|
||||
toolsets = await get_all_toolsets_for_scope(scope)
|
||||
|
||||
match scope:
|
||||
case "regulatory_node":
|
||||
node = ray_actor_hook("regulatory_node").regulatory_node
|
||||
await node.create_agent.remote(
|
||||
@@ -83,6 +90,8 @@ async def load_agent(
|
||||
agent_register.provider_title,
|
||||
agent_register.model_id,
|
||||
agent_register.tools,
|
||||
toolsets,
|
||||
accept_lang,
|
||||
)
|
||||
case "consciousness_node":
|
||||
node = ray_actor_hook("consciousness_node").consciousness_node
|
||||
@@ -91,6 +100,8 @@ async def load_agent(
|
||||
agent_register.provider_title,
|
||||
agent_register.model_id,
|
||||
agent_register.tools,
|
||||
toolsets,
|
||||
accept_lang,
|
||||
)
|
||||
case "control_node":
|
||||
node = ray_actor_hook("control_node").control_node
|
||||
@@ -99,6 +110,8 @@ async def load_agent(
|
||||
agent_register.provider_title,
|
||||
agent_register.model_id,
|
||||
agent_register.tools,
|
||||
toolsets,
|
||||
accept_lang,
|
||||
)
|
||||
case _:
|
||||
pass
|
||||
|
||||
+35
-6
@@ -16,10 +16,40 @@ from fastapi import APIRouter, Depends
|
||||
from pydantic import BaseModel
|
||||
from kilostar.utils.ray_hook import ray_actor_hook
|
||||
from kilostar.utils.access import Accessor, TokenData
|
||||
from kilostar.core.individual.regulatory_node.template import (
|
||||
MessageRequest,
|
||||
MessageResponse,
|
||||
)
|
||||
|
||||
chat_router = APIRouter(prefix="/api/v1/chat", tags=["chat"])
|
||||
|
||||
|
||||
def _extract_reply(resp: MessageResponse | None) -> str | None:
|
||||
"""从 RegulatoryNode.working 的输出里取出对用户的回复文本。
|
||||
|
||||
RegulatoryNode 现在的 output_type 只剩 ``MessageResponse``(聊天/简单任务/汇报),
|
||||
没有则视为节点降级为静默——上层不写回 chat history。
|
||||
"""
|
||||
if resp is None:
|
||||
return None
|
||||
return resp.reply_message
|
||||
|
||||
|
||||
async def _ask_regulatory(
|
||||
*, user_id: str, chat_id: str, message: str
|
||||
) -> str | None:
|
||||
"""统一封装 chat 入口对 RegulatoryNode 的调用。"""
|
||||
regulatory_node = ray_actor_hook("regulatory_node").regulatory_node
|
||||
payload = MessageRequest(
|
||||
platform="client",
|
||||
user_name=user_id,
|
||||
platform_id=chat_id,
|
||||
message=message,
|
||||
)
|
||||
resp: MessageResponse | None = await regulatory_node.working.remote(payload)
|
||||
return _extract_reply(resp)
|
||||
|
||||
|
||||
class CreateChatRequest(BaseModel):
|
||||
title: str = "新对话"
|
||||
initial_message: str
|
||||
@@ -45,9 +75,7 @@ async def create_chat_session(
|
||||
)
|
||||
|
||||
# 调用监管节点处理简单任务/交流
|
||||
regulatory_node = ray_actor_hook("regulatory_node").regulatory_node
|
||||
# 在此发起任务并等待或异步返回结果
|
||||
response_msg = await regulatory_node.handle_chat_message.remote(
|
||||
response_msg = await _ask_regulatory(
|
||||
user_id=token_data.user_id,
|
||||
chat_id=chat.chat_id,
|
||||
message=request.initial_message,
|
||||
@@ -95,9 +123,10 @@ async def send_chat_message(
|
||||
)
|
||||
|
||||
# 调用监管节点
|
||||
regulatory_node = ray_actor_hook("regulatory_node").regulatory_node
|
||||
response_msg = await regulatory_node.handle_chat_message.remote(
|
||||
user_id=token_data.user_id, chat_id=chat_id, message=request.message
|
||||
response_msg = await _ask_regulatory(
|
||||
user_id=token_data.user_id,
|
||||
chat_id=chat_id,
|
||||
message=request.message,
|
||||
)
|
||||
|
||||
# 存回复
|
||||
|
||||
@@ -0,0 +1,54 @@
|
||||
# Copyright 2026 zhaoxi826
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""健康检查端点:用于容器存活/就绪探针。"""
|
||||
|
||||
from fastapi import APIRouter
|
||||
from fastapi.responses import JSONResponse
|
||||
|
||||
from kilostar.utils.ray_hook import ray_actor_hook
|
||||
|
||||
health_router = APIRouter(tags=["health"])
|
||||
|
||||
|
||||
@health_router.get("/health/live", include_in_schema=True)
|
||||
async def liveness():
|
||||
"""存活探针:进程能响应即视为存活。"""
|
||||
return {"status": "alive"}
|
||||
|
||||
|
||||
@health_router.get("/health/ready", include_in_schema=True)
|
||||
async def readiness():
|
||||
"""就绪探针:检查关键依赖(Postgres / GSM Actor)是否可达。"""
|
||||
checks = {"postgres": False, "global_state_machine": False}
|
||||
|
||||
try:
|
||||
postgres_database = ray_actor_hook("postgres_database").postgres_database
|
||||
await postgres_database.ping.remote()
|
||||
checks["postgres"] = True
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
try:
|
||||
gsm = ray_actor_hook("global_state_machine").global_state_machine
|
||||
await gsm.get_skill_list.remote()
|
||||
checks["global_state_machine"] = True
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
all_ok = all(checks.values())
|
||||
return JSONResponse(
|
||||
status_code=200 if all_ok else 503,
|
||||
content={"status": "ready" if all_ok else "not_ready", "checks": checks},
|
||||
)
|
||||
@@ -13,5 +13,6 @@
|
||||
# limitations under the License.
|
||||
|
||||
from .frontend import client_router
|
||||
from .onebot import onebot_router
|
||||
|
||||
__all__ = ["client_router"]
|
||||
__all__ = ["client_router", "onebot_router"]
|
||||
|
||||
@@ -16,6 +16,10 @@ from fastapi import APIRouter, Depends, HTTPException, UploadFile, File
|
||||
from pydantic import BaseModel
|
||||
from kilostar.utils.access import Accessor, TokenData
|
||||
from kilostar.utils.ray_hook import ray_actor_hook
|
||||
from kilostar.core.individual.regulatory_node.template import (
|
||||
MessageRequest,
|
||||
MessageResponse,
|
||||
)
|
||||
import os
|
||||
import anyio
|
||||
from kilostar.utils.logger import get_logger
|
||||
@@ -39,12 +43,18 @@ async def create_message(
|
||||
logger.info("收到消息,来源:客户端")
|
||||
logger.debug(f"消息内容:{message.message}")
|
||||
regulatory_node = ray_actor_hook("regulatory_node").regulatory_node
|
||||
reply = await regulatory_node.handle_client_message.remote(
|
||||
user_id=token_data.user_id,
|
||||
msg_request = MessageRequest(
|
||||
platform="client",
|
||||
user_name=token_data.username,
|
||||
platform_id=token_data.user_id,
|
||||
message=message.message,
|
||||
)
|
||||
return {"message": reply}
|
||||
result = await regulatory_node.working.remote(msg_request)
|
||||
if isinstance(result, MessageResponse):
|
||||
return {"message": result.reply_message}
|
||||
if isinstance(result, str):
|
||||
return {"message": result}
|
||||
return {"message": ""}
|
||||
|
||||
|
||||
@client_router.post("/upload")
|
||||
|
||||
@@ -0,0 +1,279 @@
|
||||
# Copyright 2026 zhaoxi826
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""OneBot v11 协议适配器。
|
||||
|
||||
接收来自 OneBot 实现端(NapCat / go-cqhttp / Lagrange.OneBot 等)的事件上报,
|
||||
把消息事件翻译成 ``MessageRequest`` 投递给 RegulatoryNode。同时支持两种连接方式:
|
||||
|
||||
- HTTP 上报(POST ``/api/v1/adapter/onebot/event``):实现端把事件 POST 过来,
|
||||
通过返回体里的 ``reply`` 走 v11 "快速操作" 自动回包。
|
||||
- 反向 WebSocket(WS ``/api/v1/adapter/onebot/ws``):实现端主动建立长连接,
|
||||
服务端按 OneBot v11 反向 WS 规范返回 ``send_msg`` 等 action 主动回包。
|
||||
|
||||
模块还提供 ``send_message`` 工具函数,用 OneBot v11 HTTP API 主动给指定会话发消息。
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import os
|
||||
import uuid
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
import httpx
|
||||
from fastapi import APIRouter, Header, HTTPException, Request, WebSocket, WebSocketDisconnect
|
||||
|
||||
from kilostar.core.individual.regulatory_node.template import (
|
||||
MessageRequest,
|
||||
MessageResponse,
|
||||
)
|
||||
from kilostar.utils.logger import get_logger
|
||||
from kilostar.utils.ray_hook import ray_actor_hook
|
||||
|
||||
logger = get_logger("onebot")
|
||||
|
||||
onebot_router = APIRouter(prefix="/api/v1/adapter/onebot", tags=["onebot"])
|
||||
|
||||
|
||||
def _verify_token(token_from_header: Optional[str]) -> None:
|
||||
"""校验 OneBot 实现端在 ``Authorization`` 头里携带的 access_token。
|
||||
|
||||
若环境变量 ``ONEBOT_ACCESS_TOKEN`` 未设置则跳过校验。OneBot v11 规范要求
|
||||
格式为 ``Bearer <token>``,这里同时容忍只填 token 字符串本身的写法。
|
||||
"""
|
||||
expected = os.environ.get("ONEBOT_ACCESS_TOKEN")
|
||||
if not expected:
|
||||
return
|
||||
if not token_from_header:
|
||||
raise HTTPException(status_code=401, detail="Missing access_token")
|
||||
raw = token_from_header.removeprefix("Bearer ").removeprefix("Token ").strip()
|
||||
if raw != expected:
|
||||
raise HTTPException(status_code=401, detail="Invalid access_token")
|
||||
|
||||
|
||||
def _extract_plain_text(message: Any) -> str:
|
||||
"""把 OneBot 消息字段(字符串或 segment 数组)展平成纯文本。
|
||||
|
||||
OneBot v11 既支持 CQ 码字符串,也支持消息段数组形式;这里只抽取 ``text``
|
||||
段,其它段(图片/at/表情等)暂时丢弃。
|
||||
"""
|
||||
if isinstance(message, str):
|
||||
return message
|
||||
if isinstance(message, list):
|
||||
parts = []
|
||||
for seg in message:
|
||||
if isinstance(seg, dict) and seg.get("type") == "text":
|
||||
parts.append(seg.get("data", {}).get("text", ""))
|
||||
return "".join(parts)
|
||||
return ""
|
||||
|
||||
|
||||
async def _dispatch_event(payload: Dict[str, Any]) -> Optional[Dict[str, Any]]:
|
||||
"""把一次 OneBot 事件交给 RegulatoryNode 处理,返回快速操作字典或 ``None``。
|
||||
|
||||
仅处理 ``post_type == "message"`` 的私聊与群聊;元事件、通知、请求事件
|
||||
一律忽略。返回结果遵循 OneBot v11 "快速操作" 约定:
|
||||
|
||||
- 群聊:``{"reply": "...", "at_sender": False, "_target": {...}}``
|
||||
- 私聊:``{"reply": "...", "_target": {...}}``
|
||||
|
||||
其中 ``_target`` 仅供 ``_dispatch_via_ws`` 决定 send_msg 的入参;HTTP 模式下
|
||||
会被剔除后再返回给实现端。
|
||||
"""
|
||||
if payload.get("post_type") != "message":
|
||||
return None
|
||||
|
||||
message_type = payload.get("message_type") # private | group
|
||||
user_id = str(payload.get("user_id", ""))
|
||||
group_id = payload.get("group_id")
|
||||
raw_text = _extract_plain_text(payload.get("message", ""))
|
||||
sender = payload.get("sender") or {}
|
||||
user_name = (
|
||||
sender.get("card") or sender.get("nickname") or user_id or "onebot_user"
|
||||
)
|
||||
platform_id = (
|
||||
f"group:{group_id}" if message_type == "group" else f"private:{user_id}"
|
||||
)
|
||||
|
||||
if not raw_text.strip():
|
||||
return None
|
||||
|
||||
logger.info(
|
||||
f"[OneBot] {message_type} 消息 from {user_name}({user_id}) -> {raw_text!r}"
|
||||
)
|
||||
|
||||
msg_request = MessageRequest(
|
||||
platform="onebot",
|
||||
user_name=user_name,
|
||||
platform_id=platform_id,
|
||||
message=raw_text,
|
||||
)
|
||||
|
||||
try:
|
||||
regulatory_node = ray_actor_hook("regulatory_node").regulatory_node
|
||||
result = await regulatory_node.working.remote(msg_request)
|
||||
except Exception as e:
|
||||
logger.exception(f"[OneBot] RegulatoryNode 调用失败: {e}")
|
||||
return None
|
||||
|
||||
reply_text = ""
|
||||
if isinstance(result, MessageResponse):
|
||||
reply_text = result.reply_message or ""
|
||||
elif isinstance(result, str):
|
||||
reply_text = result
|
||||
|
||||
if not reply_text:
|
||||
return None
|
||||
|
||||
quick = {
|
||||
"reply": reply_text,
|
||||
"_target": {
|
||||
"message_type": message_type,
|
||||
"user_id": int(user_id) if user_id.isdigit() else user_id,
|
||||
"group_id": group_id,
|
||||
},
|
||||
}
|
||||
if message_type == "group":
|
||||
quick["at_sender"] = False
|
||||
return quick
|
||||
|
||||
|
||||
@onebot_router.post("/event")
|
||||
async def receive_event(
|
||||
request: Request,
|
||||
authorization: Optional[str] = Header(None),
|
||||
):
|
||||
"""HTTP 上报入口:接收 OneBot v11 事件并触发 RegulatoryNode。
|
||||
|
||||
若 RegulatoryNode 给出回复,会按 v11 "快速操作" 约定写到响应体里,由实现端
|
||||
自动发送。若不需要回复则返回 ``{"status": "ok"}``。
|
||||
"""
|
||||
_verify_token(authorization)
|
||||
payload: Dict[str, Any] = await request.json()
|
||||
quick = await _dispatch_event(payload)
|
||||
if not quick:
|
||||
return {"status": "ok"}
|
||||
quick.pop("_target", None)
|
||||
return quick
|
||||
|
||||
|
||||
async def _ws_call_action(
|
||||
ws: WebSocket, action: str, params: Dict[str, Any]
|
||||
) -> None:
|
||||
"""通过反向 WS 给实现端发送一次 action 调用,不等待响应。"""
|
||||
echo = uuid.uuid4().hex
|
||||
frame = {"action": action, "params": params, "echo": echo}
|
||||
await ws.send_text(json.dumps(frame, ensure_ascii=False))
|
||||
|
||||
|
||||
@onebot_router.websocket("/ws")
|
||||
async def reverse_websocket(
|
||||
websocket: WebSocket,
|
||||
authorization: Optional[str] = Header(None),
|
||||
x_self_id: Optional[str] = Header(None),
|
||||
):
|
||||
"""反向 WebSocket 入口:接受 OneBot 实现端主动建立的长连接。
|
||||
|
||||
握手时校验 ``Authorization`` 头;之后循环读 JSON 帧。带 ``post_type`` 的
|
||||
视为事件上报,调用 RegulatoryNode 处理后通过 ``send_msg`` action 主动回包;
|
||||
带 ``echo`` 的视为 action 响应,目前直接丢弃(后续若需可在此处认领 future)。
|
||||
"""
|
||||
try:
|
||||
_verify_token(authorization)
|
||||
except HTTPException:
|
||||
await websocket.close(code=4401)
|
||||
return
|
||||
|
||||
await websocket.accept()
|
||||
logger.info(f"[OneBot] reverse WS connected (self_id={x_self_id})")
|
||||
|
||||
try:
|
||||
while True:
|
||||
text = await websocket.receive_text()
|
||||
try:
|
||||
payload = json.loads(text)
|
||||
except json.JSONDecodeError:
|
||||
logger.warning(f"[OneBot] invalid JSON frame: {text[:200]}")
|
||||
continue
|
||||
|
||||
# action 响应帧(含 echo 而无 post_type),目前忽略
|
||||
if "post_type" not in payload and "echo" in payload:
|
||||
continue
|
||||
|
||||
quick = await _dispatch_event(payload)
|
||||
if not quick:
|
||||
continue
|
||||
|
||||
target = quick.get("_target", {})
|
||||
params: Dict[str, Any] = {"message": quick["reply"]}
|
||||
if target.get("message_type") == "group" and target.get("group_id"):
|
||||
params["group_id"] = target["group_id"]
|
||||
action = "send_group_msg"
|
||||
else:
|
||||
params["user_id"] = target.get("user_id")
|
||||
action = "send_private_msg"
|
||||
|
||||
asyncio.create_task(_ws_call_action(websocket, action, params))
|
||||
except WebSocketDisconnect:
|
||||
logger.info("[OneBot] reverse WS disconnected")
|
||||
except Exception as e:
|
||||
logger.exception(f"[OneBot] reverse WS error: {e}")
|
||||
try:
|
||||
await websocket.close()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
|
||||
async def send_message(
|
||||
user_id: Optional[int] = None,
|
||||
group_id: Optional[int] = None,
|
||||
message: str = "",
|
||||
*,
|
||||
base_url: Optional[str] = None,
|
||||
access_token: Optional[str] = None,
|
||||
) -> Dict[str, Any]:
|
||||
"""通过 OneBot v11 HTTP API 主动给私聊或群聊发送一条消息。
|
||||
|
||||
Args:
|
||||
user_id: 目标 QQ 号;与 ``group_id`` 二选一。
|
||||
group_id: 目标群号;与 ``user_id`` 二选一。
|
||||
message: 要发送的消息文本。
|
||||
base_url: OneBot 实现端的 HTTP API 地址;默认读取 ``ONEBOT_HTTP_URL``。
|
||||
access_token: 鉴权 token;默认读取 ``ONEBOT_ACCESS_TOKEN``。
|
||||
|
||||
Returns:
|
||||
OneBot HTTP API 的原始响应 JSON。
|
||||
"""
|
||||
if not user_id and not group_id:
|
||||
raise ValueError("必须指定 user_id 或 group_id 之一")
|
||||
|
||||
base = base_url or os.environ.get("ONEBOT_HTTP_URL", "http://127.0.0.1:5700")
|
||||
token = access_token or os.environ.get("ONEBOT_ACCESS_TOKEN")
|
||||
|
||||
if group_id:
|
||||
action = "send_group_msg"
|
||||
body = {"group_id": int(group_id), "message": message}
|
||||
else:
|
||||
action = "send_private_msg"
|
||||
body = {"user_id": int(user_id), "message": message}
|
||||
|
||||
headers: Dict[str, str] = {}
|
||||
if token:
|
||||
headers["Authorization"] = f"Bearer {token}"
|
||||
|
||||
url = f"{base.rstrip('/')}/{action}"
|
||||
async with httpx.AsyncClient(timeout=15.0) as client:
|
||||
resp = await client.post(url, json=body, headers=headers)
|
||||
resp.raise_for_status()
|
||||
return resp.json()
|
||||
@@ -14,11 +14,10 @@
|
||||
|
||||
from fastapi import APIRouter, Depends
|
||||
from pydantic import BaseModel
|
||||
from typing import Literal
|
||||
from typing import Any, Dict, Literal
|
||||
from kilostar.utils.access import TokenData, Accessor
|
||||
from kilostar.utils.check_user.role_check import RoleChecker
|
||||
from kilostar.core.postgres_database.model import UserAuthority
|
||||
from typing import Dict
|
||||
from kilostar.core.global_state_machine.model_provider.base_provider import Provider
|
||||
from kilostar.utils.ray_hook import ray_actor_hook
|
||||
|
||||
@@ -50,16 +49,27 @@ async def create_provider(
|
||||
)
|
||||
|
||||
|
||||
def _mask_apikey(key: str) -> str:
|
||||
if not key or len(key) <= 8:
|
||||
return "***"
|
||||
return key[:4] + "***" + key[-4:]
|
||||
|
||||
|
||||
@provider_router.get("/list")
|
||||
async def get_provider_list(
|
||||
_: TokenData = Depends(Accessor.get_current_user),
|
||||
) -> Dict[str, Dict[str, Provider]]:
|
||||
"""返回当前所有已注册的 Provider,前端用以展示模型清单。"""
|
||||
) -> Dict[str, Any]:
|
||||
"""返回当前所有已注册的 Provider,前端用以展示模型清单。apikey 脱敏。"""
|
||||
global_state_machine = ray_actor_hook("global_state_machine").global_state_machine
|
||||
provider_list: Dict[
|
||||
str, Provider
|
||||
] = await global_state_machine.get_provider_list.remote()
|
||||
return {"provider_list": provider_list}
|
||||
masked = {}
|
||||
for title, p in provider_list.items():
|
||||
d = p.model_dump() if hasattr(p, "model_dump") else dict(p)
|
||||
d["provider_apikey"] = _mask_apikey(d.get("provider_apikey", ""))
|
||||
masked[title] = d
|
||||
return {"provider_list": masked}
|
||||
|
||||
|
||||
@provider_router.delete("/{provider_title}")
|
||||
|
||||
+274
-5
@@ -12,13 +12,15 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import Any, Dict, List, Optional
|
||||
from pydantic import BaseModel
|
||||
import viceroy
|
||||
from kilostar.utils.ray_hook import ray_actor_hook
|
||||
from fastapi import APIRouter, Depends
|
||||
from fastapi import APIRouter, Depends, HTTPException
|
||||
from kilostar.utils.access import TokenData
|
||||
from kilostar.utils.check_user.role_check import RoleChecker
|
||||
from kilostar.core.postgres_database.model import UserAuthority
|
||||
from kilostar.utils.mcp_helper import list_mcp_tools_from_gsm
|
||||
|
||||
resource_router = APIRouter(prefix="/api/v1/resource")
|
||||
|
||||
@@ -30,13 +32,24 @@ class Skill(BaseModel):
|
||||
path: str | None
|
||||
|
||||
|
||||
class MCPServerConfig(BaseModel):
|
||||
"""``POST /mcp`` 入参:MCP 服务器配置。"""
|
||||
|
||||
name: str
|
||||
transport: str = "stdio" # stdio | sse | http
|
||||
command: str | None = None
|
||||
args: list[str] | None = None
|
||||
url: str | None = None
|
||||
tool_prefix: str | None = None
|
||||
env: Dict[str, str] | None = None
|
||||
|
||||
|
||||
@resource_router.post("/skill")
|
||||
async def install_skill(
|
||||
skill: Skill, _: TokenData = Depends(RoleChecker(allowed_roles=UserAuthority.USER))
|
||||
):
|
||||
"""通过 viceroy 把 skill 仓库克隆到 ``plugin/skill``,并在状态机中登记。"""
|
||||
global_state_machine = ray_actor_hook("global_state_machine").global_state_machine
|
||||
# noinspection PyUnresolvedReferences
|
||||
import os
|
||||
|
||||
skill_output_dir = os.path.abspath(
|
||||
@@ -73,19 +86,275 @@ async def delete_skill(
|
||||
):
|
||||
"""从状态机中移除 skill 注册项;不会删除磁盘上的代码文件。"""
|
||||
global_state_machine = ray_actor_hook("global_state_machine").global_state_machine
|
||||
# Note: this only removes it from the state machine manager.
|
||||
await global_state_machine.remove_skill.remote(skill_name)
|
||||
return {"message": "success"}
|
||||
|
||||
|
||||
# ─── MCP Server Management ───
|
||||
|
||||
@resource_router.post("/mcp")
|
||||
async def add_mcp_server(
|
||||
config: MCPServerConfig,
|
||||
_: TokenData = Depends(RoleChecker(allowed_roles=UserAuthority.SUPER_ADMINISTRATOR)),
|
||||
):
|
||||
"""注册一个 MCP 服务器到全局状态机。"""
|
||||
global_state_machine = ray_actor_hook("global_state_machine").global_state_machine
|
||||
import uuid
|
||||
|
||||
server_id = str(uuid.uuid4())[:8]
|
||||
cfg_dict = config.model_dump(exclude_none=True)
|
||||
await global_state_machine.add_mcp_server.remote(server_id, cfg_dict)
|
||||
return {"server_id": server_id, "message": "MCP server registered"}
|
||||
|
||||
|
||||
@resource_router.get("/mcp")
|
||||
async def list_mcp_servers(
|
||||
_: TokenData = Depends(RoleChecker(allowed_roles=UserAuthority.USER)),
|
||||
):
|
||||
"""返回已注册的全部 MCP 服务器配置;env 中的敏感字段脱敏。"""
|
||||
global_state_machine = ray_actor_hook("global_state_machine").global_state_machine
|
||||
servers = await global_state_machine.list_mcp_servers.remote()
|
||||
for s in servers:
|
||||
if "env" in s and isinstance(s["env"], dict):
|
||||
s["env"] = _mask_config(s["env"])
|
||||
return {"servers": servers}
|
||||
|
||||
|
||||
@resource_router.delete("/mcp/{server_id}")
|
||||
async def delete_mcp_server(
|
||||
server_id: str,
|
||||
_: TokenData = Depends(RoleChecker(allowed_roles=UserAuthority.SUPER_ADMINISTRATOR)),
|
||||
):
|
||||
"""从状态机中移除一个 MCP 服务器配置。"""
|
||||
global_state_machine = ray_actor_hook("global_state_machine").global_state_machine
|
||||
ok = await global_state_machine.delete_mcp_server.remote(server_id)
|
||||
if not ok:
|
||||
raise HTTPException(status_code=404, detail="MCP server not found")
|
||||
return {"message": "success"}
|
||||
|
||||
|
||||
# ─── Tool Management ───
|
||||
|
||||
@resource_router.get("/tool")
|
||||
async def get_tools(
|
||||
_: TokenData = Depends(RoleChecker(allowed_roles=UserAuthority.USER)),
|
||||
):
|
||||
"""汇总各作用域 tool_mapper,返回去重后的工具名称列表。"""
|
||||
"""返回按分类聚合的工具信息(包含系统工具、搜索工具、MCP 工具等)。
|
||||
|
||||
其中 ``mcp_servers`` 会现场尝试连接每个已注册的 MCP 服务器并列出它们暴露的
|
||||
工具名,便于前端展示;任意一台 MCP server 不可达不影响其他工具的返回。
|
||||
"""
|
||||
global_state_machine = ray_actor_hook("global_state_machine").global_state_machine
|
||||
tool_mapper = await global_state_machine.get_tool_mapper.remote()
|
||||
categories = await global_state_machine.get_tool_categories.remote()
|
||||
|
||||
all_tool_names = set()
|
||||
for scope_tools in tool_mapper.values():
|
||||
all_tool_names.update(scope_tools.keys())
|
||||
return {"tools": list(all_tool_names)}
|
||||
|
||||
mcp_servers = await list_mcp_tools_from_gsm()
|
||||
|
||||
return {
|
||||
"tools": list(all_tool_names),
|
||||
"categories": categories,
|
||||
"mcp_servers": mcp_servers,
|
||||
}
|
||||
|
||||
|
||||
# ─── Tool Config Management(Tavily API key 等运行期配置)───
|
||||
|
||||
|
||||
def _mask_secret(value: Any) -> Any:
|
||||
"""对像 ``api_key`` / ``token`` / ``secret`` 这种敏感字段做简单脱敏。"""
|
||||
if not isinstance(value, str) or not value:
|
||||
return value
|
||||
if len(value) <= 8:
|
||||
return "***"
|
||||
return value[:4] + "***" + value[-4:]
|
||||
|
||||
|
||||
def _mask_config(config: Dict[str, Any]) -> Dict[str, Any]:
|
||||
masked: Dict[str, Any] = {}
|
||||
for k, v in config.items():
|
||||
if any(s in k.lower() for s in ("key", "token", "secret", "password")):
|
||||
masked[k] = _mask_secret(v)
|
||||
else:
|
||||
masked[k] = v
|
||||
return masked
|
||||
|
||||
|
||||
class ToolConfigUpdate(BaseModel):
|
||||
"""``PUT /tool/config/{tool_name}`` 入参:要写入的工具配置 KV。"""
|
||||
|
||||
config: Dict[str, Any]
|
||||
|
||||
|
||||
@resource_router.get("/tool/config")
|
||||
async def list_tool_configs(
|
||||
_: TokenData = Depends(RoleChecker(allowed_roles=UserAuthority.SUPER_ADMINISTRATOR)),
|
||||
):
|
||||
"""列出所有工具运行期配置;敏感字段会被脱敏。"""
|
||||
global_state_machine = ray_actor_hook("global_state_machine").global_state_machine
|
||||
raw = await global_state_machine.list_tool_configs.remote()
|
||||
return {
|
||||
"configs": {name: _mask_config(cfg) for name, cfg in raw.items()},
|
||||
}
|
||||
|
||||
|
||||
@resource_router.get("/tool/config/{tool_name}")
|
||||
async def get_tool_config(
|
||||
tool_name: str,
|
||||
_: TokenData = Depends(RoleChecker(allowed_roles=UserAuthority.SUPER_ADMINISTRATOR)),
|
||||
):
|
||||
"""按工具名取出脱敏后的配置。"""
|
||||
global_state_machine = ray_actor_hook("global_state_machine").global_state_machine
|
||||
raw = await global_state_machine.get_tool_config.remote(tool_name)
|
||||
return {"tool_name": tool_name, "config": _mask_config(raw)}
|
||||
|
||||
|
||||
@resource_router.put("/tool/config/{tool_name}")
|
||||
async def set_tool_config(
|
||||
tool_name: str,
|
||||
body: ToolConfigUpdate,
|
||||
_: TokenData = Depends(RoleChecker(allowed_roles=UserAuthority.SUPER_ADMINISTRATOR)),
|
||||
):
|
||||
"""写入/覆盖某工具的运行期配置(如 ``tavily_search`` 的 ``api_key``)。"""
|
||||
global_state_machine = ray_actor_hook("global_state_machine").global_state_machine
|
||||
await global_state_machine.set_tool_config.remote(tool_name, body.config)
|
||||
return {"message": "success"}
|
||||
|
||||
|
||||
@resource_router.delete("/tool/config/{tool_name}")
|
||||
async def delete_tool_config(
|
||||
tool_name: str,
|
||||
_: TokenData = Depends(RoleChecker(allowed_roles=UserAuthority.SUPER_ADMINISTRATOR)),
|
||||
):
|
||||
"""删除某工具的运行期配置。"""
|
||||
global_state_machine = ray_actor_hook("global_state_machine").global_state_machine
|
||||
ok = await global_state_machine.delete_tool_config.remote(tool_name)
|
||||
if not ok:
|
||||
raise HTTPException(status_code=404, detail="Tool config not found")
|
||||
return {"message": "success"}
|
||||
|
||||
|
||||
# ─── Custom Toolset Management ───
|
||||
|
||||
|
||||
class CustomToolsetCreate(BaseModel):
|
||||
name: str
|
||||
tools: List[str]
|
||||
description: Optional[str] = None
|
||||
|
||||
|
||||
class CustomToolsetUpdate(BaseModel):
|
||||
name: Optional[str] = None
|
||||
tools: Optional[List[str]] = None
|
||||
description: Optional[str] = None
|
||||
|
||||
|
||||
async def _assert_toolset_owner_or_admin(
|
||||
toolset: Dict[str, Any], token_data: TokenData
|
||||
) -> None:
|
||||
"""校验 toolset 归属:非 owner 且非管理员则抛 403。"""
|
||||
from kilostar.utils.check_user.role_check import get_authority
|
||||
|
||||
if toolset.get("owner_id") == token_data.user_id:
|
||||
return
|
||||
authority = await get_authority(token_data.user_id)
|
||||
if authority >= UserAuthority.ADMINISTRATOR:
|
||||
return
|
||||
raise HTTPException(status_code=403, detail="无权访问此自定义工具组")
|
||||
|
||||
|
||||
@resource_router.post("/custom-toolset")
|
||||
async def create_custom_toolset(
|
||||
body: CustomToolsetCreate,
|
||||
token_data: TokenData = Depends(RoleChecker(allowed_roles=UserAuthority.USER)),
|
||||
):
|
||||
global_state_machine = ray_actor_hook("global_state_machine").global_state_machine
|
||||
import uuid
|
||||
|
||||
toolset_id = str(uuid.uuid4())[:8]
|
||||
try:
|
||||
saved = await global_state_machine.add_custom_toolset.remote(
|
||||
toolset_id=toolset_id,
|
||||
name=body.name,
|
||||
tools=body.tools,
|
||||
description=body.description,
|
||||
owner_id=token_data.user_id,
|
||||
)
|
||||
except ValueError as e:
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
return {"toolset_id": toolset_id, "toolset": saved}
|
||||
|
||||
|
||||
@resource_router.get("/custom-toolset")
|
||||
async def list_custom_toolsets(
|
||||
token_data: TokenData = Depends(RoleChecker(allowed_roles=UserAuthority.USER)),
|
||||
):
|
||||
"""列出工具组:USER 只能看到自己的;ADMIN 及以上可看全部。"""
|
||||
from kilostar.utils.check_user.role_check import get_authority
|
||||
|
||||
global_state_machine = ray_actor_hook("global_state_machine").global_state_machine
|
||||
toolsets = await global_state_machine.list_custom_toolsets.remote()
|
||||
authority = await get_authority(token_data.user_id)
|
||||
if authority < UserAuthority.ADMINISTRATOR:
|
||||
toolsets = [t for t in toolsets if t.get("owner_id") == token_data.user_id]
|
||||
return {"toolsets": toolsets}
|
||||
|
||||
|
||||
@resource_router.get("/custom-toolset/{toolset_id}")
|
||||
async def get_custom_toolset(
|
||||
toolset_id: str,
|
||||
token_data: TokenData = Depends(RoleChecker(allowed_roles=UserAuthority.USER)),
|
||||
):
|
||||
global_state_machine = ray_actor_hook("global_state_machine").global_state_machine
|
||||
ts = await global_state_machine.get_custom_toolset.remote(toolset_id)
|
||||
if not ts:
|
||||
raise HTTPException(status_code=404, detail="Custom toolset not found")
|
||||
await _assert_toolset_owner_or_admin(ts, token_data)
|
||||
return ts
|
||||
|
||||
|
||||
@resource_router.put("/custom-toolset/{toolset_id}")
|
||||
async def update_custom_toolset(
|
||||
toolset_id: str,
|
||||
body: CustomToolsetUpdate,
|
||||
token_data: TokenData = Depends(RoleChecker(allowed_roles=UserAuthority.USER)),
|
||||
):
|
||||
global_state_machine = ray_actor_hook("global_state_machine").global_state_machine
|
||||
existing = await global_state_machine.get_custom_toolset.remote(toolset_id)
|
||||
if not existing:
|
||||
raise HTTPException(status_code=404, detail="Custom toolset not found")
|
||||
await _assert_toolset_owner_or_admin(existing, token_data)
|
||||
name = body.name if body.name is not None else existing["name"]
|
||||
tools = body.tools if body.tools is not None else existing["tools"]
|
||||
description = body.description if body.description is not None else existing.get("description")
|
||||
try:
|
||||
saved = await global_state_machine.add_custom_toolset.remote(
|
||||
toolset_id=toolset_id,
|
||||
name=name,
|
||||
tools=tools,
|
||||
description=description,
|
||||
owner_id=existing.get("owner_id", token_data.user_id),
|
||||
)
|
||||
except ValueError as e:
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
return {"toolset": saved}
|
||||
|
||||
|
||||
@resource_router.delete("/custom-toolset/{toolset_id}")
|
||||
async def delete_custom_toolset(
|
||||
toolset_id: str,
|
||||
token_data: TokenData = Depends(RoleChecker(allowed_roles=UserAuthority.USER)),
|
||||
):
|
||||
"""删除工具组:USER 只能删自己的;ADMIN 及以上可删任意。"""
|
||||
global_state_machine = ray_actor_hook("global_state_machine").global_state_machine
|
||||
existing = await global_state_machine.get_custom_toolset.remote(toolset_id)
|
||||
if not existing:
|
||||
raise HTTPException(status_code=404, detail="Custom toolset not found")
|
||||
await _assert_toolset_owner_or_admin(existing, token_data)
|
||||
ok = await global_state_machine.delete_custom_toolset.remote(toolset_id)
|
||||
if not ok:
|
||||
raise HTTPException(status_code=404, detail="Custom toolset not found")
|
||||
return {"message": "success"}
|
||||
|
||||
@@ -119,3 +119,89 @@ async def get_workflow_detail(
|
||||
"steps": steps,
|
||||
"context_blackboard": context.blackboard if context else {},
|
||||
}
|
||||
|
||||
|
||||
@workflow_router.post("/{trace_id}/resume")
|
||||
async def resume_workflow(
|
||||
trace_id: str,
|
||||
token_data: TokenData = Depends(Accessor.get_current_user),
|
||||
):
|
||||
"""从 ``workflow_graph_state`` 持久化恢复一个被中断/挂起的工作流。
|
||||
|
||||
新 fire 一个 ray task,task 入口的 ``hydrate`` 检查会自动走 resume 路径
|
||||
把剩余节点跑完。
|
||||
"""
|
||||
postgres_database = ray_actor_hook("postgres_database").postgres_database
|
||||
wf = await postgres_database.get_workflow.remote(trace_id)
|
||||
if not wf:
|
||||
raise HTTPException(status_code=404, detail="Workflow not found")
|
||||
if getattr(wf, "user_id", None) != token_data.user_id:
|
||||
raise HTTPException(status_code=403, detail="Forbidden")
|
||||
|
||||
record = await postgres_database.get_workflow_graph_state.remote(trace_id)
|
||||
if record is None:
|
||||
raise HTTPException(
|
||||
status_code=409, detail="该工作流没有可恢复的图持久化记录"
|
||||
)
|
||||
|
||||
global_workflow_manager = ray_actor_hook(
|
||||
"global_workflow_manager"
|
||||
).global_workflow_manager
|
||||
await global_workflow_manager.create_trace.remote(trace_id)
|
||||
|
||||
from kilostar.core.work.workflow.workflow_engine import run_workflow_task
|
||||
|
||||
# workflow_data 在 resume 路径上不会被使用(hydrate 会走 resume 分支),
|
||||
# 这里给个空 dict 占位即可
|
||||
run_workflow_task.remote({}, trace_id)
|
||||
return {"trace_id": trace_id, "status": "resuming"}
|
||||
|
||||
|
||||
@workflow_router.get("/{trace_id}/graph")
|
||||
async def get_workflow_graph_mermaid(
|
||||
trace_id: str,
|
||||
token_data: TokenData = Depends(Accessor.get_current_user),
|
||||
):
|
||||
"""返回当前 workflow 引擎的 mermaid 图源码(节点拓扑)。
|
||||
|
||||
拓扑本身对所有 trace 是同一份;但如果该 trace 已经有 ``workflow_graph_state``
|
||||
持久化记录,会读出 history 里"已经成功跑过的节点"作为 ``highlighted_nodes``
|
||||
传给 mermaid,前端拿到的 mermaid 源码会自带 visited 节点高亮。
|
||||
"""
|
||||
from kilostar.core.work.workflow.workflow_engine import workflow_graph
|
||||
|
||||
postgres_database = ray_actor_hook("postgres_database").postgres_database
|
||||
wf = await postgres_database.get_workflow.remote(trace_id)
|
||||
if not wf:
|
||||
raise HTTPException(status_code=404, detail="Workflow not found")
|
||||
if getattr(wf, "user_id", None) != token_data.user_id:
|
||||
raise HTTPException(status_code=403, detail="Forbidden")
|
||||
|
||||
visited: list[str] = []
|
||||
record = await postgres_database.get_workflow_graph_state.remote(trace_id)
|
||||
if record is not None:
|
||||
history = getattr(record, "history", None) or []
|
||||
# history 里每条 NodeSnapshot.id 形如 "ClassName:hash",截前缀作为 NodeIdent
|
||||
# 只取 status == "success" 的节点(避免 "created" / "running" 带噪声)
|
||||
seen: set[str] = set()
|
||||
for entry in history:
|
||||
if not isinstance(entry, dict):
|
||||
continue
|
||||
if entry.get("kind") != "node":
|
||||
continue
|
||||
if entry.get("status") != "success":
|
||||
continue
|
||||
sid = entry.get("id") or ""
|
||||
cls_name = sid.split(":", 1)[0] if sid else ""
|
||||
if cls_name and cls_name not in seen:
|
||||
seen.add(cls_name)
|
||||
visited.append(cls_name)
|
||||
|
||||
try:
|
||||
if visited:
|
||||
mermaid = workflow_graph.mermaid_code(highlighted_nodes=visited)
|
||||
else:
|
||||
mermaid = workflow_graph.mermaid_code()
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=f"mermaid 生成失败: {e}")
|
||||
return {"trace_id": trace_id, "mermaid": mermaid, "visited": visited}
|
||||
|
||||
@@ -13,47 +13,129 @@
|
||||
# limitations under the License.
|
||||
|
||||
import ray
|
||||
from kilostar.core.global_state_machine.provider_manager import ProviderManager
|
||||
from kilostar.core.global_state_machine.tool_manager import GlobalToolManager
|
||||
from kilostar.core.postgres_database import PostgresDatabase
|
||||
from kilostar.core.global_state_machine.skill_manager import GlobalSkillManager
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
|
||||
from kilostar.core.global_state_machine.individual_manager import (
|
||||
GlobalIndividualManager,
|
||||
)
|
||||
from kilostar.core.global_state_machine.provider_manager import ProviderManager
|
||||
from kilostar.core.global_state_machine.skill_manager import GlobalSkillManager
|
||||
from kilostar.core.global_state_machine.tool_manager import GlobalToolManager
|
||||
from kilostar.core.global_state_machine.gsm_snapshot import GSMSnapshot
|
||||
from kilostar.core.postgres_database import PostgresDatabase
|
||||
|
||||
|
||||
@ray.remote
|
||||
class GlobalStateMachine:
|
||||
"""全局状态机 Actor,统一持有 Provider/Tool/Skill/Individual 四个注册表。
|
||||
"""全局状态机 Actor,统一持有 Provider/Tool/Skill/Individual/MCP/CustomToolset 注册表。
|
||||
|
||||
其它 Actor 通过 ``ray.get_actor("global_state_machine")`` 拿到本实例,
|
||||
再调用本类暴露的方法来读写各注册表,避免每个 Actor 各自维护一份状态。
|
||||
所有持久化都走 PostgresDatabase;启动时由 ``init_state_machine`` 一次性把
|
||||
Provider / MCP / Tool config / Custom toolset 拉到内存,保证后续读操作零等待。
|
||||
"""
|
||||
|
||||
def __init__(self, postgres_database: PostgresDatabase):
|
||||
import sys
|
||||
|
||||
print("GSM __init__ START", file=sys.stderr, flush=True)
|
||||
print(" event_dict done", file=sys.stderr, flush=True)
|
||||
self._global_provider_manager = ProviderManager(postgres_database)
|
||||
print(" provider_manager done", file=sys.stderr, flush=True)
|
||||
self._global_tool_manager = GlobalToolManager()
|
||||
print(" tool_manager done", file=sys.stderr, flush=True)
|
||||
self._global_skill_manager = GlobalSkillManager()
|
||||
print(" skill_manager done", file=sys.stderr, flush=True)
|
||||
self._global_individual_manager = GlobalIndividualManager()
|
||||
print(" individual_manager done", file=sys.stderr, flush=True)
|
||||
|
||||
# 内存注册表(启动时由 init_state_machine 从 DB 加载)
|
||||
self._mcp_servers: Dict[str, Dict[str, Any]] = {}
|
||||
self._tool_configs: Dict[str, Dict[str, Any]] = {}
|
||||
self._custom_toolsets: Dict[str, Dict[str, Any]] = {}
|
||||
|
||||
# 配置快照与版本号:每次写入 → version+=1 → ray.put 新 snapshot
|
||||
# 读端通过 current_config_ref 拿 ref 后用 ray.get 直读,绕开 actor 单线程瓶颈
|
||||
self._config_version: int = 0
|
||||
self._current_ref: Optional[ray.ObjectRef] = None
|
||||
|
||||
self.postgres_database = postgres_database
|
||||
print("GSM __init__ DONE", file=sys.stderr, flush=True)
|
||||
|
||||
async def init_state_machine(self):
|
||||
"""从数据库加载 Provider/Individual 注册表到内存。"""
|
||||
"""启动期一次性把 Provider/Individual/MCP/ToolConfig/CustomToolset 拉到内存。"""
|
||||
await self._global_provider_manager.init_provider_register(
|
||||
self.postgres_database
|
||||
)
|
||||
await self._global_individual_manager.init_individual_register(
|
||||
self.postgres_database
|
||||
)
|
||||
# MCP servers
|
||||
rows = await self.postgres_database.list_mcp_servers_db.remote()
|
||||
self._mcp_servers = {row["server_id"]: row for row in rows}
|
||||
# Tool configs
|
||||
cfg_rows = await self.postgres_database.list_tool_configs_db.remote()
|
||||
self._tool_configs = {row["tool_name"]: row["config"] for row in cfg_rows}
|
||||
# Custom toolsets
|
||||
ts_rows = await self.postgres_database.list_custom_toolsets.remote()
|
||||
self._custom_toolsets = {row["toolset_id"]: row for row in ts_rows}
|
||||
# 让 tool_manager 立刻把 custom toolset 装配成 FunctionToolset
|
||||
self._global_tool_manager.rebuild_custom_toolsets(self._custom_toolsets)
|
||||
# 启动期一次性发布 v1 快照,让等待中的读端立刻可用
|
||||
self._publish_snapshot()
|
||||
|
||||
# ─── Snapshot 发布(Object Store 读路径) ────────────────────
|
||||
|
||||
def _build_snapshot(self) -> GSMSnapshot:
|
||||
"""把当前内存状态打包成 GSMSnapshot,调用方应已确保数据是最新的。
|
||||
|
||||
注意 ``tool_funcs`` 这里被拍平成 ``{tool_name: callable}`` —— 内部按
|
||||
scope 分桶的细节在 snapshot 层不暴露,task 端只关心"我能调哪些函数"。
|
||||
如果同名工具在多个 scope 出现,后写入的覆盖(系统工具与第三方互斥,
|
||||
实际不会冲突)。
|
||||
|
||||
``system_tools_by_scope`` 单独保留按 scope 分桶的工具名列表,让客户端
|
||||
在自己进程里复刻 ``get_toolsets_for_scope`` 的合并语义(fetch_snapshot
|
||||
用户调 ``build_toolsets_for_scope`` 即可重建 FunctionToolset 列表)。
|
||||
"""
|
||||
tm = self._global_tool_manager
|
||||
flat_funcs: Dict[str, Any] = {}
|
||||
system_tools_by_scope: Dict[str, List[str]] = {}
|
||||
for _scope, name_to_func in tm._tool_funcs.items():
|
||||
system_tools_by_scope[_scope] = list(name_to_func.keys())
|
||||
for name, func in name_to_func.items():
|
||||
flat_funcs[name] = func
|
||||
return GSMSnapshot(
|
||||
version=self._config_version,
|
||||
providers=dict(self._global_provider_manager.provider_register),
|
||||
individuals=dict(self._global_individual_manager._individuals),
|
||||
mcp_servers=dict(self._mcp_servers),
|
||||
tool_configs=dict(self._tool_configs),
|
||||
custom_toolsets=dict(self._custom_toolsets),
|
||||
skills=dict(self._global_skill_manager.skill_mapper),
|
||||
tool_metadata=dict(tm.tool_metadata),
|
||||
tool_funcs=flat_funcs,
|
||||
third_party_funcs=dict(tm._third_party_funcs),
|
||||
tool_mapper={
|
||||
scope: dict(name_to_cls)
|
||||
for scope, name_to_cls in tm.tool_mapper.items()
|
||||
},
|
||||
system_tools_by_scope=system_tools_by_scope,
|
||||
)
|
||||
|
||||
def _publish_snapshot(self) -> None:
|
||||
"""版本号 +1 并把当前状态 put 到 Ray Object Store。
|
||||
|
||||
旧 ref 会因为引用计数归零而进入回收队列;正在执行的 task 已经把 ref
|
||||
拷贝到了自己的进程,dec 不会影响它们的读取。
|
||||
"""
|
||||
self._config_version += 1
|
||||
self._current_ref = ray.put(self._build_snapshot())
|
||||
|
||||
async def current_config_ref(self) -> Tuple[int, ray.ObjectRef]:
|
||||
"""返回 ``(version, ObjectRef)``,调用方拿了 ref 后用 ``ray.get`` 自取。
|
||||
|
||||
**不要**直接返回 snapshot 对象 —— 那样会走 actor RPC 反序列化,丧失
|
||||
object store 的共享内存优势。返回 ref 才能让调用方在自己进程里 ray.get。
|
||||
"""
|
||||
if self._current_ref is None:
|
||||
self._publish_snapshot()
|
||||
return self._config_version, self._current_ref
|
||||
|
||||
async def current_version(self) -> int:
|
||||
"""轻量版:只返回当前版本号,用于读端判断本地缓存是否还新。"""
|
||||
return self._config_version
|
||||
|
||||
# ─── Provider ──────────────────────────────────────────────
|
||||
|
||||
async def add_provider_wrap(
|
||||
self,
|
||||
@@ -64,7 +146,7 @@ class GlobalStateMachine:
|
||||
provider_owner,
|
||||
):
|
||||
"""新增一个模型 Provider:内存注册 + 数据库持久化一并完成。"""
|
||||
return await self._global_provider_manager.add_provider(
|
||||
result = await self._global_provider_manager.add_provider(
|
||||
provider_type=provider_type,
|
||||
provider_title=provider_title,
|
||||
provider_url=provider_url,
|
||||
@@ -72,8 +154,9 @@ class GlobalStateMachine:
|
||||
provider_owner=provider_owner,
|
||||
postgres_database=self.postgres_database,
|
||||
)
|
||||
self._publish_snapshot()
|
||||
return result
|
||||
|
||||
# Provider Manager Methods
|
||||
def get_provider_list(self):
|
||||
"""返回内存中已登记的全部 Provider。"""
|
||||
return self._global_provider_manager.get_provider_list()
|
||||
@@ -84,11 +167,14 @@ class GlobalStateMachine:
|
||||
|
||||
async def delete_provider(self, provider_title: str):
|
||||
"""删除一个 Provider:内存注册 + 数据库持久化一并完成。"""
|
||||
return await self._global_provider_manager.delete_provider(
|
||||
result = await self._global_provider_manager.delete_provider(
|
||||
provider_title, self.postgres_database
|
||||
)
|
||||
self._publish_snapshot()
|
||||
return result
|
||||
|
||||
# ─── Tool / Toolset ────────────────────────────────────────
|
||||
|
||||
# Tool Manager Methods
|
||||
def get_tool_mapper(self):
|
||||
"""返回 agent_name -> {tool_name: callable} 的全量映射。"""
|
||||
return self._global_tool_manager.tool_mapper
|
||||
@@ -96,37 +182,152 @@ class GlobalStateMachine:
|
||||
def get_tool_list(self, agent_name: str):
|
||||
"""返回某个 agent 可用的工具集(其专属工具与 default 工具的并集)。"""
|
||||
tools = self._global_tool_manager.tool_mapper.get(agent_name, {})
|
||||
# also include default tools
|
||||
default_tools = self._global_tool_manager.tool_mapper.get("default", {})
|
||||
merged_tools = {**default_tools, **tools}
|
||||
return merged_tools
|
||||
return {**default_tools, **tools}
|
||||
|
||||
def get_tool_categories(self):
|
||||
"""返回工具按分类聚合的完整信息。"""
|
||||
return {
|
||||
"system": self._global_tool_manager.get_system_tools(),
|
||||
"third_party": self._global_tool_manager.get_third_party_tools(),
|
||||
"by_category": {
|
||||
"system": self._global_tool_manager.get_tools_by_category("system"),
|
||||
"search": self._global_tool_manager.get_tools_by_category("search"),
|
||||
"mcp": self._global_tool_manager.get_tools_by_category("mcp"),
|
||||
"other": self._global_tool_manager.get_tools_by_category("other"),
|
||||
},
|
||||
"all": self._global_tool_manager.get_all_tools(),
|
||||
}
|
||||
|
||||
def get_toolsets_for_scope(self, scope: str) -> List[Any]:
|
||||
"""返回某个 scope 下的"系统 + 自定义工具组"toolset 列表(不含 MCP)。"""
|
||||
return self._global_tool_manager.get_toolsets_for_scope(scope)
|
||||
|
||||
# ─── MCP Server Registry ───────────────────────────────────
|
||||
|
||||
async def add_mcp_server(self, server_id: str, config: Dict[str, Any]) -> bool:
|
||||
"""注册一个 MCP 服务器配置(写库 → 写内存)。"""
|
||||
saved = await self.postgres_database.upsert_mcp_server.remote(server_id, config)
|
||||
self._mcp_servers[server_id] = saved
|
||||
self._publish_snapshot()
|
||||
return True
|
||||
|
||||
def get_mcp_server(self, server_id: str) -> Optional[Dict[str, Any]]:
|
||||
return self._mcp_servers.get(server_id)
|
||||
|
||||
def list_mcp_servers(self) -> List[Dict[str, Any]]:
|
||||
return [
|
||||
{"server_id": sid, **cfg} for sid, cfg in self._mcp_servers.items()
|
||||
]
|
||||
|
||||
async def delete_mcp_server(self, server_id: str) -> bool:
|
||||
ok = await self.postgres_database.delete_mcp_server_db.remote(server_id)
|
||||
self._mcp_servers.pop(server_id, None)
|
||||
self._publish_snapshot()
|
||||
return bool(ok)
|
||||
|
||||
def get_mcp_server_configs(self) -> Dict[str, Dict[str, Any]]:
|
||||
"""返回原始 MCP 服务器配置字典(供节点创建 toolsets 时使用)。"""
|
||||
return dict(self._mcp_servers)
|
||||
|
||||
# ─── Tool Config(Tavily API key 等)─────────────────────
|
||||
|
||||
async def set_tool_config(self, tool_name: str, config: Dict[str, Any]) -> bool:
|
||||
"""整体覆盖某工具的运行期配置(敏感字段在 DAO 内自动加密)。"""
|
||||
saved = await self.postgres_database.upsert_tool_config.remote(
|
||||
tool_name, config
|
||||
)
|
||||
self._tool_configs[tool_name] = saved["config"]
|
||||
self._publish_snapshot()
|
||||
return True
|
||||
|
||||
def get_tool_config(self, tool_name: str) -> Dict[str, Any]:
|
||||
"""按工具名取出配置;不存在则返回空字典。"""
|
||||
return dict(self._tool_configs.get(tool_name, {}))
|
||||
|
||||
def list_tool_configs(self) -> Dict[str, Dict[str, Any]]:
|
||||
"""返回全部已配置工具的配置(包含敏感字段,调用方需自行脱敏)。"""
|
||||
return dict(self._tool_configs)
|
||||
|
||||
async def delete_tool_config(self, tool_name: str) -> bool:
|
||||
ok = await self.postgres_database.delete_tool_config_db.remote(tool_name)
|
||||
self._tool_configs.pop(tool_name, None)
|
||||
self._publish_snapshot()
|
||||
return bool(ok)
|
||||
|
||||
# ─── Custom Toolset(用户自定义工具组)──────────────────
|
||||
|
||||
async def add_custom_toolset(
|
||||
self,
|
||||
toolset_id: str,
|
||||
name: str,
|
||||
tools: List[str],
|
||||
description: Optional[str] = None,
|
||||
owner_id: Optional[str] = None,
|
||||
) -> Dict[str, Any]:
|
||||
"""新增/更新一个自定义工具组:仅允许引用非 system/非 mcp 的工具。"""
|
||||
# 校验:只能放第三方(非 system / 非 mcp)工具
|
||||
invalid = [
|
||||
t for t in tools if not self._global_tool_manager.is_third_party_tool(t)
|
||||
]
|
||||
if invalid:
|
||||
raise ValueError(
|
||||
f"自定义工具组只允许包含第三方工具,以下不合法:{invalid}"
|
||||
)
|
||||
saved = await self.postgres_database.upsert_custom_toolset.remote(
|
||||
toolset_id=toolset_id,
|
||||
name=name,
|
||||
tools=list(tools),
|
||||
description=description,
|
||||
owner_id=owner_id,
|
||||
)
|
||||
self._custom_toolsets[toolset_id] = saved
|
||||
self._global_tool_manager.rebuild_custom_toolsets(self._custom_toolsets)
|
||||
self._publish_snapshot()
|
||||
return saved
|
||||
|
||||
def list_custom_toolsets(self) -> List[Dict[str, Any]]:
|
||||
return list(self._custom_toolsets.values())
|
||||
|
||||
def get_custom_toolset(self, toolset_id: str) -> Optional[Dict[str, Any]]:
|
||||
return self._custom_toolsets.get(toolset_id)
|
||||
|
||||
async def delete_custom_toolset(self, toolset_id: str) -> bool:
|
||||
ok = await self.postgres_database.delete_custom_toolset.remote(toolset_id)
|
||||
self._custom_toolsets.pop(toolset_id, None)
|
||||
self._global_tool_manager.rebuild_custom_toolsets(self._custom_toolsets)
|
||||
self._publish_snapshot()
|
||||
return bool(ok)
|
||||
|
||||
# ─── Skill ────────────────────────────────────────────────
|
||||
|
||||
# Skill Manager Methods
|
||||
def add_skill(self, skill_name: str):
|
||||
"""注册一个新的 Skill 名称到 Skill 注册表。"""
|
||||
return self._global_skill_manager.add_skill(skill_name)
|
||||
result = self._global_skill_manager.add_skill(skill_name)
|
||||
self._publish_snapshot()
|
||||
return result
|
||||
|
||||
def get_skill_list(self):
|
||||
"""返回全部已注册的 Skill 名称。"""
|
||||
return self._global_skill_manager.get_skill_list()
|
||||
|
||||
def remove_skill(self, skill_name: str):
|
||||
"""从注册表中移除一个 Skill。"""
|
||||
return self._global_skill_manager.remove_skill(skill_name)
|
||||
result = self._global_skill_manager.remove_skill(skill_name)
|
||||
self._publish_snapshot()
|
||||
return result
|
||||
|
||||
# ─── Individual ───────────────────────────────────────────
|
||||
|
||||
# Individual Manager Methods
|
||||
def add_individual(self, agent_id: str, config):
|
||||
"""把一个 Worker Individual 的运行期配置加入注册表。"""
|
||||
return self._global_individual_manager.add_individual(agent_id, config)
|
||||
result = self._global_individual_manager.add_individual(agent_id, config)
|
||||
self._publish_snapshot()
|
||||
return result
|
||||
|
||||
def get_individual(self, agent_id: str):
|
||||
"""按 agent_id 取出某个 Worker Individual 的配置。"""
|
||||
return self._global_individual_manager.get_individual(agent_id)
|
||||
|
||||
def remove_individual(self, agent_id: str):
|
||||
"""从注册表中移除一个 Worker Individual。"""
|
||||
return self._global_individual_manager.remove_individual(agent_id)
|
||||
result = self._global_individual_manager.remove_individual(agent_id)
|
||||
self._publish_snapshot()
|
||||
return result
|
||||
|
||||
def list_individuals(self):
|
||||
"""返回当前注册的全部 Worker Individual 列表。"""
|
||||
return self._global_individual_manager.list_individuals()
|
||||
|
||||
@@ -0,0 +1,184 @@
|
||||
# Copyright 2026 zhaoxi826
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""GSM 快照对象与客户端拉取工具。
|
||||
|
||||
设计动机:把 GSM actor 内存里的"读路径配置"打包成可放进 Ray Object Store
|
||||
的不可变快照。读端不再走 actor RPC,而是 ``fetch_snapshot()`` 一次拿到全量
|
||||
当前配置,亚毫秒级共享内存读,绕开单 actor 的吞吐瓶颈。
|
||||
|
||||
GSM 仍然是 source of truth + 写入串行化器,但读路径解耦:
|
||||
|
||||
- 写入 → GSM 内存更新 → ``ray.put(snapshot)`` 拿到新 ObjectRef → 版本号 +1
|
||||
- 读取 → ``current_config_ref()`` 拿 (version, ref) → ``ray.get(ref)`` 直读
|
||||
|
||||
旧的 ``get_provider / get_individual / ...`` 接口保留不动,是低频路径的兜底;
|
||||
新代码(特别是 skill task 这种高并发热路径)应优先走 ``fetch_snapshot``。
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any, Callable, Dict, List, Optional, Tuple
|
||||
|
||||
import ray
|
||||
|
||||
from kilostar.core.global_state_machine.model_provider.base_provider import Provider
|
||||
from kilostar.utils.logger import get_logger
|
||||
|
||||
_logger = get_logger("gsm_snapshot")
|
||||
|
||||
|
||||
@dataclass
|
||||
class GSMSnapshot:
|
||||
"""GSM 配置的不可变快照,所有字段都必须 cloudpickle 友好。
|
||||
|
||||
本类故意不放 ``FunctionToolset`` 实例 —— pydantic-ai toolset 的可序列化性
|
||||
随版本可能变动,让 task 端按 ``tool_funcs`` + ``tool_mapper`` 自己装配
|
||||
既隔离了 pydantic-ai 的实现风险,又让 snapshot 体积更小。
|
||||
"""
|
||||
|
||||
version: int = 0
|
||||
providers: Dict[str, Provider] = field(default_factory=dict)
|
||||
individuals: Dict[str, Dict[str, Any]] = field(default_factory=dict)
|
||||
mcp_servers: Dict[str, Dict[str, Any]] = field(default_factory=dict)
|
||||
tool_configs: Dict[str, Dict[str, Any]] = field(default_factory=dict)
|
||||
custom_toolsets: Dict[str, Dict[str, Any]] = field(default_factory=dict)
|
||||
skills: Dict[str, Tuple[str, str]] = field(default_factory=dict)
|
||||
tool_metadata: Dict[str, Dict[str, Any]] = field(default_factory=dict)
|
||||
tool_funcs: Dict[str, Callable[..., Any]] = field(default_factory=dict)
|
||||
third_party_funcs: Dict[str, Callable[..., Any]] = field(default_factory=dict)
|
||||
tool_mapper: Dict[str, Dict[str, type]] = field(default_factory=dict)
|
||||
# ``{scope: [tool_name, ...]}``:系统工具按 scope 维护的工具名清单。
|
||||
# 客户端按名字 + ``tool_funcs`` 在自己进程里重建 FunctionToolset,
|
||||
# 避开把不可序列化/版本耦合的 toolset 实例塞进快照的坑。
|
||||
system_tools_by_scope: Dict[str, List[str]] = field(default_factory=dict)
|
||||
|
||||
|
||||
_local_cache: Dict[str, Any] = {"version": -1, "snapshot": None}
|
||||
_cache_lock = asyncio.Lock()
|
||||
|
||||
|
||||
async def fetch_snapshot(
|
||||
*,
|
||||
use_cache: bool = True,
|
||||
gsm_actor: Optional[Any] = None,
|
||||
) -> GSMSnapshot:
|
||||
"""拉取当前 GSM 快照。
|
||||
|
||||
优先走"版本号检查 + ObjectRef 直读"路径:
|
||||
|
||||
1. 调 ``gsm.current_version.remote()`` 看本地缓存是否还新(一次轻量 RPC)
|
||||
2. 若本地缓存版本号一致,直接返回缓存(亚毫秒,零网络)
|
||||
3. 否则调 ``gsm.current_config_ref.remote()`` 拿 ref,``ray.get`` 解出
|
||||
4. 更新本地缓存
|
||||
|
||||
Args:
|
||||
use_cache: 默认开启进程内 LRU 缓存(实际是单槽位,持有当前版本);
|
||||
测试或诊断场景可关掉强制重拉。
|
||||
gsm_actor: 可选传入 GSM actor handle;省略时通过 ``ray_actor_hook`` 获取。
|
||||
|
||||
Note:
|
||||
本函数在 task / actor 进程内多次调用是廉价的;建议每次需要 config 时
|
||||
现取,不要把 snapshot 长期持有跨任务边界(避免 ObjectRef 阻碍回收)。
|
||||
"""
|
||||
if gsm_actor is None:
|
||||
from kilostar.utils.ray_hook import ray_actor_hook
|
||||
|
||||
gsm_actor = ray_actor_hook("global_state_machine").global_state_machine
|
||||
|
||||
if use_cache:
|
||||
async with _cache_lock:
|
||||
try:
|
||||
latest_version = await gsm_actor.current_version.remote()
|
||||
except Exception:
|
||||
latest_version = None
|
||||
|
||||
if (
|
||||
latest_version is not None
|
||||
and _local_cache.get("version") == latest_version
|
||||
and _local_cache.get("snapshot") is not None
|
||||
):
|
||||
return _local_cache["snapshot"]
|
||||
|
||||
version, ref = await gsm_actor.current_config_ref.remote()
|
||||
snapshot = ray.get(ref)
|
||||
_local_cache["version"] = version
|
||||
_local_cache["snapshot"] = snapshot
|
||||
return snapshot
|
||||
|
||||
version, ref = await gsm_actor.current_config_ref.remote()
|
||||
return ray.get(ref)
|
||||
|
||||
|
||||
def reset_local_cache() -> None:
|
||||
"""清空本进程内的快照缓存(测试用)。"""
|
||||
_local_cache["version"] = -1
|
||||
_local_cache["snapshot"] = None
|
||||
|
||||
|
||||
# ─── 客户端 helper:从快照重建本地视图 ─────────────────────────────
|
||||
|
||||
|
||||
def build_toolsets_for_scope(
|
||||
snapshot: GSMSnapshot, scope: str
|
||||
) -> List[Any]:
|
||||
"""在调用方进程里按 ``snapshot`` 现场组装 FunctionToolset 列表。
|
||||
|
||||
复刻 ``GlobalToolManager.get_toolsets_for_scope`` 的合并逻辑:
|
||||
|
||||
- 系统 toolset:按 ``default`` + ``scope`` 两个 bucket 拼装
|
||||
- 自定义 toolset:``custom_toolsets`` 里所有有效项
|
||||
|
||||
返回的 toolset 是 *进程局部* 的——pydantic-ai FunctionToolset 实例不能跨进程
|
||||
共享,但函数对象本身已经躺在 snapshot 里被 cloudpickle 还原过,
|
||||
重新 ``FunctionToolset(tools=[...])`` 几乎零代价。
|
||||
"""
|
||||
try:
|
||||
from pydantic_ai.toolsets import FunctionToolset
|
||||
except ImportError:
|
||||
_logger.warning("pydantic_ai.toolsets unavailable; cannot build toolsets")
|
||||
return []
|
||||
|
||||
result: List[Any] = []
|
||||
for bucket in ("default", scope):
|
||||
names = snapshot.system_tools_by_scope.get(bucket) or []
|
||||
funcs = [snapshot.tool_funcs[n] for n in names if n in snapshot.tool_funcs]
|
||||
if not funcs:
|
||||
continue
|
||||
try:
|
||||
result.append(
|
||||
FunctionToolset(tools=funcs, id=f"system::{bucket}")
|
||||
)
|
||||
except Exception as e: # pragma: no cover - 防御
|
||||
_logger.error(f"build system toolset {bucket} failed: {e}")
|
||||
|
||||
for toolset_id, defn in snapshot.custom_toolsets.items():
|
||||
names = defn.get("tools") or []
|
||||
funcs = [
|
||||
snapshot.third_party_funcs[n]
|
||||
for n in names
|
||||
if n in snapshot.third_party_funcs
|
||||
]
|
||||
if not funcs:
|
||||
continue
|
||||
try:
|
||||
result.append(
|
||||
FunctionToolset(tools=funcs, id=f"custom::{toolset_id}")
|
||||
)
|
||||
except Exception as e: # pragma: no cover - 防御
|
||||
_logger.error(f"build custom toolset {toolset_id} failed: {e}")
|
||||
|
||||
return result
|
||||
@@ -1,35 +1,42 @@
|
||||
# Copyright 2026 zhaoxi826
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import pathlib
|
||||
import importlib
|
||||
import inspect
|
||||
from collections import defaultdict
|
||||
from typing import Any, Callable, Dict, List, Type
|
||||
|
||||
from kilostar.plugin.tool_plugin.base_tool import BaseToolData
|
||||
from typing import Dict, Type
|
||||
from kilostar.utils.logger import get_logger
|
||||
|
||||
logger = get_logger("tool_manager")
|
||||
|
||||
_SYSTEM_BUCKET = "system"
|
||||
|
||||
|
||||
class GlobalToolManager:
|
||||
"""工具注册表:扫描 ``kilostar/plugin/tool_plugin/`` 下所有 BaseToolData 子类,
|
||||
按 ``action_scope`` 分桶到 ``tool_mapper[scope][plugin_name]``;无 scope 的归入 ``default``。"""
|
||||
按 ``action_scope`` 打包成 ``FunctionToolset``。
|
||||
|
||||
三类 toolset:
|
||||
- **system**:``is_system=True`` 的工具,按 scope 分组
|
||||
- **custom**:用户自定义工具组(由 ``rebuild_custom_toolsets`` 动态构建)
|
||||
- **mcp**:由 ``mcp_helper`` 独立管理,不经过本类
|
||||
|
||||
``category="mcp"`` 的工具不会被本类管理。
|
||||
"""
|
||||
|
||||
tool_metadata: Dict[str, Dict[str, Any]]
|
||||
_tool_funcs: Dict[str, Dict[str, Callable]]
|
||||
_system_toolsets: Dict[str, Any]
|
||||
_custom_toolsets: Dict[str, Any]
|
||||
_third_party_funcs: Dict[str, Callable]
|
||||
tool_mapper: Dict[str, Dict[str, Type[BaseToolData]]]
|
||||
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
self.tool_metadata = {}
|
||||
self._tool_funcs = defaultdict(dict)
|
||||
self._system_toolsets = {}
|
||||
self._custom_toolsets = {}
|
||||
self._third_party_funcs = {}
|
||||
self.tool_mapper = defaultdict(dict)
|
||||
|
||||
tool_plugin_dir = (
|
||||
@@ -39,21 +46,154 @@ class GlobalToolManager:
|
||||
return
|
||||
|
||||
for item in tool_plugin_dir.iterdir():
|
||||
if item.is_dir() and not item.name.startswith("__"):
|
||||
plugin_name = item.name
|
||||
module_name = f"kilostar.plugin.tool_plugin.{plugin_name}"
|
||||
if not (item.is_dir() and not item.name.startswith("__")):
|
||||
continue
|
||||
plugin_name = item.name
|
||||
module_name = f"kilostar.plugin.tool_plugin.{plugin_name}"
|
||||
|
||||
try:
|
||||
module = importlib.import_module(module_name)
|
||||
for name, obj in inspect.getmembers(module, inspect.isclass):
|
||||
if issubclass(obj, BaseToolData) and obj is not BaseToolData:
|
||||
# It's a valid tool class
|
||||
action_scopes = obj.model_fields.get("action_scope").default
|
||||
try:
|
||||
module = importlib.import_module(module_name)
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to import tool plugin {plugin_name}: {e}")
|
||||
continue
|
||||
|
||||
if not action_scopes:
|
||||
self.tool_mapper["default"][plugin_name] = obj
|
||||
else:
|
||||
for scope in action_scopes:
|
||||
self.tool_mapper[scope][plugin_name] = obj
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to load tool plugin {plugin_name}: {e}")
|
||||
tool_data_cls = self._find_tool_data_class(module)
|
||||
if tool_data_cls is None:
|
||||
continue
|
||||
|
||||
tool_func = getattr(module, plugin_name, None)
|
||||
if not callable(tool_func):
|
||||
logger.warning(
|
||||
f"Tool plugin '{plugin_name}' has no callable named "
|
||||
f"'{plugin_name}' in its module; skipped."
|
||||
)
|
||||
continue
|
||||
|
||||
action_scopes = (
|
||||
tool_data_cls.model_fields.get("action_scope").default or []
|
||||
)
|
||||
is_system = bool(tool_data_cls.model_fields.get("is_system").default)
|
||||
category_field = tool_data_cls.model_fields.get("category")
|
||||
category = (category_field.default if category_field else "other") or "other"
|
||||
|
||||
self.tool_metadata[plugin_name] = {
|
||||
"name": plugin_name,
|
||||
"is_system": is_system,
|
||||
"category": category,
|
||||
"action_scope": list(action_scopes),
|
||||
}
|
||||
|
||||
if category == "mcp":
|
||||
continue
|
||||
|
||||
scopes = [s for s in action_scopes if s] or ["default"]
|
||||
|
||||
if is_system:
|
||||
for scope in scopes:
|
||||
self._tool_funcs[scope][plugin_name] = tool_func
|
||||
self.tool_mapper[scope][plugin_name] = tool_data_cls
|
||||
else:
|
||||
self._third_party_funcs[plugin_name] = tool_func
|
||||
for scope in scopes:
|
||||
self.tool_mapper[scope][plugin_name] = tool_data_cls
|
||||
|
||||
self._build_system_toolsets()
|
||||
|
||||
def _build_system_toolsets(self) -> None:
|
||||
FunctionToolset = self._import_function_toolset()
|
||||
if FunctionToolset is None:
|
||||
return
|
||||
for scope, name_to_func in self._tool_funcs.items():
|
||||
if not name_to_func:
|
||||
continue
|
||||
try:
|
||||
self._system_toolsets[scope] = FunctionToolset(
|
||||
tools=list(name_to_func.values()),
|
||||
id=f"system::{scope}",
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to build system toolset {scope}: {e}")
|
||||
|
||||
def rebuild_custom_toolsets(self, custom_defs: Dict[str, Dict[str, Any]]) -> None:
|
||||
"""根据 DB 中的自定义工具组定义重建 custom FunctionToolset。"""
|
||||
FunctionToolset = self._import_function_toolset()
|
||||
if FunctionToolset is None:
|
||||
self._custom_toolsets = {}
|
||||
return
|
||||
new_map: Dict[str, Any] = {}
|
||||
for toolset_id, defn in custom_defs.items():
|
||||
tools_names = defn.get("tools") or []
|
||||
funcs = [
|
||||
self._third_party_funcs[n]
|
||||
for n in tools_names
|
||||
if n in self._third_party_funcs
|
||||
]
|
||||
if not funcs:
|
||||
continue
|
||||
try:
|
||||
new_map[toolset_id] = FunctionToolset(
|
||||
tools=funcs,
|
||||
id=f"custom::{toolset_id}",
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to build custom toolset {toolset_id}: {e}")
|
||||
self._custom_toolsets = new_map
|
||||
|
||||
@staticmethod
|
||||
def _import_function_toolset():
|
||||
try:
|
||||
from pydantic_ai.toolsets import FunctionToolset
|
||||
return FunctionToolset
|
||||
except ImportError:
|
||||
logger.warning("pydantic_ai.toolsets unavailable")
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def _find_tool_data_class(module) -> Type[BaseToolData] | None:
|
||||
for _, obj in inspect.getmembers(module, inspect.isclass):
|
||||
if issubclass(obj, BaseToolData) and obj is not BaseToolData:
|
||||
return obj
|
||||
return None
|
||||
|
||||
# ─── Toolset accessors ───
|
||||
|
||||
def get_system_toolset(self, scope: str) -> Any | None:
|
||||
return self._system_toolsets.get(scope)
|
||||
|
||||
def get_toolsets_for_scope(self, scope: str) -> List[Any]:
|
||||
"""合并 system(default + scope)+ 全部 custom toolset。"""
|
||||
result: List[Any] = []
|
||||
for s in ("default", scope):
|
||||
ts = self._system_toolsets.get(s)
|
||||
if ts is not None:
|
||||
result.append(ts)
|
||||
result.extend(self._custom_toolsets.values())
|
||||
return result
|
||||
|
||||
# ─── Metadata accessors ───
|
||||
|
||||
def is_third_party_tool(self, tool_name: str) -> bool:
|
||||
return tool_name in self._third_party_funcs
|
||||
|
||||
def get_tools_by_category(self, category: str) -> List[Dict[str, Any]]:
|
||||
return [m for m in self.tool_metadata.values() if m.get("category") == category]
|
||||
|
||||
def get_system_tools(self) -> List[Dict[str, Any]]:
|
||||
return [m for m in self.tool_metadata.values() if m.get("is_system") is True]
|
||||
|
||||
def get_third_party_tools(self) -> List[Dict[str, Any]]:
|
||||
return [
|
||||
m
|
||||
for m in self.tool_metadata.values()
|
||||
if m.get("is_system") is not True and m.get("category") != "mcp"
|
||||
]
|
||||
|
||||
def get_all_tools(self) -> List[Dict[str, Any]]:
|
||||
return list(self.tool_metadata.values())
|
||||
|
||||
# 兼容旧接口
|
||||
def get_non_system_tools(self) -> List[Dict[str, Any]]:
|
||||
return self.get_third_party_tools()
|
||||
|
||||
def get_personal_tools(self) -> List[Dict[str, Any]]:
|
||||
return self.get_third_party_tools()
|
||||
|
||||
@@ -29,6 +29,7 @@ from kilostar.core.global_state_machine.global_state_machine import GlobalStateM
|
||||
from kilostar.core.global_state_machine.model_provider.base_provider import Provider
|
||||
from kilostar.adapter.model_adapter.agent_factory import AgentFactory
|
||||
from kilostar.utils.ray_hook import ray_actor_hook
|
||||
from kilostar.utils.i18n import agent_prompt
|
||||
|
||||
|
||||
@ray.remote
|
||||
@@ -45,22 +46,19 @@ class ConsciousnessNode:
|
||||
provider_title: str,
|
||||
model_id: str,
|
||||
tools_list: list[str] = None,
|
||||
toolsets=None,
|
||||
locale: str | None = None,
|
||||
) -> None:
|
||||
system_prompt: str = (
|
||||
"你叫kilostar,是一个多智能体AI助手系统中的【意识节点 (Consciousness Node)】。\n"
|
||||
"你是系统的'高级规划师'和'架构师',负责处理监控节点分配过来的复杂任务。\n"
|
||||
"你的主要工作场景包括:\n"
|
||||
"1. 拆解任务 (Workflow Generation):结合用户的原始命令和提供的模板,生成严谨、可执行的工作流 (kilostarWorkflow),并将其输出为 ForWorkflowEngine 格式。拆解时步骤应清晰连贯。\n"
|
||||
"2. 中途指导 (Workflow Execution):在工作流执行中,如果某一步骤指派给你,你需要对控制节点的结果进行分析或提供下一步的指导,输出 ForWorkflow 格式。\n"
|
||||
"3. 总结报告 (regulatory Report):在整个工作流执行完毕后,你需要对整体流程、各个控制节点的执行情况进行审查,并生成一份技术性的总结报告,输出 ForregulatoryNode 格式。\n"
|
||||
"请确保所有的思考和生成过程符合逻辑,严密且高质量。"
|
||||
)
|
||||
system_prompt: str = agent_prompt("consciousness_node", locale=locale)
|
||||
output_type = Union[ForregulatoryNode, ForWorkflow, ForWorkflowEngine]
|
||||
from kilostar.utils.get_tool import load_tools_from_list
|
||||
from kilostar.core.global_state_machine.gsm_snapshot import fetch_snapshot
|
||||
|
||||
provider: Provider = await global_state_machine.get_provider.remote(
|
||||
provider_title
|
||||
)
|
||||
snapshot = await fetch_snapshot(gsm_actor=global_state_machine)
|
||||
provider: Provider = snapshot.providers.get(provider_title)
|
||||
if provider is None:
|
||||
from kilostar.utils.i18n import t
|
||||
raise ValueError(t("provider_not_registered", locale=locale, provider_title=provider_title))
|
||||
agent_factory = AgentFactory()
|
||||
|
||||
callables = load_tools_from_list(tools_list)
|
||||
@@ -72,6 +70,7 @@ class ConsciousnessNode:
|
||||
deps_type=ConsciousnessNodeDeps,
|
||||
agent_name="consciousness_node",
|
||||
tools=callables,
|
||||
toolsets=toolsets,
|
||||
)
|
||||
|
||||
@self.agent.system_prompt
|
||||
@@ -95,6 +94,13 @@ class ConsciousnessNode:
|
||||
开始进行工作流设计的交互过程(与用户通过 SSE 进行确认或直接生成)
|
||||
目前简化为:直接根据 command 拆解并构建工作流,然后提交执行。
|
||||
"""
|
||||
from kilostar.utils.request_context import trace_id_scope
|
||||
|
||||
# 进入工作流域:把 trace_id 绑到 contextvars,本协程所有日志自动带上它
|
||||
with trace_id_scope(trace_id):
|
||||
await self._do_start_workflow_design(trace_id, command)
|
||||
|
||||
async def _do_start_workflow_design(self, trace_id: str, command: str):
|
||||
self.logger.info(
|
||||
f"ConsciousnessNode: 开始为 trace_id {trace_id} 设计工作流。原始命令:{command}"
|
||||
)
|
||||
@@ -116,11 +122,11 @@ class ConsciousnessNode:
|
||||
original_command=command, available_skills=available_skills
|
||||
)
|
||||
|
||||
# 通知 SSE 正在生成图结构
|
||||
# 通知 SSE 正在生成图结构(pending 队列:节点端写入 → API SSE 读取,单向下行)
|
||||
global_workflow_manager = ray_actor_hook(
|
||||
"global_workflow_manager"
|
||||
).global_workflow_manager
|
||||
await global_workflow_manager.put_received.remote(
|
||||
await global_workflow_manager.put_pending.remote(
|
||||
trace_id, "正在为您构建并规划工作流任务节点,请稍候..."
|
||||
)
|
||||
|
||||
@@ -131,17 +137,17 @@ class ConsciousnessNode:
|
||||
workflow = result.workflow
|
||||
workflow.trace_id = trace_id
|
||||
|
||||
await global_workflow_manager.put_received.remote(
|
||||
await global_workflow_manager.put_pending.remote(
|
||||
trace_id, "工作流构建完成,即将开始执行!"
|
||||
)
|
||||
|
||||
# 将生成的完整工作流提交执行
|
||||
workflow_engine = ray_actor_hook(
|
||||
"workflow_running_engine"
|
||||
).workflow_running_engine
|
||||
await workflow_engine.execute_workflow.remote(workflow)
|
||||
# 直接以 ray task 形式 fire workflow,不再经过 WorkflowRunningEngine 这层中转:
|
||||
# workflow 是一次性、有头有尾的执行,task 语义比常驻 actor 更贴。
|
||||
from kilostar.core.work.workflow.workflow_engine import run_workflow_task
|
||||
|
||||
run_workflow_task.remote(workflow.model_dump(), trace_id)
|
||||
else:
|
||||
await global_workflow_manager.put_received.remote(
|
||||
await global_workflow_manager.put_pending.remote(
|
||||
trace_id, "很抱歉,工作流生成失败。"
|
||||
)
|
||||
await postgres_database.update_workflow_status.remote(trace_id, "failed")
|
||||
|
||||
@@ -22,6 +22,7 @@ from kilostar.core.individual.control_node.template import (
|
||||
ForWorkflowInput,
|
||||
ControlNodeDeps,
|
||||
)
|
||||
from kilostar.utils.i18n import agent_prompt
|
||||
|
||||
|
||||
@ray.remote
|
||||
@@ -44,6 +45,8 @@ class ControlNode:
|
||||
provider_title: str,
|
||||
model_id: str,
|
||||
tools_list: list[str] = None,
|
||||
toolsets=None,
|
||||
locale: str | None = None,
|
||||
) -> None:
|
||||
"""
|
||||
create_agent方法,将agent对象装配到Control的属性内
|
||||
@@ -54,25 +57,21 @@ class ControlNode:
|
||||
global_state_machine: 全局状态机
|
||||
provider_title: 供应商名
|
||||
model_id: 模型id
|
||||
locale: 语言代码(zh/en),控制system prompt语言
|
||||
|
||||
Returns:
|
||||
无返回
|
||||
"""
|
||||
system_prompt: str = (
|
||||
"你叫kilostar,是一个多智能体AI助手系统中的【控制节点 (Control Node)】。\n"
|
||||
"你是系统的'执行者'和'车间主任',专门负责执行工作流中分配给你的具体子任务。\n"
|
||||
"你的工作职责是:\n"
|
||||
"1. 仔细分析分配给你的工作流步骤 (workflow_step) 的目标和要求。\n"
|
||||
"2. 运用你被分配的工具 (如有) 或者依靠自身的知识和推理能力,精准、高效地完成该任务。\n"
|
||||
"3. 将执行的结果、产生的数据或者具体的输出,严格按照 ForWorkflow 格式返回。\n"
|
||||
"请注意:你的输出应当具体、实用,直接提供任务所要求的结果,不要做过多无关的寒暄。"
|
||||
)
|
||||
system_prompt: str = agent_prompt("control_node", locale=locale)
|
||||
output_type = ForWorkflow
|
||||
from kilostar.utils.get_tool import load_tools_from_list
|
||||
from kilostar.core.global_state_machine.gsm_snapshot import fetch_snapshot
|
||||
|
||||
provider: Provider = await global_state_machine.get_provider.remote(
|
||||
provider_title
|
||||
)
|
||||
snapshot = await fetch_snapshot(gsm_actor=global_state_machine)
|
||||
provider: Provider = snapshot.providers.get(provider_title)
|
||||
if provider is None:
|
||||
from kilostar.utils.i18n import t
|
||||
raise ValueError(t("provider_not_registered", locale=locale, provider_title=provider_title))
|
||||
agent_factory = AgentFactory()
|
||||
|
||||
callables = load_tools_from_list(tools_list)
|
||||
@@ -84,6 +83,7 @@ class ControlNode:
|
||||
deps_type=ControlNodeDeps,
|
||||
agent_name="control_node",
|
||||
tools=callables,
|
||||
toolsets=toolsets,
|
||||
)
|
||||
|
||||
@self.agent.system_prompt
|
||||
|
||||
@@ -24,6 +24,7 @@ from kilostar.core.individual.regulatory_node.template import (
|
||||
MessageResponse
|
||||
)
|
||||
from pydantic_ai import RunContext, Agent
|
||||
from kilostar.utils.i18n import agent_prompt
|
||||
|
||||
|
||||
@ray.remote
|
||||
@@ -46,6 +47,8 @@ class RegulatoryNode:
|
||||
provider_title: str,
|
||||
model_id: str,
|
||||
tools_list: list[str] = None,
|
||||
toolsets=None,
|
||||
locale: str | None = None,
|
||||
) -> None:
|
||||
"""
|
||||
create_agent方法,将agent对象装配到regulatoryNode的属性内
|
||||
@@ -56,24 +59,21 @@ class RegulatoryNode:
|
||||
provider_title: 供应商名
|
||||
model_id: 模型id
|
||||
tools_list: 工具列表
|
||||
locale: 语言代码(zh/en),控制system prompt语言
|
||||
Returns:
|
||||
无返回
|
||||
"""
|
||||
system_prompt: str = (
|
||||
"你叫kilostar,是一个多智能体AI助手系统中的【监控节点 (regulatory Node)】。\n"
|
||||
"你是系统的'前台接待'和'大脑皮层',负责接收用户的初始请求或工作流的最终报告。\n"
|
||||
"你的核心职责是进行【意图识别与路由】。请仔细阅读用户的请求:\n"
|
||||
"1. 如果用户只是进行简单的问候、闲聊或查询非常基础的信息,请直接生成友好的回复,使用 ForUser 格式。\n"
|
||||
"2. 如果用户提出的是复杂任务(如需要编写代码、多步骤规划、数据处理等),请务必将其判定为需要工作流处理的任务,"
|
||||
" 并使用 ForConsciousnessNode 格式将其移交意识节点处理。\n"
|
||||
"3. 如果你收到的是 TerminationMessage(代表工作流已完成并生成了报告),请将报告内容转化为友好的面向用户的回复,使用 ForUser 格式。\n"
|
||||
"请保持冷静、专业,并严格遵循上述路由规则。"
|
||||
)
|
||||
system_prompt: str = agent_prompt("regulatory_node", locale=locale)
|
||||
output_type = Union[MessageResponse]
|
||||
from kilostar.utils.get_tool import load_tools_from_list
|
||||
provider: Provider = await global_state_machine.get_provider.remote(
|
||||
provider_title
|
||||
)
|
||||
from kilostar.core.global_state_machine.gsm_snapshot import fetch_snapshot
|
||||
|
||||
# 走 Object Store 快照而不是 actor RPC:高频读路径不再受单 actor 串行限制
|
||||
snapshot = await fetch_snapshot(gsm_actor=global_state_machine)
|
||||
provider: Provider = snapshot.providers.get(provider_title)
|
||||
if provider is None:
|
||||
from kilostar.utils.i18n import t
|
||||
raise ValueError(t("provider_not_registered", locale=locale, provider_title=provider_title))
|
||||
agent_factory = AgentFactory()
|
||||
|
||||
callables = load_tools_from_list(tools_list)
|
||||
@@ -85,6 +85,7 @@ class RegulatoryNode:
|
||||
deps_type=RegulatoryNodeDeps,
|
||||
agent_name="regulatory_node",
|
||||
tools=callables,
|
||||
toolsets=toolsets,
|
||||
)
|
||||
|
||||
@self.agent.system_prompt
|
||||
@@ -112,16 +113,15 @@ class RegulatoryNode:
|
||||
)
|
||||
return prompt
|
||||
|
||||
async def working(self, payload: MessageRequest) -> str:
|
||||
"""working方法,是节点唯一的调用方法,对于_run函数的结果进行判断并实现最终回复
|
||||
async def working(self, payload: MessageRequest) -> Union[MessageResponse, None]:
|
||||
"""working方法,是节点唯一的调用方法,对_run函数的结果进行判断并返回最终回复
|
||||
Args:
|
||||
payload: 消息载荷,包含所有信息
|
||||
|
||||
Returns:
|
||||
str,监控节点对于用户的回复
|
||||
MessageResponse 或 None,监控节点对用户的结构化回复
|
||||
"""
|
||||
await self._run(payload)
|
||||
return ""
|
||||
return await self._run(payload)
|
||||
|
||||
async def _run(
|
||||
self, payload: MessageRequest
|
||||
@@ -140,7 +140,8 @@ class RegulatoryNode:
|
||||
deps=deps,)
|
||||
response: MessageResponse = agent_response.output
|
||||
response.platform = platform
|
||||
response.platform_id = MessageRequest.platform_id
|
||||
response.platform_id = payload.platform_id
|
||||
return response
|
||||
except:
|
||||
pass
|
||||
except Exception as e:
|
||||
self.logger.exception(f"RegulatoryNode._run failed: {e}")
|
||||
return None
|
||||
@@ -49,7 +49,7 @@ class MessageRequest(RequestModel):
|
||||
MessageRequest类
|
||||
任何消息渠道向regulatory_node发送消息请求的模型
|
||||
"""
|
||||
platform: Literal["client"]
|
||||
platform: Literal["client", "onebot"]
|
||||
user_name: str
|
||||
platform_id: Optional[str]
|
||||
message: str
|
||||
@@ -59,6 +59,6 @@ class MessageResponse(RegulatoryNodeResponse):
|
||||
MessageResponse类
|
||||
regulatory_node回复的模型
|
||||
"""
|
||||
platform: Optional[Literal["client"]] = Field(description="系统自动填入的platform")
|
||||
platform: Optional[Literal["client", "onebot"]] = Field(description="系统自动填入的platform")
|
||||
platform_id: Optional[str] = Field(description="系统自动填入的platform_id")
|
||||
reply_message: str = Field(...,description="模型回复的消息")
|
||||
|
||||
@@ -23,12 +23,16 @@ from kilostar.core.postgres_database.model.individual import (
|
||||
from kilostar.core.postgres_database.model.workflow import (
|
||||
Workflow,
|
||||
WorkflowContextModel,
|
||||
WorkflowGraphStateModel,
|
||||
)
|
||||
from kilostar.core.postgres_database.model.chat_history import (
|
||||
ChatHistoryRegister,
|
||||
ChatHistoryMessage,
|
||||
)
|
||||
from kilostar.core.postgres_database.model.system_node import SystemNodeConfigModel
|
||||
from kilostar.core.postgres_database.model.mcp_server import MCPServerModel
|
||||
from kilostar.core.postgres_database.model.tool_config import ToolConfigModel
|
||||
from kilostar.core.postgres_database.model.custom_toolset import CustomToolsetModel
|
||||
|
||||
# 兼容旧代码的别名
|
||||
Provider = ProviderModel
|
||||
@@ -49,9 +53,13 @@ __all__ = [
|
||||
"SpecialIndividualModel",
|
||||
"Workflow",
|
||||
"WorkflowContextModel",
|
||||
"WorkflowGraphStateModel",
|
||||
"ChatHistoryRegister",
|
||||
"ChatHistoryMessage",
|
||||
"SystemNodeConfigModel",
|
||||
"SystemNodeConfig",
|
||||
"MCPServerModel",
|
||||
"ToolConfigModel",
|
||||
"CustomToolsetModel",
|
||||
"AgentType",
|
||||
]
|
||||
|
||||
@@ -0,0 +1,32 @@
|
||||
from datetime import datetime
|
||||
from typing import List, Optional
|
||||
|
||||
from sqlalchemy import String, Text, DateTime, func
|
||||
from sqlalchemy.dialects.postgresql import JSONB
|
||||
from sqlalchemy.orm import Mapped, mapped_column
|
||||
|
||||
from .base import BaseDataModel
|
||||
|
||||
|
||||
class CustomToolsetModel(BaseDataModel):
|
||||
"""用户自定义工具组:把若干个非 system / 非 mcp 的工具插件打包成一个 toolset。
|
||||
|
||||
``tools`` 字段保存工具名列表(即 ``plugin/tool_plugin/`` 下的目录名);
|
||||
GSM 启动时按列表把对应工具函数装进同一个 ``FunctionToolset``。
|
||||
"""
|
||||
|
||||
__tablename__ = "custom_toolset"
|
||||
|
||||
toolset_id: Mapped[str] = mapped_column(String(64), primary_key=True)
|
||||
name: Mapped[str] = mapped_column(String(100), nullable=False)
|
||||
description: Mapped[Optional[str]] = mapped_column(Text)
|
||||
owner_id: Mapped[Optional[str]] = mapped_column(String(64), index=True)
|
||||
tools: Mapped[List[str]] = mapped_column(
|
||||
JSONB, default=list, comment="工具名列表,仅允许非 system/非 mcp 的工具"
|
||||
)
|
||||
created_at: Mapped[datetime] = mapped_column(
|
||||
DateTime(timezone=True), server_default=func.now()
|
||||
)
|
||||
updated_at: Mapped[datetime] = mapped_column(
|
||||
DateTime(timezone=True), server_default=func.now(), onupdate=func.now()
|
||||
)
|
||||
@@ -0,0 +1,29 @@
|
||||
from typing import Optional
|
||||
from datetime import datetime
|
||||
|
||||
from sqlalchemy import String, DateTime, func
|
||||
from sqlalchemy.dialects.postgresql import JSONB
|
||||
from sqlalchemy.orm import Mapped, mapped_column
|
||||
|
||||
from .base import BaseDataModel
|
||||
|
||||
|
||||
class MCPServerModel(BaseDataModel):
|
||||
"""MCP 服务器注册表,记录 stdio/sse/http 三种 transport 的连接配置。"""
|
||||
|
||||
__tablename__ = "mcp_server"
|
||||
|
||||
server_id: Mapped[str] = mapped_column(String(64), primary_key=True)
|
||||
name: Mapped[str] = mapped_column(String(100), nullable=False)
|
||||
transport: Mapped[str] = mapped_column(String(16), nullable=False)
|
||||
command: Mapped[Optional[str]] = mapped_column(String(255))
|
||||
args: Mapped[list] = mapped_column(JSONB, default=list)
|
||||
url: Mapped[Optional[str]] = mapped_column(String(500))
|
||||
tool_prefix: Mapped[Optional[str]] = mapped_column(String(64))
|
||||
env: Mapped[dict] = mapped_column(JSONB, default=dict, comment="敏感字段已 Fernet 加密")
|
||||
created_at: Mapped[datetime] = mapped_column(
|
||||
DateTime(timezone=True), server_default=func.now()
|
||||
)
|
||||
updated_at: Mapped[datetime] = mapped_column(
|
||||
DateTime(timezone=True), server_default=func.now(), onupdate=func.now()
|
||||
)
|
||||
@@ -0,0 +1,22 @@
|
||||
from datetime import datetime
|
||||
|
||||
from sqlalchemy import String, DateTime, func
|
||||
from sqlalchemy.dialects.postgresql import JSONB
|
||||
from sqlalchemy.orm import Mapped, mapped_column
|
||||
|
||||
from .base import BaseDataModel
|
||||
|
||||
|
||||
class ToolConfigModel(BaseDataModel):
|
||||
"""工具运行期配置(如 Tavily API key);config 内的敏感字段已 Fernet 加密。"""
|
||||
|
||||
__tablename__ = "tool_config"
|
||||
|
||||
tool_name: Mapped[str] = mapped_column(String(100), primary_key=True)
|
||||
config: Mapped[dict] = mapped_column(JSONB, default=dict)
|
||||
created_at: Mapped[datetime] = mapped_column(
|
||||
DateTime(timezone=True), server_default=func.now()
|
||||
)
|
||||
updated_at: Mapped[datetime] = mapped_column(
|
||||
DateTime(timezone=True), server_default=func.now(), onupdate=func.now()
|
||||
)
|
||||
@@ -79,3 +79,28 @@ class WorkflowContextModel(BaseDataModel):
|
||||
updated_at: Mapped[str] = mapped_column(
|
||||
DateTime(timezone=True), server_default=func.now(), onupdate=func.now()
|
||||
)
|
||||
|
||||
|
||||
class WorkflowGraphStateModel(BaseDataModel):
|
||||
"""pydantic_graph 持久化 blob 的存储表。
|
||||
|
||||
与 ``workflow_context`` 解耦——后者面向"业务展示 / 用户可读",前者面向
|
||||
"graph 引擎自身的状态恢复"。一份 trace_id 一行,jsonb 直接存 history 全量。
|
||||
"""
|
||||
|
||||
__tablename__ = "workflow_graph_state"
|
||||
|
||||
trace_id: Mapped[str] = mapped_column(
|
||||
String(64), primary_key=True, comment="对应的工作流 Trace ID"
|
||||
)
|
||||
history: Mapped[list] = mapped_column(
|
||||
JSONB,
|
||||
default=list,
|
||||
comment="pydantic_graph FullStatePersistence.history 的 JSON 序列化",
|
||||
)
|
||||
created_at: Mapped[str] = mapped_column(
|
||||
DateTime(timezone=True), server_default=func.now()
|
||||
)
|
||||
updated_at: Mapped[str] = mapped_column(
|
||||
DateTime(timezone=True), server_default=func.now(), onupdate=func.now()
|
||||
)
|
||||
|
||||
@@ -0,0 +1,81 @@
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from sqlalchemy import delete, select
|
||||
|
||||
from kilostar.core.postgres_database.database_exception import database_exception
|
||||
from kilostar.core.postgres_database.model.custom_toolset import CustomToolsetModel
|
||||
|
||||
|
||||
class CustomToolsetDatabase:
|
||||
"""用户自定义工具组 DAO。``tools`` 字段是工具名列表,业务层负责保证只放非 system/非 mcp 的工具。"""
|
||||
|
||||
def __init__(self, async_session_maker):
|
||||
self.async_session_maker = async_session_maker
|
||||
|
||||
@staticmethod
|
||||
def _to_dict(row: CustomToolsetModel) -> Dict[str, Any]:
|
||||
return {
|
||||
"toolset_id": row.toolset_id,
|
||||
"name": row.name,
|
||||
"description": row.description,
|
||||
"owner_id": row.owner_id,
|
||||
"tools": list(row.tools or []),
|
||||
}
|
||||
|
||||
@database_exception
|
||||
async def upsert(
|
||||
self,
|
||||
toolset_id: str,
|
||||
name: str,
|
||||
tools: List[str],
|
||||
description: Optional[str] = None,
|
||||
owner_id: Optional[str] = None,
|
||||
) -> Dict[str, Any]:
|
||||
async with self.async_session_maker() as session:
|
||||
stmt = select(CustomToolsetModel).where(
|
||||
CustomToolsetModel.toolset_id == toolset_id
|
||||
)
|
||||
row = (await session.execute(stmt)).scalar_one_or_none()
|
||||
if row:
|
||||
row.name = name
|
||||
row.description = description
|
||||
row.owner_id = owner_id
|
||||
row.tools = list(tools)
|
||||
else:
|
||||
row = CustomToolsetModel(
|
||||
toolset_id=toolset_id,
|
||||
name=name,
|
||||
description=description,
|
||||
owner_id=owner_id,
|
||||
tools=list(tools),
|
||||
)
|
||||
session.add(row)
|
||||
await session.commit()
|
||||
await session.refresh(row)
|
||||
return self._to_dict(row)
|
||||
|
||||
@database_exception
|
||||
async def get(self, toolset_id: str) -> Optional[Dict[str, Any]]:
|
||||
async with self.async_session_maker() as session:
|
||||
stmt = select(CustomToolsetModel).where(
|
||||
CustomToolsetModel.toolset_id == toolset_id
|
||||
)
|
||||
row = (await session.execute(stmt)).scalar_one_or_none()
|
||||
return self._to_dict(row) if row else None
|
||||
|
||||
@database_exception
|
||||
async def list_all(self) -> List[Dict[str, Any]]:
|
||||
async with self.async_session_maker() as session:
|
||||
stmt = select(CustomToolsetModel)
|
||||
rows = (await session.execute(stmt)).scalars().all()
|
||||
return [self._to_dict(r) for r in rows]
|
||||
|
||||
@database_exception
|
||||
async def delete(self, toolset_id: str) -> bool:
|
||||
async with self.async_session_maker() as session:
|
||||
stmt = delete(CustomToolsetModel).where(
|
||||
CustomToolsetModel.toolset_id == toolset_id
|
||||
)
|
||||
result = await session.execute(stmt)
|
||||
await session.commit()
|
||||
return result.rowcount > 0
|
||||
@@ -0,0 +1,79 @@
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from sqlalchemy import delete, select
|
||||
|
||||
from kilostar.core.postgres_database.database_exception import database_exception
|
||||
from kilostar.core.postgres_database.model.mcp_server import MCPServerModel
|
||||
from kilostar.utils.crypto import decrypt_dict_secrets, encrypt_dict_secrets
|
||||
|
||||
|
||||
class MCPServerDatabase:
|
||||
"""MCP 服务器配置 DAO;写入前自动加密 ``env`` 中的敏感字段,读出后自动解密。"""
|
||||
|
||||
def __init__(self, async_session_maker):
|
||||
self.async_session_maker = async_session_maker
|
||||
|
||||
@staticmethod
|
||||
def _to_dict(row: MCPServerModel) -> Dict[str, Any]:
|
||||
return {
|
||||
"server_id": row.server_id,
|
||||
"name": row.name,
|
||||
"transport": row.transport,
|
||||
"command": row.command,
|
||||
"args": row.args or [],
|
||||
"url": row.url,
|
||||
"tool_prefix": row.tool_prefix,
|
||||
"env": decrypt_dict_secrets(row.env or {}),
|
||||
}
|
||||
|
||||
@database_exception
|
||||
async def upsert(self, server_id: str, config: Dict[str, Any]) -> Dict[str, Any]:
|
||||
env = encrypt_dict_secrets(config.get("env") or {})
|
||||
async with self.async_session_maker() as session:
|
||||
stmt = select(MCPServerModel).where(MCPServerModel.server_id == server_id)
|
||||
row = (await session.execute(stmt)).scalar_one_or_none()
|
||||
if row:
|
||||
row.name = config.get("name", row.name)
|
||||
row.transport = config.get("transport", row.transport)
|
||||
row.command = config.get("command")
|
||||
row.args = config.get("args") or []
|
||||
row.url = config.get("url")
|
||||
row.tool_prefix = config.get("tool_prefix")
|
||||
row.env = env
|
||||
else:
|
||||
row = MCPServerModel(
|
||||
server_id=server_id,
|
||||
name=config.get("name", server_id),
|
||||
transport=config.get("transport", "stdio"),
|
||||
command=config.get("command"),
|
||||
args=config.get("args") or [],
|
||||
url=config.get("url"),
|
||||
tool_prefix=config.get("tool_prefix"),
|
||||
env=env,
|
||||
)
|
||||
session.add(row)
|
||||
await session.commit()
|
||||
await session.refresh(row)
|
||||
return self._to_dict(row)
|
||||
|
||||
@database_exception
|
||||
async def get(self, server_id: str) -> Optional[Dict[str, Any]]:
|
||||
async with self.async_session_maker() as session:
|
||||
stmt = select(MCPServerModel).where(MCPServerModel.server_id == server_id)
|
||||
row = (await session.execute(stmt)).scalar_one_or_none()
|
||||
return self._to_dict(row) if row else None
|
||||
|
||||
@database_exception
|
||||
async def list_all(self) -> List[Dict[str, Any]]:
|
||||
async with self.async_session_maker() as session:
|
||||
stmt = select(MCPServerModel)
|
||||
rows = (await session.execute(stmt)).scalars().all()
|
||||
return [self._to_dict(r) for r in rows]
|
||||
|
||||
@database_exception
|
||||
async def delete(self, server_id: str) -> bool:
|
||||
async with self.async_session_maker() as session:
|
||||
stmt = delete(MCPServerModel).where(MCPServerModel.server_id == server_id)
|
||||
result = await session.execute(stmt)
|
||||
await session.commit()
|
||||
return result.rowcount > 0
|
||||
@@ -17,10 +17,37 @@ from typing import List
|
||||
from kilostar.core.postgres_database.model.provider import ProviderModel
|
||||
from sqlalchemy import select
|
||||
from kilostar.core.postgres_database.database_exception import database_exception
|
||||
from kilostar.utils.crypto import (
|
||||
CryptoError,
|
||||
decrypt_secret,
|
||||
encrypt_secret,
|
||||
is_encrypted,
|
||||
)
|
||||
from kilostar.utils.logger import get_logger
|
||||
|
||||
logger = get_logger("provider_dao")
|
||||
|
||||
|
||||
def _decrypt_apikey(value):
|
||||
if not value:
|
||||
return value
|
||||
if not is_encrypted(value):
|
||||
return value
|
||||
try:
|
||||
return decrypt_secret(value)
|
||||
except CryptoError as e:
|
||||
logger.error(f"Provider apikey 解密失败: {e}")
|
||||
return value
|
||||
|
||||
|
||||
def _encrypt_apikey(value):
|
||||
if not value or is_encrypted(value):
|
||||
return value
|
||||
return encrypt_secret(value)
|
||||
|
||||
|
||||
class ProviderDatabase:
|
||||
"""Provider 表的 DAO:模型 Provider 的增删查改。"""
|
||||
"""Provider 表的 DAO:模型 Provider 的增删查改;``provider_apikey`` 透明 Fernet 加解密。"""
|
||||
|
||||
def __init__(self, async_session_maker):
|
||||
self.async_session_maker = async_session_maker
|
||||
@@ -37,11 +64,10 @@ class ProviderDatabase:
|
||||
provider_id=provider.provider_id,
|
||||
provider_title=provider.provider_title,
|
||||
provider_url=provider.provider_url,
|
||||
provider_apikey=provider.provider_apikey,
|
||||
provider_apikey=_decrypt_apikey(provider.provider_apikey),
|
||||
provider_models=provider.provider_models,
|
||||
provider_type=provider.provider_type,
|
||||
provider_owner=provider.provider_owner,
|
||||
provider_status=provider.provider_status,
|
||||
is_active=provider.is_active,
|
||||
)
|
||||
for provider in results
|
||||
@@ -50,7 +76,9 @@ class ProviderDatabase:
|
||||
|
||||
@database_exception
|
||||
async def add_provider(self, **kwargs) -> None:
|
||||
"""新建一条 Provider 记录;字段通过 kwargs 直接传给 ProviderModel。"""
|
||||
"""新建一条 Provider 记录;``provider_apikey`` 写入前自动加密。"""
|
||||
if "provider_apikey" in kwargs:
|
||||
kwargs["provider_apikey"] = _encrypt_apikey(kwargs["provider_apikey"])
|
||||
async with self.async_session_maker() as session:
|
||||
provider = ProviderModel(**kwargs)
|
||||
session.add(provider)
|
||||
@@ -67,7 +95,9 @@ class ProviderDatabase:
|
||||
|
||||
@database_exception
|
||||
async def update_provider(self, provider_id: str, **kwargs) -> None:
|
||||
"""部分更新指定 Provider 的字段;不存在时返回 None,否则返回刷新后的对象。"""
|
||||
"""部分更新指定 Provider 的字段;``provider_apikey`` 写入前自动加密。"""
|
||||
if "provider_apikey" in kwargs:
|
||||
kwargs["provider_apikey"] = _encrypt_apikey(kwargs["provider_apikey"])
|
||||
async with self.async_session_maker() as session:
|
||||
provider = await session.get(ProviderModel, provider_id)
|
||||
if provider is not None:
|
||||
@@ -76,5 +106,7 @@ class ProviderDatabase:
|
||||
session.add(provider)
|
||||
await session.commit()
|
||||
await session.refresh(provider)
|
||||
# 解密返回,方便上游使用
|
||||
provider.provider_apikey = _decrypt_apikey(provider.provider_apikey)
|
||||
return provider
|
||||
return None
|
||||
|
||||
@@ -0,0 +1,58 @@
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from sqlalchemy import delete, select
|
||||
|
||||
from kilostar.core.postgres_database.database_exception import database_exception
|
||||
from kilostar.core.postgres_database.model.tool_config import ToolConfigModel
|
||||
from kilostar.utils.crypto import decrypt_dict_secrets, encrypt_dict_secrets
|
||||
|
||||
|
||||
class ToolConfigDatabase:
|
||||
"""工具运行期配置 DAO;config 中的敏感字段(key/token/secret/password 系列)自动加解密。"""
|
||||
|
||||
def __init__(self, async_session_maker):
|
||||
self.async_session_maker = async_session_maker
|
||||
|
||||
@staticmethod
|
||||
def _to_dict(row: ToolConfigModel) -> Dict[str, Any]:
|
||||
return {
|
||||
"tool_name": row.tool_name,
|
||||
"config": decrypt_dict_secrets(row.config or {}),
|
||||
}
|
||||
|
||||
@database_exception
|
||||
async def upsert(self, tool_name: str, config: Dict[str, Any]) -> Dict[str, Any]:
|
||||
encrypted = encrypt_dict_secrets(config or {})
|
||||
async with self.async_session_maker() as session:
|
||||
stmt = select(ToolConfigModel).where(ToolConfigModel.tool_name == tool_name)
|
||||
row = (await session.execute(stmt)).scalar_one_or_none()
|
||||
if row:
|
||||
row.config = encrypted
|
||||
else:
|
||||
row = ToolConfigModel(tool_name=tool_name, config=encrypted)
|
||||
session.add(row)
|
||||
await session.commit()
|
||||
await session.refresh(row)
|
||||
return self._to_dict(row)
|
||||
|
||||
@database_exception
|
||||
async def get(self, tool_name: str) -> Optional[Dict[str, Any]]:
|
||||
async with self.async_session_maker() as session:
|
||||
stmt = select(ToolConfigModel).where(ToolConfigModel.tool_name == tool_name)
|
||||
row = (await session.execute(stmt)).scalar_one_or_none()
|
||||
return self._to_dict(row) if row else None
|
||||
|
||||
@database_exception
|
||||
async def list_all(self) -> List[Dict[str, Any]]:
|
||||
async with self.async_session_maker() as session:
|
||||
stmt = select(ToolConfigModel)
|
||||
rows = (await session.execute(stmt)).scalars().all()
|
||||
return [self._to_dict(r) for r in rows]
|
||||
|
||||
@database_exception
|
||||
async def delete(self, tool_name: str) -> bool:
|
||||
async with self.async_session_maker() as session:
|
||||
stmt = delete(ToolConfigModel).where(ToolConfigModel.tool_name == tool_name)
|
||||
result = await session.execute(stmt)
|
||||
await session.commit()
|
||||
return result.rowcount > 0
|
||||
@@ -17,6 +17,7 @@ from typing import List, Optional
|
||||
from kilostar.core.postgres_database.model.workflow import (
|
||||
Workflow,
|
||||
WorkflowContextModel,
|
||||
WorkflowGraphStateModel,
|
||||
)
|
||||
from sqlalchemy.ext.asyncio import async_sessionmaker, AsyncSession
|
||||
from kilostar.core.postgres_database.database_exception import database_exception
|
||||
@@ -101,3 +102,58 @@ class WorkflowDatabase:
|
||||
)
|
||||
results = await session.execute(statement)
|
||||
return results.scalar_one_or_none()
|
||||
|
||||
# ─── pydantic_graph 持久化(resume 用)─────────────────────────────
|
||||
|
||||
@database_exception
|
||||
async def upsert_workflow_graph_state(
|
||||
self, trace_id: str, history: list
|
||||
) -> WorkflowGraphStateModel:
|
||||
"""落 pydantic_graph FullStatePersistence.history 的 JSON 视图。
|
||||
|
||||
每个节点边界都会被引擎调一次,覆盖式写入;回滚到任一历史点是 graph
|
||||
引擎自身的能力,DB 这层只保留最新版本。
|
||||
"""
|
||||
async with self.async_session_maker() as session:
|
||||
statement = select(WorkflowGraphStateModel).where(
|
||||
WorkflowGraphStateModel.trace_id == trace_id
|
||||
)
|
||||
results = await session.execute(statement)
|
||||
record = results.scalar_one_or_none()
|
||||
if record:
|
||||
record.history = history
|
||||
else:
|
||||
record = WorkflowGraphStateModel(
|
||||
trace_id=trace_id, history=history
|
||||
)
|
||||
session.add(record)
|
||||
await session.commit()
|
||||
await session.refresh(record)
|
||||
return record
|
||||
|
||||
@database_exception
|
||||
async def get_workflow_graph_state(
|
||||
self, trace_id: str
|
||||
) -> Optional[WorkflowGraphStateModel]:
|
||||
"""读取 graph 持久化 history;不存在返回 ``None``。"""
|
||||
async with self.async_session_maker() as session:
|
||||
statement = select(WorkflowGraphStateModel).where(
|
||||
WorkflowGraphStateModel.trace_id == trace_id
|
||||
)
|
||||
results = await session.execute(statement)
|
||||
return results.scalar_one_or_none()
|
||||
|
||||
@database_exception
|
||||
async def delete_workflow_graph_state(self, trace_id: str) -> bool:
|
||||
"""删除某个工作流的 graph 持久化记录(用于显式清理)。"""
|
||||
async with self.async_session_maker() as session:
|
||||
statement = select(WorkflowGraphStateModel).where(
|
||||
WorkflowGraphStateModel.trace_id == trace_id
|
||||
)
|
||||
results = await session.execute(statement)
|
||||
record = results.scalar_one_or_none()
|
||||
if record is None:
|
||||
return False
|
||||
await session.delete(record)
|
||||
await session.commit()
|
||||
return True
|
||||
|
||||
@@ -38,6 +38,9 @@ from kilostar.core.postgres_database.model.chat_history import (
|
||||
ChatHistoryMessage,
|
||||
)
|
||||
from kilostar.core.postgres_database.model.system_node import SystemNodeConfigModel
|
||||
from kilostar.core.postgres_database.model.mcp_server import MCPServerModel
|
||||
from kilostar.core.postgres_database.model.tool_config import ToolConfigModel
|
||||
from kilostar.core.postgres_database.model.custom_toolset import CustomToolsetModel
|
||||
|
||||
from .module.individual import IndividualDatabase
|
||||
from .module.user import AuthDatabase
|
||||
@@ -45,6 +48,9 @@ from .module.provider import ProviderDatabase
|
||||
from .module.system_node import SystemNodeDatabase
|
||||
from .module.workflow import WorkflowDatabase
|
||||
from .module.chat_history import ChatHistoryDatabase
|
||||
from .module.mcp_server import MCPServerDatabase
|
||||
from .module.tool_config import ToolConfigDatabase
|
||||
from .module.custom_toolset import CustomToolsetDatabase
|
||||
|
||||
|
||||
@ray.remote
|
||||
@@ -76,6 +82,9 @@ class PostgresDatabase:
|
||||
self._system_node_database = SystemNodeDatabase(self.async_session_maker)
|
||||
self._workflow_database = WorkflowDatabase(self.async_session_maker)
|
||||
self._chat_history_database = ChatHistoryDatabase(self.async_session_maker)
|
||||
self._mcp_server_database = MCPServerDatabase(self.async_session_maker)
|
||||
self._tool_config_database = ToolConfigDatabase(self.async_session_maker)
|
||||
self._custom_toolset_database = CustomToolsetDatabase(self.async_session_maker)
|
||||
|
||||
self.ready_event = asyncio.Event()
|
||||
|
||||
@@ -91,6 +100,15 @@ class PostgresDatabase:
|
||||
finally:
|
||||
self.ready_event.set()
|
||||
|
||||
async def ping(self) -> bool:
|
||||
"""轻量探活:等待 ready 后执行 ``SELECT 1``。"""
|
||||
from sqlalchemy import text
|
||||
|
||||
await self.ready_event.wait()
|
||||
async with self.async_engine.connect() as conn:
|
||||
await conn.execute(text("SELECT 1"))
|
||||
return True
|
||||
|
||||
# Auth Database Methods
|
||||
async def add_user(self, user_name: str, hashed_password: str):
|
||||
"""新建一名用户。"""
|
||||
@@ -242,6 +260,24 @@ class PostgresDatabase:
|
||||
await self.ready_event.wait()
|
||||
return await self._workflow_database.get_workflow_context(trace_id)
|
||||
|
||||
# Workflow Graph State (pydantic_graph 持久化)
|
||||
async def upsert_workflow_graph_state(self, trace_id: str, history: list):
|
||||
"""覆盖式写入 graph 持久化 history(pydantic_graph 节点边界自动调用)。"""
|
||||
await self.ready_event.wait()
|
||||
return await self._workflow_database.upsert_workflow_graph_state(
|
||||
trace_id, history
|
||||
)
|
||||
|
||||
async def get_workflow_graph_state(self, trace_id: str):
|
||||
"""读取 graph 持久化记录,用于跨进程 resume。"""
|
||||
await self.ready_event.wait()
|
||||
return await self._workflow_database.get_workflow_graph_state(trace_id)
|
||||
|
||||
async def delete_workflow_graph_state(self, trace_id: str):
|
||||
"""显式清理 graph 持久化记录(已完成/失败的 workflow 释放空间)。"""
|
||||
await self.ready_event.wait()
|
||||
return await self._workflow_database.delete_workflow_graph_state(trace_id)
|
||||
|
||||
# Chat History Database Methods
|
||||
async def create_chat_session(self, user_id: str, title: str = "新对话"):
|
||||
"""新建一个聊天会话。"""
|
||||
@@ -264,3 +300,79 @@ class PostgresDatabase:
|
||||
"""返回某个聊天会话的全部消息。"""
|
||||
await self.ready_event.wait()
|
||||
return await self._chat_history_database.list_chat_messages(chat_id)
|
||||
|
||||
# MCP Server Database Methods
|
||||
async def upsert_mcp_server(self, server_id: str, config: dict):
|
||||
"""插入或更新一条 MCP 服务器配置;env 中敏感字段自动加密。"""
|
||||
await self.ready_event.wait()
|
||||
return await self._mcp_server_database.upsert(server_id, config)
|
||||
|
||||
async def get_mcp_server_db(self, server_id: str):
|
||||
"""读取单条 MCP 服务器配置;env 自动解密。"""
|
||||
await self.ready_event.wait()
|
||||
return await self._mcp_server_database.get(server_id)
|
||||
|
||||
async def list_mcp_servers_db(self):
|
||||
"""读取全部 MCP 服务器配置。"""
|
||||
await self.ready_event.wait()
|
||||
return await self._mcp_server_database.list_all()
|
||||
|
||||
async def delete_mcp_server_db(self, server_id: str):
|
||||
"""删除某条 MCP 服务器配置。"""
|
||||
await self.ready_event.wait()
|
||||
return await self._mcp_server_database.delete(server_id)
|
||||
|
||||
# Tool Config Database Methods
|
||||
async def upsert_tool_config(self, tool_name: str, config: dict):
|
||||
"""插入或更新某工具的运行期配置;敏感字段自动加密。"""
|
||||
await self.ready_event.wait()
|
||||
return await self._tool_config_database.upsert(tool_name, config)
|
||||
|
||||
async def get_tool_config_db(self, tool_name: str):
|
||||
"""读取某工具的运行期配置;敏感字段自动解密。"""
|
||||
await self.ready_event.wait()
|
||||
return await self._tool_config_database.get(tool_name)
|
||||
|
||||
async def list_tool_configs_db(self):
|
||||
"""读取全部工具的运行期配置。"""
|
||||
await self.ready_event.wait()
|
||||
return await self._tool_config_database.list_all()
|
||||
|
||||
async def delete_tool_config_db(self, tool_name: str):
|
||||
"""删除某工具的运行期配置。"""
|
||||
await self.ready_event.wait()
|
||||
return await self._tool_config_database.delete(tool_name)
|
||||
|
||||
# Custom Toolset Database Methods
|
||||
async def upsert_custom_toolset(
|
||||
self,
|
||||
toolset_id: str,
|
||||
name: str,
|
||||
tools: list,
|
||||
description: str = None,
|
||||
owner_id: str = None,
|
||||
):
|
||||
"""插入或更新一个用户自定义工具组。"""
|
||||
await self.ready_event.wait()
|
||||
return await self._custom_toolset_database.upsert(
|
||||
toolset_id=toolset_id,
|
||||
name=name,
|
||||
tools=tools,
|
||||
description=description,
|
||||
owner_id=owner_id,
|
||||
)
|
||||
|
||||
async def get_custom_toolset(self, toolset_id: str):
|
||||
"""按 ID 读取一个自定义工具组。"""
|
||||
await self.ready_event.wait()
|
||||
return await self._custom_toolset_database.get(toolset_id)
|
||||
|
||||
async def list_custom_toolsets(self):
|
||||
"""读取全部自定义工具组。"""
|
||||
await self.ready_event.wait()
|
||||
return await self._custom_toolset_database.list_all()
|
||||
|
||||
async def delete_custom_toolset(self, toolset_id: str):
|
||||
"""删除一个自定义工具组。"""
|
||||
await self.ready_event.wait()
|
||||
return await self._custom_toolset_database.delete(toolset_id)
|
||||
|
||||
@@ -0,0 +1,191 @@
|
||||
# Copyright 2026 zhaoxi826
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Postgres 后端的 ``BaseStatePersistence`` 实现,让 graph 跨进程 resume。
|
||||
|
||||
设计思路:
|
||||
|
||||
- **复用 ``FullStatePersistence`` 的内存语义**:snapshot/record_run/load_next 等
|
||||
的实现已经做得很完备(NodeSnapshot 状态机、deep_copy、type adapter),不重写
|
||||
这些细节,只是在每次"史会发生变更"的钩子触发后,把内存 history 序列化为 JSON
|
||||
写到 ``workflow_graph_state`` 表里。
|
||||
- **异步 IO 解耦**:DB 写入是 fire-and-forget 模式,不在 graph 节点路径上阻塞——
|
||||
graph 跑得快,持久化追得上即可。但 ``snapshot_end`` 一定会 await(确保关机
|
||||
之前最终 history 落盘)。
|
||||
- **resume 入口**:从 DB 读出 history JSON → ``load_json`` 还原 → ``Graph.iter_from_persistence``
|
||||
跑剩余节点。
|
||||
|
||||
这一层不直接持有 SQLAlchemy session;通过两个 awaitable 注入 IO,便于测试。
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
from contextlib import asynccontextmanager
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any, AsyncIterator, Awaitable, Callable, Optional
|
||||
|
||||
from pydantic_graph import BaseNode, End
|
||||
from pydantic_graph.persistence import (
|
||||
BaseStatePersistence,
|
||||
NodeSnapshot,
|
||||
Snapshot,
|
||||
)
|
||||
from pydantic_graph.persistence.in_mem import FullStatePersistence
|
||||
|
||||
from kilostar.utils.logger import get_logger
|
||||
|
||||
_logger = get_logger("graph_persistence")
|
||||
|
||||
|
||||
# IO 注入签名:写一份 history JSON / 读一份 history JSON(None=没有)
|
||||
WriteHistory = Callable[[str, Any], Awaitable[None]]
|
||||
ReadHistory = Callable[[str], Awaitable[Optional[Any]]]
|
||||
|
||||
|
||||
@dataclass
|
||||
class PostgresStatePersistence(BaseStatePersistence):
|
||||
"""复用 ``FullStatePersistence`` 内存语义 + 把 history 落 postgres。
|
||||
|
||||
每个 hook 触发后异步把 history 写库,DB 失败不影响 graph 继续推进
|
||||
(只记 warning),保证 graph 自身的可用性。
|
||||
"""
|
||||
|
||||
trace_id: str
|
||||
write_history: WriteHistory
|
||||
read_history: ReadHistory
|
||||
_inner: FullStatePersistence = field(default_factory=FullStatePersistence)
|
||||
|
||||
# ─── BaseStatePersistence 接口 ──────────────────────────────────
|
||||
|
||||
async def snapshot_node(self, state, next_node):
|
||||
await self._inner.snapshot_node(state, next_node)
|
||||
await self._flush()
|
||||
|
||||
async def snapshot_node_if_new(self, snapshot_id, state, next_node):
|
||||
await self._inner.snapshot_node_if_new(snapshot_id, state, next_node)
|
||||
await self._flush()
|
||||
|
||||
async def snapshot_end(self, state, end):
|
||||
await self._inner.snapshot_end(state, end)
|
||||
# graph 已结束:必须确保最终 snapshot 落盘后再返回
|
||||
await self._flush(must_succeed=True)
|
||||
|
||||
@asynccontextmanager
|
||||
async def record_run(self, snapshot_id: str) -> AsyncIterator[None]:
|
||||
async with self._inner.record_run(snapshot_id):
|
||||
yield
|
||||
# record_run 退出时 NodeSnapshot 状态从 running → success/error,需要刷盘
|
||||
await self._flush()
|
||||
|
||||
async def load_next(self) -> Optional[NodeSnapshot]:
|
||||
return await self._inner.load_next()
|
||||
|
||||
async def load_all(self) -> list[Snapshot]:
|
||||
return await self._inner.load_all()
|
||||
|
||||
def should_set_types(self) -> bool:
|
||||
return self._inner.should_set_types()
|
||||
|
||||
def set_types(self, state_type, run_end_type):
|
||||
self._inner.set_types(state_type, run_end_type)
|
||||
|
||||
# ─── 序列化 / 反序列化 ─────────────────────────────────────────
|
||||
|
||||
async def hydrate(self) -> bool:
|
||||
"""从 DB 拉一次 history 并恢复到内存;返回是否拉到了内容。
|
||||
|
||||
在 ``run_workflow_task`` 决定 fresh/resume 时调用。``set_types`` 必须
|
||||
在调用前由 ``Graph.iter_from_persistence`` 替我们调过——否则
|
||||
``_snapshots_type_adapter`` 还没准备好就会 assert 失败。
|
||||
"""
|
||||
try:
|
||||
raw = await self.read_history(self.trace_id)
|
||||
except Exception as e: # pragma: no cover - 防御
|
||||
_logger.warning(f"hydrate read failed: {e}")
|
||||
return False
|
||||
if not raw:
|
||||
return False
|
||||
try:
|
||||
self._inner.load_json(_to_json_bytes(raw))
|
||||
except AssertionError:
|
||||
# 没 set_types 时 load_json 会 assert;调用方需先调 set_graph_types
|
||||
raise
|
||||
except Exception as e:
|
||||
_logger.warning(f"hydrate load_json failed: {e}")
|
||||
return False
|
||||
return True
|
||||
|
||||
async def _flush(self, *, must_succeed: bool = False) -> None:
|
||||
"""把内存 history 序列化后异步写 DB。
|
||||
|
||||
``must_succeed=False`` 时 DB 异常仅记 warning;``True`` 时再抛出,
|
||||
让 ``snapshot_end`` 这种"必须落盘"的场景能感知失败。
|
||||
"""
|
||||
if self._inner._snapshots_type_adapter is None:
|
||||
# 类型还没注册,没法 dump(首次 set_graph_types 还没跑过)
|
||||
return
|
||||
try:
|
||||
blob = self._inner.dump_json()
|
||||
except Exception as e: # pragma: no cover
|
||||
_logger.warning(f"dump history failed: {e}")
|
||||
return
|
||||
try:
|
||||
await self.write_history(self.trace_id, _from_json_bytes(blob))
|
||||
except Exception as e:
|
||||
_logger.warning(f"persist history failed: {e}")
|
||||
if must_succeed:
|
||||
raise
|
||||
|
||||
|
||||
def _to_json_bytes(value: Any) -> bytes:
|
||||
"""把 DB 读出的 ``list[dict]`` / ``str`` / ``bytes`` 都规范成 bytes 喂 load_json。"""
|
||||
if isinstance(value, (bytes, bytearray)):
|
||||
return bytes(value)
|
||||
if isinstance(value, str):
|
||||
return value.encode("utf-8")
|
||||
# 假定是 list/dict(来自 JSONB),转回 JSON 字符串
|
||||
import json as _json
|
||||
|
||||
return _json.dumps(value).encode("utf-8")
|
||||
|
||||
|
||||
def _from_json_bytes(blob: bytes) -> Any:
|
||||
"""把 ``dump_json`` 出来的 bytes 转成 list/dict 以便 JSONB 友好存储。"""
|
||||
import json as _json
|
||||
|
||||
return _json.loads(blob.decode("utf-8"))
|
||||
|
||||
|
||||
def build_postgres_persistence(trace_id: str) -> PostgresStatePersistence:
|
||||
"""生产环境构造 PostgresStatePersistence:从 ray_actor_hook 取 postgres handle。"""
|
||||
from kilostar.utils.ray_hook import ray_actor_hook
|
||||
|
||||
postgres_database = ray_actor_hook("postgres_database").postgres_database
|
||||
|
||||
async def _write(tid: str, history: Any) -> None:
|
||||
await postgres_database.upsert_workflow_graph_state.remote(tid, history)
|
||||
|
||||
async def _read(tid: str) -> Optional[Any]:
|
||||
record = await postgres_database.get_workflow_graph_state.remote(tid)
|
||||
if record is None:
|
||||
return None
|
||||
# ORM 模型 / dict / list 都兼容
|
||||
return getattr(record, "history", None) or record
|
||||
|
||||
return PostgresStatePersistence(
|
||||
trace_id=trace_id,
|
||||
write_history=_write,
|
||||
read_history=_read,
|
||||
)
|
||||
@@ -13,7 +13,7 @@
|
||||
# limitations under the License.
|
||||
|
||||
from pydantic import BaseModel, Field, model_validator
|
||||
from typing import Optional, Union, List, Dict, Any
|
||||
from typing import Literal, Optional, Union, List, Dict, Any
|
||||
from .model import LogicGate, WorkflowMetadata, WorkStepStatus, WorkflowStatus
|
||||
from ulid import ULID
|
||||
from datetime import datetime
|
||||
@@ -61,10 +61,22 @@ class WorkflowStep(BaseModel):
|
||||
default=None, description="前置依赖输出"
|
||||
)
|
||||
outputs: Optional[str] = Field(default=None, description="当前步骤产出物变量名")
|
||||
node: Literal["skill_individual", "consciousness_node"] = Field(
|
||||
default="skill_individual",
|
||||
description=(
|
||||
"执行此步的节点类别:\n"
|
||||
"- skill_individual:task 内现起一个专家子个体执行(一次性)\n"
|
||||
"- consciousness_node:远程调用全局 ConsciousnessNode actor"
|
||||
),
|
||||
)
|
||||
agent_id: Optional[str] = Field(
|
||||
default=None,
|
||||
description="分配给 skill_individual 的 Skill Individual 真实 agent_id,不可用名称代替",
|
||||
)
|
||||
require_approval: bool = Field(
|
||||
default=False,
|
||||
description="该步执行前是否需要人工审批;启用时会暂停工作流并通过 SSE 等待用户回执",
|
||||
)
|
||||
logic_gate: Optional[LogicGate] = Field(default=None, description="逻辑跳转控制")
|
||||
|
||||
|
||||
|
||||
@@ -12,166 +12,521 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Workflow 引擎:基于 ``pydantic_graph`` 的状态机驱动。
|
||||
|
||||
调度路径只剩两类节点:
|
||||
|
||||
- ``skill_individual``:在当前 ray task 进程内现起一个 ``SkillIndividual``
|
||||
执行;用后即焚,不消耗 actor。
|
||||
- ``consciousness_node``:远程调用全局 ConsciousnessNode actor 的
|
||||
``working`` 方法(中途指导 / 审查类工作)。
|
||||
|
||||
每一步执行前可设置 ``require_approval=True`` 触发 ``HumanApproval`` 节点:
|
||||
推送 SSE → ``await gwm.get_received`` 阻塞等用户回执 → 决策 continue/abort。
|
||||
|
||||
Graph 还接了 pydantic_graph 的 ``FullStatePersistence``,目前主要用于
|
||||
节点边界自动 snapshot(postgres 持久化保留旧 ``upsert_workflow_context``
|
||||
路径,跨进程 resume 留到后续)。
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import datetime
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Awaitable, Callable, Dict, List, Optional
|
||||
|
||||
import ray
|
||||
from kilostar.core.work.workflow.workflow import KiloStarWorkflow
|
||||
from typing import Dict, Any, List
|
||||
from pydantic import BaseModel, Field
|
||||
from pydantic_graph import BaseNode, End, Graph, GraphRunContext
|
||||
from pydantic_graph.persistence import BaseStatePersistence
|
||||
from pydantic_graph.persistence.in_mem import FullStatePersistence
|
||||
|
||||
from kilostar.core.work.workflow.model import WorkflowStatus
|
||||
|
||||
|
||||
@ray.remote
|
||||
def run_workflow_task(workflow_data: dict, trace_id: str):
|
||||
from kilostar.utils.ray_hook import ray_actor_hook
|
||||
from kilostar.core.work.workflow.model import WorkflowStatus
|
||||
import datetime
|
||||
from pydantic import BaseModel
|
||||
# ─── State / Deps ─────────────────────────────────────────────────────────
|
||||
|
||||
# State passed through graph nodes
|
||||
class WorkflowGraphState(BaseModel):
|
||||
trace_id: str
|
||||
blackboard: Dict[str, Any]
|
||||
work_link: List[Dict[str, Any]]
|
||||
current_step_index: int = 0
|
||||
status: str = "running"
|
||||
logs: List[Dict[str, Any]] = []
|
||||
|
||||
async def save_context(state: WorkflowGraphState):
|
||||
postgres_database = ray_actor_hook("postgres_database").postgres_database
|
||||
await postgres_database.upsert_workflow_context.remote(
|
||||
state.trace_id,
|
||||
workflow_pointer=state.current_step_index,
|
||||
blackboard=state.blackboard,
|
||||
work_link=state.work_link,
|
||||
workflow_status={str(datetime.datetime.now()): state.status},
|
||||
workflow_log=state.logs,
|
||||
)
|
||||
await postgres_database.update_workflow_status.remote(
|
||||
state.trace_id, state.status
|
||||
)
|
||||
global_workflow_manager = ray_actor_hook(
|
||||
"global_workflow_manager"
|
||||
).global_workflow_manager
|
||||
await global_workflow_manager.put_received.remote(
|
||||
state.trace_id, f"执行步骤 {state.current_step_index + 1}..."
|
||||
class WorkflowGraphState(BaseModel):
|
||||
"""图运行期跨节点共享的状态。"""
|
||||
|
||||
trace_id: str
|
||||
blackboard: Dict[str, Any] = Field(default_factory=dict)
|
||||
work_link: List[Dict[str, Any]] = Field(default_factory=list)
|
||||
current_step_index: int = 0
|
||||
final_status: str = WorkflowStatus.RUNNING.value
|
||||
logs: List[Dict[str, Any]] = Field(default_factory=list)
|
||||
original_command: str = ""
|
||||
# 已发过 put_pending 的 HumanApproval step index 列表;resume 后避免重复推送。
|
||||
# 用 list(不是 set)是为了 pydantic_graph 序列化 history 时 JSON 友好。
|
||||
approvals_notified: List[int] = Field(default_factory=list)
|
||||
|
||||
|
||||
# 业务侧执行入口:把 step + state 喂进去,拿到 (output_text, success_bool)
|
||||
StepExecutor = Callable[
|
||||
[Dict[str, Any], "WorkflowGraphState"], Awaitable[tuple[str, bool]]
|
||||
]
|
||||
|
||||
|
||||
@dataclass
|
||||
class WorkflowDeps:
|
||||
"""节点运行期依赖:所有外部 IO 都从这里走,便于测试 mock。
|
||||
|
||||
每个字段都是一个 awaitable,签名贴近原 ``.remote()`` 调用。生产路径由
|
||||
``build_default_deps`` 现场组装真实 actor handle 包装;单测可以传任意
|
||||
``AsyncMock``。
|
||||
|
||||
``run_skill`` / ``run_consciousness`` 是把 step 派发到具体执行器的入口,
|
||||
抽出来既能让 graph 节点保持纯逻辑,又便于测试无需真起 SkillIndividual。
|
||||
"""
|
||||
|
||||
upsert_workflow_context: Callable[..., Awaitable[Any]]
|
||||
update_workflow_status: Callable[[str, str], Awaitable[Any]]
|
||||
put_pending: Callable[[str, str], Awaitable[Any]]
|
||||
get_received: Callable[[str], Awaitable[str]]
|
||||
run_skill: StepExecutor
|
||||
run_consciousness: StepExecutor
|
||||
|
||||
|
||||
# ─── 节点 ─────────────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
@dataclass
|
||||
class Initialize(BaseNode[WorkflowGraphState, WorkflowDeps, str]):
|
||||
"""图入口节点:把 workflow 标记为 RUNNING,发首条 SSE 提示。"""
|
||||
|
||||
async def run(
|
||||
self, ctx: GraphRunContext[WorkflowGraphState, WorkflowDeps]
|
||||
) -> "Dispatch":
|
||||
await ctx.deps.update_workflow_status(
|
||||
ctx.state.trace_id, WorkflowStatus.RUNNING.value
|
||||
)
|
||||
await _persist_context(ctx, status=WorkflowStatus.RUNNING.value)
|
||||
return Dispatch()
|
||||
|
||||
async def execute_step(state: WorkflowGraphState):
|
||||
"""执行单一工作流节点逻辑"""
|
||||
|
||||
@dataclass
|
||||
class Dispatch(BaseNode[WorkflowGraphState, WorkflowDeps, str]):
|
||||
"""读取当前 step,按 ``node`` / ``require_approval`` 字段选择下一节点。"""
|
||||
|
||||
async def run(
|
||||
self,
|
||||
ctx: GraphRunContext[WorkflowGraphState, WorkflowDeps],
|
||||
) -> "HumanApproval | SkillStep | ConsciousnessStep | Finalize":
|
||||
state = ctx.state
|
||||
if state.current_step_index >= len(state.work_link):
|
||||
state.status = WorkflowStatus.COMPLETED
|
||||
return state
|
||||
return Finalize(status=WorkflowStatus.COMPLETED.value)
|
||||
|
||||
step = state.work_link[state.current_step_index]
|
||||
step.get("node", "")
|
||||
action = step.get("action", "")
|
||||
step_data = state.work_link[state.current_step_index]
|
||||
if step_data.get("require_approval"):
|
||||
return HumanApproval()
|
||||
|
||||
# 记录开始状态
|
||||
node_type = step_data.get("node") or "skill_individual"
|
||||
if node_type == "consciousness_node":
|
||||
return ConsciousnessStep()
|
||||
if node_type == "skill_individual":
|
||||
return SkillStep()
|
||||
|
||||
# 未识别的 node 类型按失败处理(保守:不静默吞)
|
||||
state.logs.append(
|
||||
{
|
||||
str(state.current_step_index): [
|
||||
str(datetime.datetime.now()),
|
||||
"working",
|
||||
f"开始执行: {step.get('name', '未命名步骤')}",
|
||||
"failed",
|
||||
f"未知节点类型: {node_type}",
|
||||
]
|
||||
}
|
||||
)
|
||||
await save_context(state)
|
||||
await _persist_context(ctx, status=WorkflowStatus.FAILED.value)
|
||||
return Finalize(status=WorkflowStatus.FAILED.value)
|
||||
|
||||
try:
|
||||
# TODO: 实际对接不同节点执行逻辑 (例如: control_node, agent 技能)
|
||||
# 这里是简化版,向控制节点或指定 skill 发送指令
|
||||
|
||||
# ... 模拟执行逻辑 ...
|
||||
await asyncio.sleep(2)
|
||||
@dataclass
|
||||
class HumanApproval(BaseNode[WorkflowGraphState, WorkflowDeps, str]):
|
||||
"""人工审批节点:暂停 graph,等用户通过 SSE 回执决策。
|
||||
|
||||
# 记录结果
|
||||
state.blackboard[
|
||||
step.get("outputs", f"step_{state.current_step_index}_result")
|
||||
] = "Success execution of " + action
|
||||
state.logs[-1][str(state.current_step_index)] = [
|
||||
str(datetime.datetime.now()),
|
||||
"completed",
|
||||
f"成功: {action}",
|
||||
]
|
||||
回执约定(轻量协议,沿用现有 SSE 通道):
|
||||
|
||||
# 判断逻辑跳转
|
||||
logic_gate = step.get("logic_gate")
|
||||
if logic_gate and logic_gate.get("if_pass") == "exit":
|
||||
state.status = WorkflowStatus.COMPLETED
|
||||
else:
|
||||
state.current_step_index += 1
|
||||
- 含 ``approve`` / ``yes`` / ``ok`` 视为通过,回到 Dispatch 继续执行
|
||||
- 其它(包括 ``reject``/``no``/``abort``)视为拒绝,工作流终止为 FAILED
|
||||
"""
|
||||
|
||||
except Exception as e:
|
||||
state.logs[-1][str(state.current_step_index)] = [
|
||||
str(datetime.datetime.now()),
|
||||
"failed",
|
||||
str(e),
|
||||
]
|
||||
state.status = WorkflowStatus.FAILED
|
||||
logic_gate = step.get("logic_gate")
|
||||
if logic_gate and logic_gate.get("if_fail"):
|
||||
fail_target = logic_gate.get("if_fail")
|
||||
if "jump_to_step_" in fail_target:
|
||||
target_step = int(fail_target.split("_")[-1]) - 1
|
||||
state.current_step_index = target_step
|
||||
state.status = WorkflowStatus.RUNNING
|
||||
async def run(
|
||||
self,
|
||||
ctx: GraphRunContext[WorkflowGraphState, WorkflowDeps],
|
||||
) -> "Dispatch | Finalize":
|
||||
state = ctx.state
|
||||
step_data = state.work_link[state.current_step_index]
|
||||
# idempotent 推送:仅当本 step 还没通知过时才发 put_pending。
|
||||
# 这样 resume 场景(HumanApproval 节点被重新进入)不会给前端发重复消息。
|
||||
if state.current_step_index not in state.approvals_notified:
|
||||
await ctx.deps.put_pending(
|
||||
state.trace_id,
|
||||
f"步骤 {state.current_step_index + 1} ({step_data.get('name', '')}) "
|
||||
f"需要人工审批,请回复 approve / reject。",
|
||||
)
|
||||
state.approvals_notified.append(state.current_step_index)
|
||||
await _persist_context(ctx, status=WorkflowStatus.HANGUP.value)
|
||||
|
||||
await save_context(state)
|
||||
return state
|
||||
reply = (await ctx.deps.get_received(state.trace_id) or "").strip().lower()
|
||||
if any(token in reply for token in ("approve", "yes", "ok")):
|
||||
# 把 require_approval 置否避免无限循环重新进 HumanApproval
|
||||
step_data["require_approval"] = False
|
||||
await _persist_context(ctx, status=WorkflowStatus.RUNNING.value)
|
||||
return Dispatch()
|
||||
|
||||
async def _run():
|
||||
postgres_database = ray_actor_hook("postgres_database").postgres_database
|
||||
await postgres_database.update_workflow_status.remote(
|
||||
trace_id, WorkflowStatus.RUNNING
|
||||
state.logs.append(
|
||||
{
|
||||
str(state.current_step_index): [
|
||||
str(datetime.datetime.now()),
|
||||
"rejected",
|
||||
f"用户拒绝执行该步骤: {reply!r}",
|
||||
]
|
||||
}
|
||||
)
|
||||
await _persist_context(ctx, status=WorkflowStatus.FAILED.value)
|
||||
return Finalize(status=WorkflowStatus.FAILED.value)
|
||||
|
||||
state = WorkflowGraphState(
|
||||
trace_id=trace_id,
|
||||
blackboard={},
|
||||
work_link=workflow_data.get("work_link", []),
|
||||
)
|
||||
await save_context(state)
|
||||
|
||||
# 简单的图执行驱动 (模拟 pydantic-ai.graph.run 行为,直至 Graph 库正式稳定)
|
||||
while state.status == WorkflowStatus.RUNNING and state.current_step_index < len(
|
||||
state.work_link
|
||||
):
|
||||
state = await execute_step(state)
|
||||
@dataclass
|
||||
class SkillStep(BaseNode[WorkflowGraphState, WorkflowDeps, str]):
|
||||
"""skill_individual 路径:当前进程内拉一个专家子个体执行该步。"""
|
||||
|
||||
await postgres_database.update_workflow_status.remote(trace_id, state.status)
|
||||
global_workflow_manager = ray_actor_hook(
|
||||
"global_workflow_manager"
|
||||
).global_workflow_manager
|
||||
async def run(
|
||||
self,
|
||||
ctx: GraphRunContext[WorkflowGraphState, WorkflowDeps],
|
||||
) -> "Dispatch | Finalize":
|
||||
return await _execute_step(ctx, executor=ctx.deps.run_skill)
|
||||
|
||||
|
||||
@dataclass
|
||||
class ConsciousnessStep(BaseNode[WorkflowGraphState, WorkflowDeps, str]):
|
||||
"""consciousness_node 路径:远程调用 ConsciousnessNode actor 处理该步。"""
|
||||
|
||||
async def run(
|
||||
self,
|
||||
ctx: GraphRunContext[WorkflowGraphState, WorkflowDeps],
|
||||
) -> "Dispatch | Finalize":
|
||||
return await _execute_step(ctx, executor=ctx.deps.run_consciousness)
|
||||
|
||||
|
||||
@dataclass
|
||||
class Finalize(BaseNode[WorkflowGraphState, WorkflowDeps, str]):
|
||||
"""收尾节点:写最终状态、推送最终 SSE,把 workflow 终态作为 graph 输出。"""
|
||||
|
||||
status: str
|
||||
|
||||
async def run(
|
||||
self, ctx: GraphRunContext[WorkflowGraphState, WorkflowDeps]
|
||||
) -> End[str]:
|
||||
ctx.state.final_status = self.status
|
||||
await ctx.deps.update_workflow_status(ctx.state.trace_id, self.status)
|
||||
msg = (
|
||||
"工作流执行完成!"
|
||||
if state.status == WorkflowStatus.COMPLETED
|
||||
if self.status == WorkflowStatus.COMPLETED.value
|
||||
else "工作流执行失败。"
|
||||
)
|
||||
await global_workflow_manager.put_received.remote(trace_id, msg)
|
||||
await ctx.deps.put_pending(ctx.state.trace_id, msg)
|
||||
return End(self.status)
|
||||
|
||||
asyncio.run(_run())
|
||||
|
||||
# ─── 内部 helper ─────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
async def _persist_context(
|
||||
ctx: GraphRunContext[WorkflowGraphState, WorkflowDeps], *, status: str
|
||||
) -> None:
|
||||
"""把当前 state 落库到 workflow_context 表(覆盖式写入)。"""
|
||||
await ctx.deps.upsert_workflow_context(
|
||||
ctx.state.trace_id,
|
||||
workflow_pointer=ctx.state.current_step_index,
|
||||
blackboard=ctx.state.blackboard,
|
||||
work_link=ctx.state.work_link,
|
||||
workflow_status={str(datetime.datetime.now()): status},
|
||||
workflow_log=ctx.state.logs,
|
||||
)
|
||||
|
||||
|
||||
async def _execute_step(
|
||||
ctx: GraphRunContext[WorkflowGraphState, WorkflowDeps],
|
||||
*,
|
||||
executor: StepExecutor,
|
||||
) -> "Dispatch | Finalize":
|
||||
"""SkillStep / ConsciousnessStep 共享的执行骨架。
|
||||
|
||||
把"日志/SSE/blackboard 更新/逻辑闸门"这些跨节点共性逻辑抽出来;具体怎么
|
||||
跑这一步交给 ``executor`` 决定(生产是 SkillIndividual.run / actor.working
|
||||
远程调用,测试可以直接传 lambda)。
|
||||
"""
|
||||
state = ctx.state
|
||||
step_data = state.work_link[state.current_step_index]
|
||||
|
||||
state.logs.append(
|
||||
{
|
||||
str(state.current_step_index): [
|
||||
str(datetime.datetime.now()),
|
||||
"working",
|
||||
f"开始执行: {step_data.get('name', '未命名步骤')}",
|
||||
]
|
||||
}
|
||||
)
|
||||
await _persist_context(ctx, status=WorkflowStatus.RUNNING.value)
|
||||
await ctx.deps.put_pending(
|
||||
state.trace_id, f"执行步骤 {state.current_step_index + 1}..."
|
||||
)
|
||||
|
||||
try:
|
||||
output_text, success = await executor(step_data, state)
|
||||
except Exception as e: # 执行器抛异常 → 走失败分支
|
||||
output_text, success = str(e), False
|
||||
|
||||
if success:
|
||||
output_key = step_data.get(
|
||||
"outputs", f"step_{state.current_step_index}_result"
|
||||
)
|
||||
state.blackboard[output_key] = output_text
|
||||
state.logs[-1][str(state.current_step_index)] = [
|
||||
str(datetime.datetime.now()),
|
||||
"completed",
|
||||
f"成功: {step_data.get('action', '')}",
|
||||
]
|
||||
await _persist_context(ctx, status=WorkflowStatus.RUNNING.value)
|
||||
|
||||
logic_gate = step_data.get("logic_gate") or {}
|
||||
if logic_gate.get("if_pass") == "exit":
|
||||
return Finalize(status=WorkflowStatus.COMPLETED.value)
|
||||
|
||||
state.current_step_index += 1
|
||||
if state.current_step_index >= len(state.work_link):
|
||||
return Finalize(status=WorkflowStatus.COMPLETED.value)
|
||||
return Dispatch()
|
||||
|
||||
# 失败:if_fail 跳转优先于直接收尾
|
||||
state.logs[-1][str(state.current_step_index)] = [
|
||||
str(datetime.datetime.now()),
|
||||
"failed",
|
||||
output_text,
|
||||
]
|
||||
logic_gate = step_data.get("logic_gate") or {}
|
||||
fail_target = logic_gate.get("if_fail")
|
||||
if fail_target and "jump_to_step_" in fail_target:
|
||||
target_step = int(fail_target.split("_")[-1]) - 1
|
||||
state.current_step_index = target_step
|
||||
await _persist_context(ctx, status=WorkflowStatus.RUNNING.value)
|
||||
return Dispatch()
|
||||
|
||||
await _persist_context(ctx, status=WorkflowStatus.FAILED.value)
|
||||
return Finalize(status=WorkflowStatus.FAILED.value)
|
||||
|
||||
|
||||
# ─── 图定义 ───────────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
workflow_graph: Graph[WorkflowGraphState, WorkflowDeps, str] = Graph(
|
||||
nodes=[Initialize, Dispatch, HumanApproval, SkillStep, ConsciousnessStep, Finalize],
|
||||
state_type=WorkflowGraphState,
|
||||
run_end_type=str,
|
||||
)
|
||||
|
||||
|
||||
# ─── 默认执行器:把 step 派发到 SkillIndividual / ConsciousnessNode ────
|
||||
|
||||
|
||||
async def _default_skill_executor(
|
||||
step_data: Dict[str, Any], state: WorkflowGraphState
|
||||
) -> tuple[str, bool]:
|
||||
"""生产环境的 skill_individual 派发器:当前 task 进程现起 agent 执行。
|
||||
|
||||
每步现起一个 ``SkillIndividual`` 跑完即销毁,不绑定 actor 寿命。``agent_id``
|
||||
是必须的(用于从 GSM 拉到该子个体的配置)。
|
||||
"""
|
||||
from kilostar.worker_individual.skill_individual import SkillIndividual
|
||||
from kilostar.core.global_state_machine.gsm_snapshot import fetch_snapshot
|
||||
|
||||
agent_id = step_data.get("agent_id")
|
||||
if not agent_id:
|
||||
return "skill_individual 步骤缺少 agent_id", False
|
||||
|
||||
snapshot = await fetch_snapshot()
|
||||
agent_config = snapshot.individuals.get(agent_id)
|
||||
if not agent_config:
|
||||
return f"未找到 agent_id={agent_id} 的专家子个体", False
|
||||
|
||||
individual = SkillIndividual(dict(agent_config))
|
||||
task_event = {
|
||||
"step": step_data,
|
||||
"blackboard": state.blackboard,
|
||||
"original_command": state.original_command,
|
||||
}
|
||||
result = await individual.run(task_event)
|
||||
output = (
|
||||
result.get("output", "") if isinstance(result, dict) else str(result)
|
||||
)
|
||||
return output or "(empty)", True
|
||||
|
||||
|
||||
async def _default_consciousness_executor(
|
||||
step_data: Dict[str, Any], state: WorkflowGraphState
|
||||
) -> tuple[str, bool]:
|
||||
"""生产环境的 consciousness 派发器:远程调用 ConsciousnessNode.working。"""
|
||||
from kilostar.utils.ray_hook import ray_actor_hook
|
||||
from kilostar.core.individual.consciousness_node.template import (
|
||||
ForWorkflow,
|
||||
ForWorkflowInput,
|
||||
)
|
||||
from kilostar.core.work.workflow.workflow import WorkflowStep
|
||||
|
||||
consciousness_node = ray_actor_hook("consciousness_node").consciousness_node
|
||||
payload = ForWorkflowInput(
|
||||
workflow_step=WorkflowStep.model_validate(step_data),
|
||||
original_command=state.original_command,
|
||||
)
|
||||
result = await consciousness_node.working.remote(payload)
|
||||
if isinstance(result, ForWorkflow):
|
||||
return result.output, True
|
||||
if result is None:
|
||||
return "ConsciousnessNode 返回 None", False
|
||||
return f"ConsciousnessNode 返回未知类型: {type(result).__name__}", False
|
||||
|
||||
|
||||
def build_default_deps() -> WorkflowDeps:
|
||||
"""生产环境构造 ``WorkflowDeps``:把 ray actor handle 包装成 awaitable。
|
||||
|
||||
抽出来是为了让 ``run_workflow_task`` 入口和测试入口共享同一套包装逻辑。
|
||||
"""
|
||||
from kilostar.utils.ray_hook import ray_actor_hook
|
||||
|
||||
postgres_database = ray_actor_hook("postgres_database").postgres_database
|
||||
global_workflow_manager = ray_actor_hook(
|
||||
"global_workflow_manager"
|
||||
).global_workflow_manager
|
||||
|
||||
async def _upsert_workflow_context(trace_id: str, **kwargs: Any) -> Any:
|
||||
return await postgres_database.upsert_workflow_context.remote(
|
||||
trace_id, **kwargs
|
||||
)
|
||||
|
||||
async def _update_workflow_status(trace_id: str, status: str) -> Any:
|
||||
return await postgres_database.update_workflow_status.remote(
|
||||
trace_id, status
|
||||
)
|
||||
|
||||
async def _put_pending(trace_id: str, message: str) -> Any:
|
||||
return await global_workflow_manager.put_pending.remote(trace_id, message)
|
||||
|
||||
async def _get_received(trace_id: str) -> str:
|
||||
return await global_workflow_manager.get_received.remote(trace_id)
|
||||
|
||||
return WorkflowDeps(
|
||||
upsert_workflow_context=_upsert_workflow_context,
|
||||
update_workflow_status=_update_workflow_status,
|
||||
put_pending=_put_pending,
|
||||
get_received=_get_received,
|
||||
run_skill=_default_skill_executor,
|
||||
run_consciousness=_default_consciousness_executor,
|
||||
)
|
||||
|
||||
|
||||
async def run_workflow_graph(
|
||||
workflow_data: Dict[str, Any],
|
||||
trace_id: str,
|
||||
*,
|
||||
deps: Optional[WorkflowDeps] = None,
|
||||
persistence: Optional[BaseStatePersistence] = None,
|
||||
) -> str:
|
||||
"""在当前事件循环里跑一遍 workflow graph,返回 workflow 终态字符串。
|
||||
|
||||
Args:
|
||||
workflow_data: ``KiloStarWorkflow.model_dump()`` 出来的 dict
|
||||
trace_id: 工作流追踪 id
|
||||
deps: 缺省时通过 ``build_default_deps`` 现场构造(生产路径)
|
||||
persistence: 缺省给一个 ``FullStatePersistence`` 让 graph 自动在节点
|
||||
边界 snapshot;外部传入则共享同一份持久化(便于诊断 / 后续 resume)
|
||||
"""
|
||||
if deps is None:
|
||||
deps = build_default_deps()
|
||||
if persistence is None:
|
||||
persistence = FullStatePersistence()
|
||||
|
||||
state = WorkflowGraphState(
|
||||
trace_id=trace_id,
|
||||
blackboard={},
|
||||
work_link=list(workflow_data.get("work_link", []) or []),
|
||||
original_command=(workflow_data.get("workflow_metadata") or {}).get(
|
||||
"command", ""
|
||||
)
|
||||
or "",
|
||||
)
|
||||
result = await workflow_graph.run(
|
||||
Initialize(),
|
||||
state=state,
|
||||
deps=deps,
|
||||
persistence=persistence,
|
||||
)
|
||||
return result.output
|
||||
|
||||
|
||||
async def resume_workflow_graph(
|
||||
trace_id: str,
|
||||
*,
|
||||
deps: Optional[WorkflowDeps] = None,
|
||||
persistence: Optional[BaseStatePersistence] = None,
|
||||
) -> str:
|
||||
"""从持久化里恢复一个工作流,跑剩余节点直至 End。
|
||||
|
||||
要求 ``persistence`` 已经事先 ``hydrate``(生产路径用 ``PostgresStatePersistence.hydrate``)。
|
||||
本函数只负责调 ``Graph.iter_from_persistence`` 把剩下的节点跑完。
|
||||
"""
|
||||
if persistence is None:
|
||||
raise ValueError("resume 必须显式传入 persistence")
|
||||
if deps is None:
|
||||
deps = build_default_deps()
|
||||
|
||||
final_output: str = WorkflowStatus.RUNNING.value
|
||||
async with workflow_graph.iter_from_persistence(
|
||||
persistence, deps=deps
|
||||
) as run:
|
||||
async for node in run:
|
||||
if isinstance(node, End):
|
||||
final_output = node.data
|
||||
break
|
||||
return final_output
|
||||
|
||||
|
||||
@ray.remote
|
||||
class WorkflowRunningEngine:
|
||||
def __init__(
|
||||
self, consciousness_node=None, control_node=None, regulatory_node=None
|
||||
):
|
||||
self.consciousness_node = consciousness_node
|
||||
self.control_node = control_node
|
||||
self.regulatory_node = regulatory_node
|
||||
self.events_queue = asyncio.Queue()
|
||||
def run_workflow_task(workflow_data: dict, trace_id: str):
|
||||
"""workflow 的 ray task 入口:一次性执行,跑完即销毁。
|
||||
|
||||
async def put_event(self, event):
|
||||
await self.events_queue.put(event)
|
||||
生产路径下持久化交给 ``PostgresStatePersistence`` —— 即便进程崩溃,再 fire
|
||||
一次相同 ``trace_id`` 的任务(或调 ``/workflow/{trace_id}/resume``)即可
|
||||
续跑。同时为了支持 fresh start,先尝试 ``hydrate``:
|
||||
- hydrate 拿到内容 → 走 resume 路径
|
||||
- hydrate 没拿到 → 走全新路径
|
||||
|
||||
async def run(self):
|
||||
"""引擎循环提取事件"""
|
||||
while True:
|
||||
await self.events_queue.get()
|
||||
await asyncio.sleep(1)
|
||||
ray task 是新进程,contextvars 不会从 caller 传过来,所以入口先 bind 一次
|
||||
``trace_id``,让节点内的日志自动带上它。
|
||||
"""
|
||||
from kilostar.utils.request_context import trace_id_scope
|
||||
from kilostar.core.work.workflow.graph_persistence import (
|
||||
build_postgres_persistence,
|
||||
)
|
||||
|
||||
async def execute_workflow(self, workflow: KiloStarWorkflow):
|
||||
# 这个方法可以由意识节点调用来提交一个完整的运行任务
|
||||
workflow_dict = workflow.model_dump()
|
||||
trace_id = workflow.trace_id
|
||||
run_workflow_task.remote(workflow_dict, trace_id)
|
||||
async def _entry() -> None:
|
||||
with trace_id_scope(trace_id):
|
||||
persistence = build_postgres_persistence(trace_id)
|
||||
persistence.set_graph_types(workflow_graph)
|
||||
recovered = False
|
||||
try:
|
||||
recovered = await persistence.hydrate()
|
||||
except Exception: # pragma: no cover - 防御
|
||||
recovered = False
|
||||
|
||||
if recovered:
|
||||
await resume_workflow_graph(trace_id, persistence=persistence)
|
||||
else:
|
||||
await run_workflow_graph(
|
||||
workflow_data, trace_id, persistence=persistence
|
||||
)
|
||||
|
||||
asyncio.run(_entry())
|
||||
|
||||
@@ -28,10 +28,10 @@ class ApprovalToolData(BaseToolData):
|
||||
"regulatory_node",
|
||||
"growth_node",
|
||||
"",
|
||||
"",
|
||||
]
|
||||
] = ["control_node", "consciousness_node"]
|
||||
config_args: Dict[str, str] = {}
|
||||
category: str = "system"
|
||||
|
||||
|
||||
async def approval(message: str, trace_id: str) -> str:
|
||||
|
||||
@@ -29,7 +29,8 @@ class BaseToolData(BaseModel):
|
||||
"regulatory_node",
|
||||
"growth_node",
|
||||
"",
|
||||
"",
|
||||
]
|
||||
] = []
|
||||
config_args: Dict[str, str] = {}
|
||||
category: str = "other"
|
||||
"""工具分类:system(系统内置)、search(搜索)、mcp(MCP 服务器)、other(其他)"""
|
||||
|
||||
@@ -12,6 +12,6 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from .file_reader import FileReaderData, file_reader
|
||||
from .file_reader import FileReaderToolData, file_reader
|
||||
|
||||
__all__ = ["FileReaderData", "file_reader"]
|
||||
__all__ = ["FileReaderToolData", "file_reader"]
|
||||
|
||||
@@ -12,36 +12,45 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from pydantic_ai import RunContext
|
||||
"""File Reader Tool Plugin for KiloStar.
|
||||
|
||||
Reads the contents of a file from the local filesystem.
|
||||
"""
|
||||
|
||||
from kilostar.plugin.tool_plugin.base_tool import BaseToolData
|
||||
import os
|
||||
from typing import List, Literal, Dict
|
||||
|
||||
|
||||
class FileReaderData(BaseToolData):
|
||||
"""``file_reader`` 工具的元数据:声明工具的名称、描述与是否系统级别。"""
|
||||
class FileReaderToolData(BaseToolData):
|
||||
"""``file_reader`` 工具的元数据。"""
|
||||
|
||||
is_system: bool = True
|
||||
name: str = "file_reader"
|
||||
description: str = "读取本地文件的内容"
|
||||
action_scope: List[
|
||||
Literal[
|
||||
"control_node",
|
||||
"consciousness_node",
|
||||
"regulatory_node",
|
||||
"growth_node",
|
||||
"",
|
||||
]
|
||||
] = ["control_node"]
|
||||
config_args: Dict[str, str] = {}
|
||||
category: str = "system"
|
||||
|
||||
|
||||
def file_reader(ctx: RunContext, filepath: str) -> str:
|
||||
"""读取本地文件内容的工具。
|
||||
async def file_reader(file_path: str) -> str:
|
||||
"""读取本地文件的内容。
|
||||
|
||||
Args:
|
||||
filepath: 目标文件的绝对路径或相对路径。
|
||||
file_path: 文件的绝对路径或相对路径
|
||||
|
||||
Returns:
|
||||
如果文件存在并可读,返回文件内容;否则返回错误信息。
|
||||
文件内容文本,若文件不存在则返回错误信息
|
||||
"""
|
||||
if not os.path.exists(filepath):
|
||||
return f"Error: 文件 {filepath} 不存在。"
|
||||
if not os.path.isfile(filepath):
|
||||
return f"Error: {filepath} 不是一个文件。"
|
||||
|
||||
try:
|
||||
with open(filepath, "r", encoding="utf-8") as f:
|
||||
content = f.read()
|
||||
return content
|
||||
with open(file_path, "r", encoding="utf-8") as f:
|
||||
return f.read()
|
||||
except FileNotFoundError:
|
||||
return f"[Error] File not found: {file_path}"
|
||||
except Exception as e:
|
||||
return f"Error: 读取文件失败,原因:{str(e)}"
|
||||
return f"[Error] Failed to read file: {str(e)}"
|
||||
|
||||
@@ -0,0 +1,122 @@
|
||||
# Copyright 2026 zhaoxi826
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Tavily Web Search Tool Plugin for KiloStar.
|
||||
|
||||
Provides intelligent web search capabilities via Tavily API.
|
||||
API key 取值优先级:调用参数 > GlobalStateMachine 中 ``tavily_search`` 工具配置 >
|
||||
环境变量 ``TAVILY_API_KEY``。
|
||||
"""
|
||||
|
||||
import os
|
||||
from typing import List, Literal, Dict, Optional
|
||||
|
||||
from kilostar.plugin.tool_plugin.base_tool import BaseToolData
|
||||
from tavily import AsyncTavilyClient
|
||||
|
||||
|
||||
class TavilySearchToolData(BaseToolData):
|
||||
"""Tavily 搜索工具的元数据:面向所有节点开放。"""
|
||||
|
||||
is_system: bool = False
|
||||
action_scope: List[
|
||||
Literal[
|
||||
"control_node",
|
||||
"consciousness_node",
|
||||
"regulatory_node",
|
||||
"growth_node",
|
||||
]
|
||||
] = ["control_node", "consciousness_node", "regulatory_node"]
|
||||
config_args: Dict[str, str] = {
|
||||
"api_key": "",
|
||||
"max_results": "5",
|
||||
"search_depth": "basic",
|
||||
"include_answer": "true",
|
||||
}
|
||||
category: str = "search"
|
||||
|
||||
|
||||
async def _resolve_api_key(explicit: Optional[str]) -> Optional[str]:
|
||||
"""按优先级解析 Tavily API key:显式参数 > GSM 配置 > 环境变量。"""
|
||||
if explicit:
|
||||
return explicit
|
||||
try:
|
||||
from kilostar.core.global_state_machine.gsm_snapshot import fetch_snapshot
|
||||
|
||||
# 工具调用是高频热路径,走 Object Store 快照而不是 actor RPC
|
||||
snapshot = await fetch_snapshot()
|
||||
cfg = snapshot.tool_configs.get("tavily_search") or {}
|
||||
if isinstance(cfg, dict) and cfg.get("api_key"):
|
||||
return cfg["api_key"]
|
||||
except Exception:
|
||||
pass
|
||||
return os.environ.get("TAVILY_API_KEY")
|
||||
|
||||
|
||||
async def tavily_search(
|
||||
query: str,
|
||||
max_results: int = 5,
|
||||
search_depth: str = "basic",
|
||||
include_answer: bool = True,
|
||||
api_key: Optional[str] = None,
|
||||
) -> str:
|
||||
"""使用 Tavily 进行网络搜索,获取高质量的网络搜索结果。
|
||||
|
||||
Args:
|
||||
query: 搜索查询内容
|
||||
max_results: 返回的最大结果数量(1-10)
|
||||
search_depth: 搜索深度,"basic" 或 "advanced"
|
||||
include_answer: 是否包含 AI 生成的答案摘要
|
||||
api_key: 可选;不传则按 GSM 配置 → 环境变量顺序解析
|
||||
|
||||
Returns:
|
||||
格式化的搜索结果文本,包含标题、URL、摘要和可选的 AI 答案
|
||||
"""
|
||||
resolved_key = await _resolve_api_key(api_key)
|
||||
if not resolved_key:
|
||||
return (
|
||||
"[Error] Tavily API key 未配置。"
|
||||
"请在 ``/api/v1/resource/tool/config`` 写入或设置环境变量 ``TAVILY_API_KEY``。"
|
||||
)
|
||||
|
||||
try:
|
||||
client = AsyncTavilyClient(api_key=resolved_key)
|
||||
result = await client.search(
|
||||
query=query,
|
||||
max_results=min(max_results, 10),
|
||||
search_depth=search_depth,
|
||||
include_answer=include_answer,
|
||||
)
|
||||
|
||||
lines = []
|
||||
if include_answer and result.get("answer"):
|
||||
lines.append(f"【AI 摘要】{result['answer']}\n")
|
||||
|
||||
results = result.get("results", [])
|
||||
if not results:
|
||||
return "No results found for the query."
|
||||
|
||||
lines.append("【搜索结果】")
|
||||
for i, item in enumerate(results, 1):
|
||||
title = item.get("title", "Untitled")
|
||||
url = item.get("url", "")
|
||||
content = item.get("content", "").strip()
|
||||
lines.append(f"\n{i}. {title}")
|
||||
lines.append(f" URL: {url}")
|
||||
if content:
|
||||
lines.append(f" {content[:300]}{'...' if len(content) > 300 else ''}")
|
||||
|
||||
return "\n".join(lines)
|
||||
except Exception as e:
|
||||
return f"[Error] Tavily search failed: {str(e)}"
|
||||
@@ -12,10 +12,16 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from typing import Annotated
|
||||
from fastapi import Depends, HTTPException
|
||||
from fastapi import Depends, HTTPException, Request
|
||||
from kilostar.utils.access import Accessor, TokenData
|
||||
from kilostar.core.postgres_database.model import UserAuthority
|
||||
from kilostar.utils.ray_hook import ray_actor_hook
|
||||
from kilostar.utils.i18n import t
|
||||
|
||||
|
||||
def _user_not_found_detail(request: Request | None = None) -> str:
|
||||
loc = request.headers.get("accept-language") if request else None
|
||||
return t("user_not_found", accept_language=loc)
|
||||
|
||||
|
||||
async def get_authority(user_id: str) -> UserAuthority:
|
||||
@@ -29,12 +35,12 @@ async def get_authority(user_id: str) -> UserAuthority:
|
||||
)
|
||||
return user_authority
|
||||
except UserNotExistError:
|
||||
raise HTTPException(status_code=401, detail="用户不存在或已被删除,请重新登录")
|
||||
raise HTTPException(status_code=401, detail=t("user_not_found"))
|
||||
except Exception as e:
|
||||
# Check if it's a RayTaskError wrapping UserNotExistError
|
||||
if "UserNotExistError" in str(e):
|
||||
raise HTTPException(
|
||||
status_code=401, detail="用户不存在或已被删除,请重新登录"
|
||||
status_code=401, detail=t("user_not_found")
|
||||
)
|
||||
raise
|
||||
|
||||
|
||||
@@ -0,0 +1,87 @@
|
||||
import os
|
||||
from functools import lru_cache
|
||||
|
||||
from cryptography.fernet import Fernet, InvalidToken
|
||||
|
||||
from kilostar.utils.logger import get_logger
|
||||
|
||||
logger = get_logger("crypto")
|
||||
|
||||
_VERSION_PREFIX = "v1:"
|
||||
_SENSITIVE_KEYS = {"key", "token", "secret", "password", "apikey", "api_key"}
|
||||
|
||||
|
||||
class CryptoError(Exception):
|
||||
pass
|
||||
|
||||
|
||||
@lru_cache(maxsize=1)
|
||||
def _get_fernet() -> Fernet:
|
||||
raw = os.environ.get("KILOSTAR_SECRET_KEY", "")
|
||||
if not raw:
|
||||
raise CryptoError(
|
||||
"环境变量 KILOSTAR_SECRET_KEY 未设置,无法进行加解密。"
|
||||
"请生成一个密钥:python -c \"from cryptography.fernet import Fernet; print(Fernet.generate_key().decode())\""
|
||||
)
|
||||
try:
|
||||
return Fernet(raw.encode() if isinstance(raw, str) else raw)
|
||||
except Exception as e:
|
||||
raise CryptoError(f"KILOSTAR_SECRET_KEY 格式无效: {e}") from e
|
||||
|
||||
|
||||
def encrypt_secret(plaintext: str) -> str:
|
||||
if not plaintext:
|
||||
return plaintext
|
||||
f = _get_fernet()
|
||||
token = f.encrypt(plaintext.encode("utf-8"))
|
||||
return _VERSION_PREFIX + token.decode("utf-8")
|
||||
|
||||
|
||||
def decrypt_secret(ciphertext: str) -> str:
|
||||
if not ciphertext:
|
||||
return ciphertext
|
||||
if not ciphertext.startswith(_VERSION_PREFIX):
|
||||
return ciphertext
|
||||
raw = ciphertext[len(_VERSION_PREFIX):]
|
||||
f = _get_fernet()
|
||||
try:
|
||||
return f.decrypt(raw.encode("utf-8")).decode("utf-8")
|
||||
except InvalidToken as e:
|
||||
raise CryptoError("解密失败:密钥不匹配或密文已损坏") from e
|
||||
|
||||
|
||||
def is_encrypted(value: str) -> bool:
|
||||
return isinstance(value, str) and value.startswith(_VERSION_PREFIX)
|
||||
|
||||
|
||||
def _is_sensitive_key(key: str) -> bool:
|
||||
lower = key.lower()
|
||||
return any(s in lower for s in _SENSITIVE_KEYS)
|
||||
|
||||
|
||||
def encrypt_dict_secrets(data: dict) -> dict:
|
||||
if not isinstance(data, dict):
|
||||
return data
|
||||
out: dict = {}
|
||||
for k, v in data.items():
|
||||
if _is_sensitive_key(k) and isinstance(v, str) and v and not is_encrypted(v):
|
||||
out[k] = encrypt_secret(v)
|
||||
else:
|
||||
out[k] = v
|
||||
return out
|
||||
|
||||
|
||||
def decrypt_dict_secrets(data: dict) -> dict:
|
||||
if not isinstance(data, dict):
|
||||
return data
|
||||
out: dict = {}
|
||||
for k, v in data.items():
|
||||
if _is_sensitive_key(k) and isinstance(v, str) and is_encrypted(v):
|
||||
try:
|
||||
out[k] = decrypt_secret(v)
|
||||
except CryptoError as e:
|
||||
logger.error(f"字段 {k} 解密失败: {e}")
|
||||
out[k] = v
|
||||
else:
|
||||
out[k] = v
|
||||
return out
|
||||
+92
-27
@@ -12,68 +12,133 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""KiloStar 统一异常体系。
|
||||
|
||||
class RetryableError(Exception):
|
||||
"""基类:所有可重试错误(如网络断开、抖动等临时性故障)。"""
|
||||
设计原则:所有自定义异常归到两条主轴下。
|
||||
|
||||
pass
|
||||
- ``BusinessError``:业务可预期错误,HTTP 层映射 4xx;前端可读、可展示给用户。
|
||||
- ``InfraError``:系统/基础设施失败错误,HTTP 层映射 5xx;通常需要落日志告警。
|
||||
其下再细分为 ``RetryableError``(瞬时故障,可由 ``retry_on_retryable_error`` 自动重试)
|
||||
与 ``NonRetryableError``(确定性失败,重试无意义)。
|
||||
|
||||
注意:用 ``InfraError`` 而非 ``SystemError`` 是为了避免与 Python 内置的
|
||||
``SystemError`` 冲突。
|
||||
|
||||
每个异常类都带 ``http_status`` 与 ``code`` 类属性,``api/__init__.py`` 的统一
|
||||
handler 根据它们直接生成结构化响应,避免业务代码里硬编码状态码。
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
|
||||
class NonRetryableError(Exception):
|
||||
"""基类:所有不可重试错误(如数据验证失败、类型错误等业务逻辑故障)。"""
|
||||
class KiloStarError(Exception):
|
||||
"""KiloStar 所有自定义异常的总根。"""
|
||||
|
||||
pass
|
||||
http_status: int = 500
|
||||
code: str = "kilostar_error"
|
||||
|
||||
|
||||
class DemandError(NonRetryableError):
|
||||
# ─── 主轴 1:业务可预期错误(4xx) ───────────────────────────────────────────
|
||||
|
||||
|
||||
class BusinessError(KiloStarError):
|
||||
"""业务层可预期错误的基类,HTTP 层默认 400。"""
|
||||
|
||||
http_status = 400
|
||||
code = "business_error"
|
||||
|
||||
|
||||
class DemandError(BusinessError):
|
||||
"""需求/任务参数不合法或不满足前置条件时抛出。"""
|
||||
|
||||
pass
|
||||
http_status = 400
|
||||
code = "demand_error"
|
||||
|
||||
|
||||
class ModelNotExistError(Exception):
|
||||
"""请求了一个未在 Provider 中注册的模型 ID 时抛出。"""
|
||||
|
||||
pass
|
||||
# 用户域 ─────────────────────────────────────────
|
||||
|
||||
|
||||
class UserError(Exception):
|
||||
"""用户相关错误的基类,HTTP 层会被统一映射为 4xx。"""
|
||||
class UserError(BusinessError):
|
||||
"""用户域错误的基类。"""
|
||||
|
||||
pass
|
||||
http_status = 400
|
||||
code = "user_error"
|
||||
|
||||
|
||||
class UserNotExistError(UserError):
|
||||
"""按用户名/ID 查询时用户不存在。"""
|
||||
|
||||
pass
|
||||
http_status = 404
|
||||
code = "user_not_exist"
|
||||
|
||||
|
||||
class UserPasswordError(UserError):
|
||||
"""口令校验失败(旧密码错误、登录密码错误等)。"""
|
||||
|
||||
pass
|
||||
http_status = 401
|
||||
code = "user_password_error"
|
||||
|
||||
|
||||
class ProviderError(Exception):
|
||||
"""模型 Provider 相关错误的基类。"""
|
||||
# Provider 域 ─────────────────────────────────────
|
||||
|
||||
pass
|
||||
|
||||
class ProviderError(BusinessError):
|
||||
"""模型 Provider 域错误的基类。"""
|
||||
|
||||
http_status = 400
|
||||
code = "provider_error"
|
||||
|
||||
|
||||
class ProviderNotExistError(ProviderError):
|
||||
"""请求了一个未注册的 Provider 时抛出。"""
|
||||
|
||||
pass
|
||||
http_status = 404
|
||||
code = "provider_not_exist"
|
||||
|
||||
|
||||
class WorkflowError(Exception):
|
||||
"""工作流执行期错误的基类,HTTP 层会被统一映射为 5xx。"""
|
||||
class ModelNotExistError(BusinessError):
|
||||
"""请求了一个未在 Provider 中注册的模型 ID 时抛出。"""
|
||||
|
||||
pass
|
||||
http_status = 404
|
||||
code = "model_not_exist"
|
||||
|
||||
|
||||
class WorkflowExit(WorkflowError):
|
||||
"""工作流被显式终止(用户取消、上游决策跳出等)时抛出,是预期内的退出信号。"""
|
||||
# Workflow 域 ─────────────────────────────────────
|
||||
|
||||
pass
|
||||
|
||||
class WorkflowExit(BusinessError):
|
||||
"""工作流被显式终止(用户取消、上游决策跳出等),是预期内的退出信号。"""
|
||||
|
||||
http_status = 400
|
||||
code = "workflow_exit"
|
||||
|
||||
|
||||
# ─── 主轴 2:系统/基础设施失败错误(5xx) ────────────────────────────────────
|
||||
|
||||
|
||||
class InfraError(KiloStarError):
|
||||
"""系统/基础设施失败错误的基类,HTTP 层默认 500。"""
|
||||
|
||||
http_status = 500
|
||||
code = "infra_error"
|
||||
|
||||
|
||||
class RetryableError(InfraError):
|
||||
"""瞬时故障(如网络抖动),可由 ``retry_on_retryable_error`` 自动重试。"""
|
||||
|
||||
http_status = 503
|
||||
code = "retryable_error"
|
||||
|
||||
|
||||
class NonRetryableError(InfraError):
|
||||
"""确定性的系统失败,重试无意义。"""
|
||||
|
||||
http_status = 500
|
||||
code = "non_retryable_error"
|
||||
|
||||
|
||||
class WorkflowError(InfraError):
|
||||
"""工作流执行期错误的基类,HTTP 层映射为 5xx。"""
|
||||
|
||||
http_status = 500
|
||||
code = "workflow_error"
|
||||
|
||||
@@ -33,9 +33,11 @@ def _get_tool_func(tool_name: str) -> Callable | None:
|
||||
if func:
|
||||
return func
|
||||
|
||||
app_root = "/app"
|
||||
tool_plugin_dir = os.path.join(
|
||||
app_root, "kilostar", "plugin", "tool_plugin", tool_name
|
||||
os.path.dirname(os.path.dirname(os.path.abspath(__file__))),
|
||||
"plugin",
|
||||
"tool_plugin",
|
||||
tool_name,
|
||||
)
|
||||
|
||||
if not os.path.exists(tool_plugin_dir) or not os.path.isdir(tool_plugin_dir):
|
||||
|
||||
@@ -0,0 +1,183 @@
|
||||
# Copyright 2026 zhaoxi826
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""KiloStar 轻量级国际化工具。
|
||||
|
||||
设计原则:
|
||||
- 纯内存字典,无文件 IO,Ray 远程序列化零成本。
|
||||
- 支持环境变量 ``KILOSTAR_LANG`` 作为全局默认语言。
|
||||
- Agent system prompt 按 ``{locale}`` 分桶,调用方显式传入 locale。
|
||||
- API 层通过请求头 ``Accept-Language`` 解析首选语言。
|
||||
|
||||
当前支持:``zh`` (简体中文), ``en`` (English)。
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
from typing import Dict
|
||||
|
||||
_DEFAULT_LOCALE: str = os.getenv("KILOSTAR_LANG", "zh")
|
||||
|
||||
# ─── Agent System Prompts ──────────────────────────────────────────────────
|
||||
|
||||
_PROMPTS: Dict[str, Dict[str, str]] = {
|
||||
"regulatory_node": {
|
||||
"zh": (
|
||||
"你叫kilostar,是一个多智能体AI助手系统中的【监控节点 (regulatory Node)】。\n"
|
||||
"你是系统的'前台接待'和'大脑皮层',负责接收用户的初始请求或工作流的最终报告。\n"
|
||||
"你的核心职责是进行【意图识别与路由】。请仔细阅读用户的请求:\n"
|
||||
"1. 如果用户只是进行简单的问候、闲聊或查询非常基础的信息,请直接生成友好的回复,使用 ForUser 格式。\n"
|
||||
"2. 如果用户提出的是复杂任务(如需要编写代码、多步骤规划、数据处理等),请务必将其判定为需要工作流处理的任务,"
|
||||
" 并使用 ForConsciousnessNode 格式将其移交意识节点处理。\n"
|
||||
"3. 如果你收到的是 TerminationMessage(代表工作流已完成并生成了报告),请将报告内容转化为友好的面向用户的回复,使用 ForUser 格式。\n"
|
||||
"请保持冷静、专业,并严格遵循上述路由规则。"
|
||||
),
|
||||
"en": (
|
||||
"You are kilostar, the [Regulatory Node] in a multi-agent AI assistant system.\n"
|
||||
"You are the system's 'front desk' and 'cerebral cortex', responsible for receiving user requests and final workflow reports.\n"
|
||||
"Your core duty is [intent recognition and routing]. Please read the user's request carefully:\n"
|
||||
"1. If the user is simply greeting, chatting, or asking very basic questions, generate a friendly reply directly in the ForUser format.\n"
|
||||
"2. If the user presents a complex task (e.g., writing code, multi-step planning, data processing), you must classify it as a workflow-requiring task "
|
||||
" and hand it over to the Consciousness Node using the ForConsciousnessNode format.\n"
|
||||
"3. If you receive a TerminationMessage (indicating the workflow is complete and a report has been generated), convert the report into a user-friendly reply in the ForUser format.\n"
|
||||
"Please remain calm, professional, and strictly follow the routing rules above."
|
||||
),
|
||||
},
|
||||
"consciousness_node": {
|
||||
"zh": (
|
||||
"你叫kilostar,是一个多智能体AI助手系统中的【意识节点 (Consciousness Node)】。\n"
|
||||
"你是系统的'高级规划师'和'架构师',负责处理监控节点分配过来的复杂任务。\n"
|
||||
"你的主要工作场景包括:\n"
|
||||
"1. 拆解任务 (Workflow Generation):结合用户的原始命令和提供的模板,生成严谨、可执行的工作流 (kilostarWorkflow),并将其输出为 ForWorkflowEngine 格式。拆解时步骤应清晰连贯。\n"
|
||||
"2. 中途指导 (Workflow Execution):在工作流执行中,如果某一步骤指派给你,你需要对控制节点的结果进行分析或提供下一步的指导,输出 ForWorkflow 格式。\n"
|
||||
"3. 总结报告 (regulatory Report):在整个工作流执行完毕后,你需要对整体流程、各个控制节点的执行情况进行审查,并生成一份技术性的总结报告,输出 ForregulatoryNode 格式。\n"
|
||||
"请确保所有的思考和生成过程符合逻辑,严密且高质量。"
|
||||
),
|
||||
"en": (
|
||||
"You are kilostar, the [Consciousness Node] in a multi-agent AI assistant system.\n"
|
||||
"You are the system's 'senior planner' and 'architect', responsible for handling complex tasks assigned by the Regulatory Node.\n"
|
||||
"Your main scenarios include:\n"
|
||||
"1. Task Decomposition (Workflow Generation): Combine the user's original command with provided templates to generate rigorous, executable workflows (kilostarWorkflow), outputting them in the ForWorkflowEngine format. Steps should be clear and coherent.\n"
|
||||
"2. Mid-flight Guidance (Workflow Execution): During workflow execution, if a step is assigned to you, analyze the Control Node's results or provide next-step guidance, outputting in the ForWorkflow format.\n"
|
||||
"3. Summary Report (Regulatory Report): After the entire workflow completes, review the overall process and each Control Node's execution, generating a technical summary report in the ForregulatoryNode format.\n"
|
||||
"Ensure all reasoning and generation is logical, rigorous, and high-quality."
|
||||
),
|
||||
},
|
||||
"control_node": {
|
||||
"zh": (
|
||||
"你叫kilostar,是一个多智能体AI助手系统中的【控制节点 (Control Node)】。\n"
|
||||
"你是系统的'执行者'和'车间主任',专门负责执行工作流中分配给你的具体子任务。\n"
|
||||
"你的工作职责是:\n"
|
||||
"1. 仔细分析分配给你的工作流步骤 (workflow_step) 的目标和要求。\n"
|
||||
"2. 运用你被分配的工具 (如有) 或者依靠自身的知识和推理能力,精准、高效地完成该任务。\n"
|
||||
"3. 将执行的结果、产生的数据或者具体的输出,严格按照 ForWorkflow 格式返回。\n"
|
||||
"请注意:你的输出应当具体、实用,直接提供任务所要求的结果,不要做过多无关的寒暄。"
|
||||
),
|
||||
"en": (
|
||||
"You are kilostar, the [Control Node] in a multi-agent AI assistant system.\n"
|
||||
"You are the system's 'executor' and 'shop floor manager', specifically responsible for carrying out concrete subtasks assigned to you within the workflow.\n"
|
||||
"Your duties are:\n"
|
||||
"1. Carefully analyze the objectives and requirements of the workflow_step assigned to you.\n"
|
||||
"2. Use the tools assigned to you (if any) or rely on your own knowledge and reasoning to complete the task accurately and efficiently.\n"
|
||||
"3. Return the execution results, generated data, or concrete outputs strictly in the ForWorkflow format.\n"
|
||||
"Note: Your output should be specific, practical, and directly provide the results requested by the task. Avoid excessive irrelevant pleasantries."
|
||||
),
|
||||
},
|
||||
}
|
||||
|
||||
# ─── API / 通用消息 ────────────────────────────────────────────────────────
|
||||
|
||||
_MESSAGES: Dict[str, Dict[str, str]] = {
|
||||
"internal_error": {
|
||||
"zh": "服务内部错误,请稍后重试",
|
||||
"en": "Internal server error, please try again later.",
|
||||
},
|
||||
"user_not_found": {
|
||||
"zh": "用户不存在或已被删除,请重新登录",
|
||||
"en": "User does not exist or has been deleted. Please log in again.",
|
||||
},
|
||||
"provider_not_registered": {
|
||||
"zh": "Provider {provider_title} 未注册",
|
||||
"en": "Provider {provider_title} is not registered.",
|
||||
},
|
||||
"model_not_exist": {
|
||||
"zh": "模型不存在",
|
||||
"en": "Model does not exist.",
|
||||
},
|
||||
"api_not_found": {
|
||||
"zh": "API endpoint not found",
|
||||
"en": "API endpoint not found",
|
||||
},
|
||||
"frontend_not_found": {
|
||||
"zh": "Frontend build not found",
|
||||
"en": "Frontend build not found",
|
||||
},
|
||||
}
|
||||
|
||||
# ─── 工具函数 ──────────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
def _resolve_locale(locale: str | None = None, accept_language: str | None = None) -> str:
|
||||
"""确定最终使用的 locale。
|
||||
|
||||
优先级:显式传入 > Accept-Language 头 > KILOSTAR_LANG 环境变量 > 默认 zh。
|
||||
"""
|
||||
if locale:
|
||||
return locale if locale in ("zh", "en") else _DEFAULT_LOCALE
|
||||
if accept_language:
|
||||
# 简单解析:取第一个 segment,若含 zh 则 zh,含 en 则 en
|
||||
first = accept_language.split(",")[0].split(";")[0].strip().lower()
|
||||
if "zh" in first:
|
||||
return "zh"
|
||||
if "en" in first:
|
||||
return "en"
|
||||
return _DEFAULT_LOCALE
|
||||
|
||||
|
||||
def t(key: str, locale: str | None = None, accept_language: str | None = None, **kwargs) -> str:
|
||||
"""通用消息翻译。
|
||||
|
||||
Args:
|
||||
key: 消息键,如 ``internal_error``。
|
||||
locale: 显式指定语言代码(``zh`` / ``en``)。
|
||||
accept_language: 前端传来的 ``Accept-Language`` 头内容。
|
||||
**kwargs: 模板变量插值。
|
||||
|
||||
Returns:
|
||||
翻译后的字符串;若 key 不存在则返回 key 本身。
|
||||
"""
|
||||
loc = _resolve_locale(locale, accept_language)
|
||||
text = _MESSAGES.get(loc, {}).get(key) or _MESSAGES.get(_DEFAULT_LOCALE, {}).get(key) or key
|
||||
return text.format(**kwargs) if kwargs else text
|
||||
|
||||
|
||||
def agent_prompt(agent_name: str, locale: str | None = None, accept_language: str | None = None) -> str:
|
||||
"""获取指定 Agent 的 system prompt,并追加语言指令。
|
||||
|
||||
Args:
|
||||
agent_name: ``regulatory_node`` / ``consciousness_node`` / ``control_node``
|
||||
locale: 显式指定语言代码。
|
||||
accept_language: ``Accept-Language`` 头内容。
|
||||
|
||||
Returns:
|
||||
完整 system prompt(含 "请使用 XX 语言回复" 的追加指令)。
|
||||
"""
|
||||
loc = _resolve_locale(locale, accept_language)
|
||||
prompt = _PROMPTS.get(agent_name, {}).get(loc) or _PROMPTS.get(agent_name, {}).get(_DEFAULT_LOCALE, "")
|
||||
lang_instruction = {
|
||||
"zh": "\n\n【重要】请始终使用简体中文进行思考和回复。",
|
||||
"en": "\n\n[Important] Please always think and reply in English.",
|
||||
}.get(loc, "")
|
||||
return prompt + lang_instruction
|
||||
@@ -12,24 +12,83 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import os
|
||||
import sys
|
||||
|
||||
from loguru import logger
|
||||
from rich.logging import RichHandler
|
||||
from loguru._logger import Logger
|
||||
|
||||
from kilostar.utils.request_context import get_request_id, get_trace_id
|
||||
|
||||
|
||||
def _is_json_mode() -> bool:
|
||||
"""根据环境变量决定是否启用 JSON 结构化日志。
|
||||
|
||||
支持开关:``KILOSTAR_LOG_FORMAT=json`` 或 ``KILOSTAR_LOG_JSON=1/true``。
|
||||
"""
|
||||
fmt = os.environ.get("KILOSTAR_LOG_FORMAT", "").lower()
|
||||
if fmt == "json":
|
||||
return True
|
||||
flag = os.environ.get("KILOSTAR_LOG_JSON", "").lower()
|
||||
return flag in {"1", "true", "yes", "on"}
|
||||
|
||||
|
||||
def _ctx_patcher(record):
|
||||
"""日志切面:每条日志写出前,把 contextvars 里的 request_id / trace_id 注入。
|
||||
|
||||
显式 ``bind(trace_id=...)`` 的 logger 优先(业务代码可以覆盖切面值);
|
||||
没有 bind 时回退到 contextvars,没有 contextvars 时为空串。
|
||||
"""
|
||||
extra = record["extra"]
|
||||
if not extra.get("trace_id"):
|
||||
extra["trace_id"] = get_trace_id()
|
||||
if not extra.get("request_id"):
|
||||
extra["request_id"] = get_request_id()
|
||||
|
||||
|
||||
def setup_logger() -> Logger:
|
||||
"""初始化全局 loguru logger,输出格式为 ``actor:(...) | trace_id:(...) : message``。"""
|
||||
"""初始化全局 loguru logger。
|
||||
|
||||
- 默认(开发模式):``RichHandler`` 彩色输出,格式 ``actor:(...) | request_id:(...) | trace_id:(...) : message``
|
||||
- JSON 模式(``KILOSTAR_LOG_FORMAT=json``):写到 stdout,每行一条 JSON,便于 ELK/Loki 采集
|
||||
|
||||
request_id / trace_id 来自 ``kilostar.utils.request_context``,由 FastAPI middleware
|
||||
或工作流入口绑定到 contextvars,本模块通过 ``patcher`` 透明注入。
|
||||
"""
|
||||
logger.remove()
|
||||
|
||||
log_level = os.environ.get("KILOSTAR_LOG_LEVEL", "DEBUG").upper()
|
||||
|
||||
if _is_json_mode():
|
||||
logger.configure(
|
||||
extra={"actor_name": "System", "trace_id": "", "request_id": ""},
|
||||
patcher=_ctx_patcher,
|
||||
)
|
||||
logger.add(
|
||||
sys.stdout,
|
||||
serialize=True,
|
||||
level=log_level,
|
||||
enqueue=True,
|
||||
)
|
||||
return logger
|
||||
|
||||
def format_record(record):
|
||||
# Format string for rich handler
|
||||
actor = record["extra"].get("actor_name", "System")
|
||||
trace_id = record["extra"].get("trace_id", "")
|
||||
request_id = record["extra"].get("request_id", "")
|
||||
ids = []
|
||||
if request_id:
|
||||
ids.append(f"request_id:({request_id})")
|
||||
if trace_id:
|
||||
ids.append(f"trace_id:({trace_id})")
|
||||
ids_str = " | " + " | ".join(ids) if ids else ""
|
||||
return f"actor:({actor}){ids_str} : {record['message']}"
|
||||
|
||||
trace_str = f" | trace_id:({trace_id})" if trace_id else ""
|
||||
return f"actor:({actor}){trace_str} : {record['message']}"
|
||||
|
||||
logger.configure(extra={"actor_name": "System", "trace_id": ""})
|
||||
logger.configure(
|
||||
extra={"actor_name": "System", "trace_id": "", "request_id": ""},
|
||||
patcher=_ctx_patcher,
|
||||
)
|
||||
|
||||
logger.add(
|
||||
RichHandler(
|
||||
@@ -40,8 +99,8 @@ def setup_logger() -> Logger:
|
||||
show_path=False,
|
||||
),
|
||||
format=format_record,
|
||||
level="DEBUG",
|
||||
enqueue=True, # 异步记录
|
||||
level=log_level,
|
||||
enqueue=True,
|
||||
)
|
||||
|
||||
return logger
|
||||
@@ -51,5 +110,9 @@ global_logger = setup_logger()
|
||||
|
||||
|
||||
def get_logger(actor_name: str, trace_id: str = "") -> Logger:
|
||||
"""获取一个绑定了 actor_name 与可选 trace_id 的 logger,便于日志按 Actor/请求归类。"""
|
||||
"""获取一个绑定了 actor_name 与可选 trace_id 的 logger,便于日志按 Actor/请求归类。
|
||||
|
||||
若 ``trace_id`` 留空,会回退到 ``contextvars`` 中的当前值(由 middleware 或
|
||||
工作流入口设置)。显式传值则会覆盖切面注入。
|
||||
"""
|
||||
return global_logger.bind(actor_name=actor_name, trace_id=trace_id)
|
||||
|
||||
@@ -0,0 +1,180 @@
|
||||
# Copyright 2026 zhaoxi826
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""MCP 辅助模块:根据全局状态机中的配置创建 pydantic-ai MCPServer 实例。"""
|
||||
|
||||
from typing import Dict, List, Any, Sequence
|
||||
|
||||
from kilostar.utils.logger import get_logger
|
||||
|
||||
logger = get_logger("mcp_helper")
|
||||
|
||||
# 延迟导入 pydantic_ai.mcp,避免在 MCP 包未安装时崩溃
|
||||
try:
|
||||
from pydantic_ai.mcp import (
|
||||
MCPServerStdio,
|
||||
MCPServerSSE,
|
||||
MCPServerHTTP,
|
||||
)
|
||||
_MCP_AVAILABLE = True
|
||||
except ImportError:
|
||||
_MCP_AVAILABLE = False
|
||||
logger.warning("MCP package not installed. MCP servers will not be available.")
|
||||
|
||||
|
||||
def build_mcp_toolsets(configs: Dict[str, Dict[str, Any]]) -> List[Any]:
|
||||
"""根据配置字典创建 MCPServer 实例列表。
|
||||
|
||||
Args:
|
||||
configs: {server_id: {"name": ..., "transport": ..., ...}}
|
||||
|
||||
Returns:
|
||||
MCPServer 实例列表(可直接传给 Agent 的 toolsets 参数)
|
||||
"""
|
||||
if not _MCP_AVAILABLE:
|
||||
return []
|
||||
|
||||
toolsets = []
|
||||
for server_id, cfg in configs.items():
|
||||
try:
|
||||
transport = cfg.get("transport", "stdio")
|
||||
tool_prefix = cfg.get("tool_prefix")
|
||||
name = cfg.get("name", server_id)
|
||||
|
||||
if transport == "stdio":
|
||||
server = MCPServerStdio(
|
||||
command=cfg.get("command", ""),
|
||||
args=cfg.get("args", []),
|
||||
env=cfg.get("env"),
|
||||
tool_prefix=tool_prefix,
|
||||
id=server_id,
|
||||
)
|
||||
elif transport == "sse":
|
||||
server = MCPServerSSE(
|
||||
url=cfg.get("url", ""),
|
||||
tool_prefix=tool_prefix,
|
||||
id=server_id,
|
||||
)
|
||||
elif transport == "http":
|
||||
server = MCPServerHTTP(
|
||||
url=cfg.get("url", ""),
|
||||
tool_prefix=tool_prefix,
|
||||
id=server_id,
|
||||
)
|
||||
else:
|
||||
logger.warning(f"Unsupported MCP transport: {transport} for server {name}")
|
||||
continue
|
||||
|
||||
toolsets.append(server)
|
||||
logger.info(f"MCP server '{name}' ({transport}) registered as toolset")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to build MCP server '{server_id}': {e}")
|
||||
|
||||
return toolsets
|
||||
|
||||
|
||||
async def get_mcp_toolsets_from_gsm() -> List[Any]:
|
||||
"""从 GlobalStateMachine 拉取 MCP 配置并构建 toolsets。"""
|
||||
if not _MCP_AVAILABLE:
|
||||
return []
|
||||
|
||||
try:
|
||||
from kilostar.core.global_state_machine.gsm_snapshot import fetch_snapshot
|
||||
|
||||
# 走快照:MCP 配置变更频率极低,本地缓存命中率近 100%
|
||||
snapshot = await fetch_snapshot()
|
||||
return build_mcp_toolsets(snapshot.mcp_servers)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to load MCP configs from GSM: {e}")
|
||||
return []
|
||||
|
||||
|
||||
async def get_all_toolsets_for_scope(scope: str) -> List[Any]:
|
||||
"""汇总某个 scope 下的全部 toolset:system + personal + mcp。
|
||||
|
||||
返回顺序保持稳定:先本地 toolset(system → personal),再 MCP toolset。
|
||||
任意一类拉取失败仅记录日志,不影响其他类。
|
||||
"""
|
||||
toolsets: List[Any] = []
|
||||
try:
|
||||
from kilostar.core.global_state_machine.gsm_snapshot import (
|
||||
build_toolsets_for_scope,
|
||||
fetch_snapshot,
|
||||
)
|
||||
|
||||
# 一次快照拉取覆盖 system + custom toolsets,本地按 scope 重建 FunctionToolset
|
||||
snapshot = await fetch_snapshot()
|
||||
local = build_toolsets_for_scope(snapshot, scope)
|
||||
if local:
|
||||
toolsets.extend(local)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to load local toolsets from GSM ({scope}): {e}")
|
||||
|
||||
toolsets.extend(await get_mcp_toolsets_from_gsm())
|
||||
return toolsets
|
||||
|
||||
|
||||
async def list_mcp_tools_for_configs(
|
||||
configs: Dict[str, Dict[str, Any]],
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""对每个 MCP 服务器逐个尝试连接,列出它们暴露的工具名。
|
||||
|
||||
实现层面会进入 ``async with server:`` 上下文,调用一次 ``get_tools()``,
|
||||
再把工具名(带 tool_prefix)抽出来。任何一个 server 失败都不影响其他 server,
|
||||
出错时该项 ``tools=[]`` 并附带 ``error`` 字段。
|
||||
"""
|
||||
result: List[Dict[str, Any]] = []
|
||||
if not _MCP_AVAILABLE:
|
||||
return result
|
||||
|
||||
servers = build_mcp_toolsets(configs)
|
||||
for server in servers:
|
||||
server_id = getattr(server, "id", None)
|
||||
cfg = configs.get(server_id, {}) if server_id else {}
|
||||
name = cfg.get("name", server_id or "unknown")
|
||||
transport = cfg.get("transport", "stdio")
|
||||
item: Dict[str, Any] = {
|
||||
"server_id": server_id,
|
||||
"name": name,
|
||||
"transport": transport,
|
||||
"tool_prefix": cfg.get("tool_prefix"),
|
||||
"tools": [],
|
||||
}
|
||||
try:
|
||||
async with server:
|
||||
tools = await server.get_tools()
|
||||
item["tools"] = [
|
||||
getattr(t, "name", None) or getattr(t, "tool_name", str(t))
|
||||
for t in tools
|
||||
]
|
||||
except Exception as e:
|
||||
item["error"] = str(e)
|
||||
logger.warning(f"MCP server '{name}' list_tools failed: {e}")
|
||||
result.append(item)
|
||||
return result
|
||||
|
||||
|
||||
async def list_mcp_tools_from_gsm() -> List[Dict[str, Any]]:
|
||||
"""从 GlobalStateMachine 拉取配置后调用 :func:`list_mcp_tools_for_configs`。"""
|
||||
if not _MCP_AVAILABLE:
|
||||
return []
|
||||
|
||||
try:
|
||||
from kilostar.core.global_state_machine.gsm_snapshot import fetch_snapshot
|
||||
|
||||
snapshot = await fetch_snapshot()
|
||||
return await list_mcp_tools_for_configs(snapshot.mcp_servers)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to list MCP tools from GSM: {e}")
|
||||
return []
|
||||
@@ -0,0 +1,130 @@
|
||||
# Copyright 2026 zhaoxi826
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""请求/工作流上下文:基于 ``contextvars`` 的双层 ID 传播。
|
||||
|
||||
设计上把"一次用户请求"和"一次重型工作流"区分开:
|
||||
|
||||
- ``request_id``:会话域。所有进 API 的请求都要带,由 middleware 在入口生成或
|
||||
从 ``X-Request-Id`` 头继承。chat 这条同步链路靠它走完一生。
|
||||
- ``trace_id``:工作流域。只有 ``ConsciousnessNode`` 决定启动重型任务时才生成,
|
||||
挂到 ``KiloStarWorkflow`` 上。trace_id 应能追溯回触发它的 request_id(前者
|
||||
通过显式参数传入,后者从 contextvars 读取)。
|
||||
|
||||
为什么用 ``contextvars`` 而不是参数透传:
|
||||
|
||||
1. ``contextvars`` 在 ``asyncio`` 协程间天然继承,不会跨协程串味;
|
||||
2. ``loguru`` 的 ``patcher`` 钩子可以把它变成日志切面,业务代码不需要在每条
|
||||
``logger.info`` 上手动 ``.bind(trace_id=...)``;
|
||||
3. Ray 跨进程调用时 contextvars 不会自动传播 —— 这是有意为之,避免不同 actor
|
||||
间的上下文意外串联。跨 actor 边界要走显式参数,由接收方再 ``bind_*`` 一次。
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import uuid
|
||||
from contextlib import contextmanager
|
||||
from contextvars import ContextVar, Token
|
||||
from typing import Iterator, Optional
|
||||
|
||||
|
||||
_request_id_var: ContextVar[str] = ContextVar("kilostar_request_id", default="")
|
||||
_trace_id_var: ContextVar[str] = ContextVar("kilostar_trace_id", default="")
|
||||
|
||||
|
||||
def get_request_id() -> str:
|
||||
"""返回当前协程的 ``request_id``,未绑定时返回空串。"""
|
||||
return _request_id_var.get()
|
||||
|
||||
|
||||
def get_trace_id() -> str:
|
||||
"""返回当前协程的 ``trace_id``,未绑定时返回空串。"""
|
||||
return _trace_id_var.get()
|
||||
|
||||
|
||||
def bind_request_id(request_id: str) -> Token:
|
||||
"""直接绑定 ``request_id`` 到当前 context,返回 token 以便 ``reset`` 还原。
|
||||
|
||||
返回的 ``Token`` 只能在与 ``set`` 同一线程/协程中传给 ``reset``,否则会抛
|
||||
``ValueError``。一般情况下推荐用 ``request_id_scope`` 上下文管理器代替。
|
||||
"""
|
||||
return _request_id_var.set(request_id)
|
||||
|
||||
|
||||
def bind_trace_id(trace_id: str) -> Token:
|
||||
"""直接绑定 ``trace_id`` 到当前 context,返回 token 以便 ``reset`` 还原。"""
|
||||
return _trace_id_var.set(trace_id)
|
||||
|
||||
|
||||
def reset_request_id(token: Token) -> None:
|
||||
_request_id_var.reset(token)
|
||||
|
||||
|
||||
def reset_trace_id(token: Token) -> None:
|
||||
_trace_id_var.reset(token)
|
||||
|
||||
|
||||
@contextmanager
|
||||
def request_id_scope(request_id: str) -> Iterator[str]:
|
||||
"""``with`` 范围内绑定 request_id,退出自动还原。"""
|
||||
token = _request_id_var.set(request_id)
|
||||
try:
|
||||
yield request_id
|
||||
finally:
|
||||
_request_id_var.reset(token)
|
||||
|
||||
|
||||
@contextmanager
|
||||
def trace_id_scope(trace_id: str) -> Iterator[str]:
|
||||
"""``with`` 范围内绑定 trace_id,退出自动还原。"""
|
||||
token = _trace_id_var.set(trace_id)
|
||||
try:
|
||||
yield trace_id
|
||||
finally:
|
||||
_trace_id_var.reset(token)
|
||||
|
||||
|
||||
def new_request_id(prefix: str = "req") -> str:
|
||||
"""生成一个新的 request_id:``<prefix>-<uuid4 hex>``。"""
|
||||
return f"{prefix}-{uuid.uuid4().hex}"
|
||||
|
||||
|
||||
def snapshot() -> dict[str, str]:
|
||||
"""返回当前上下文 ID 的快照,便于跨 actor/task 边界显式透传。"""
|
||||
return {
|
||||
"request_id": _request_id_var.get(),
|
||||
"trace_id": _trace_id_var.get(),
|
||||
}
|
||||
|
||||
|
||||
@contextmanager
|
||||
def apply_snapshot(snap: Optional[dict[str, str]]) -> Iterator[None]:
|
||||
"""把外部传来的 snapshot 在当前 context 内生效一次(用于跨 Ray actor 调用时)。"""
|
||||
if not snap:
|
||||
yield
|
||||
return
|
||||
tokens: list[Token] = []
|
||||
if snap.get("request_id"):
|
||||
tokens.append(_request_id_var.set(snap["request_id"]))
|
||||
if snap.get("trace_id"):
|
||||
tokens.append(_trace_id_var.set(snap["trace_id"]))
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
for tok in reversed(tokens):
|
||||
try:
|
||||
tok.var.reset(tok)
|
||||
except (ValueError, LookupError):
|
||||
# token 可能因协程切换失效,宽容处理
|
||||
pass
|
||||
@@ -59,10 +59,14 @@ class WorkerCluster:
|
||||
self._active_workers.move_to_end(agent_id)
|
||||
return self._active_workers[agent_id]
|
||||
|
||||
from kilostar.core.global_state_machine.gsm_snapshot import fetch_snapshot
|
||||
|
||||
global_state_machine = ray_actor_hook(
|
||||
"global_state_machine"
|
||||
).global_state_machine
|
||||
agent_config = await global_state_machine.get_individual.remote(agent_id)
|
||||
# 走快照读,避开 GSM actor RPC:高频唤醒路径不再是单 actor 瓶颈
|
||||
snapshot = await fetch_snapshot(gsm_actor=global_state_machine)
|
||||
agent_config = snapshot.individuals.get(agent_id)
|
||||
|
||||
if not agent_config:
|
||||
raise ValueError(f"无法唤醒 Agent {agent_id}:数据库中不存在该档案")
|
||||
|
||||
@@ -52,17 +52,21 @@ class BaseIndividual:
|
||||
self.agent_id = agent_config.get("agent_id")
|
||||
self.agent: Agent | None = None
|
||||
|
||||
async def _init_agent(self, agent_name: str, system_prompt: str):
|
||||
async def _init_agent(self, agent_name: str, system_prompt: str, toolsets=None):
|
||||
"""根据 agent_config 拉起一个 pydantic-ai Agent 实例。
|
||||
|
||||
从 GlobalStateMachine 取出 Provider,按 agent_config 中的 provider_title
|
||||
和 model_id 选择模型,加载工具列表,并把 system_prompt 注册为动态提示词。
|
||||
若调用方未显式提供 ``toolsets``,会自动从全局状态机拉取 MCP toolsets 注入。
|
||||
|
||||
Args:
|
||||
agent_name: Agent 的人类可读名称,用于日志与展示。
|
||||
system_prompt: 该 Agent 的基础系统提示词,会和 task_event 拼接成动态提示词。
|
||||
toolsets: 显式传入的外部工具集;为 ``None`` 时会自动拉取 MCP toolsets。
|
||||
"""
|
||||
from kilostar.utils.get_tool import load_tools_from_list
|
||||
from kilostar.utils.mcp_helper import get_all_toolsets_for_scope
|
||||
from kilostar.core.global_state_machine.gsm_snapshot import fetch_snapshot
|
||||
|
||||
global_state_machine = ray_actor_hook(
|
||||
"global_state_machine"
|
||||
@@ -73,13 +77,18 @@ class BaseIndividual:
|
||||
model_id = self.agent_config.get("model_id", "gpt-4o") # default fallback
|
||||
tools_list = self.agent_config.get("tools", None)
|
||||
|
||||
provider: Provider = await global_state_machine.get_provider.remote(
|
||||
provider_title
|
||||
)
|
||||
# 直读快照,避开 actor RPC 单线程串行
|
||||
snapshot = await fetch_snapshot(gsm_actor=global_state_machine)
|
||||
provider: Provider = snapshot.providers.get(provider_title)
|
||||
if provider is None:
|
||||
raise ValueError(f"Provider {provider_title!r} 未注册")
|
||||
agent_factory = AgentFactory()
|
||||
|
||||
callables = load_tools_from_list(tools_list)
|
||||
|
||||
if toolsets is None:
|
||||
toolsets = await get_all_toolsets_for_scope(agent_name)
|
||||
|
||||
self.agent = agent_factory.create_agent(
|
||||
provider=provider,
|
||||
model_id=model_id,
|
||||
@@ -88,6 +97,7 @@ class BaseIndividual:
|
||||
deps_type=WorkerIndividualDeps,
|
||||
agent_name=agent_name,
|
||||
tools=callables,
|
||||
toolsets=toolsets,
|
||||
)
|
||||
|
||||
@self.agent.system_prompt
|
||||
|
||||
Reference in New Issue
Block a user