feat(system):优化后端

1.新增后端测试
2.增加了后端的加密
3.增加了i18n(国际化)
This commit is contained in:
2026-05-31 15:39:34 +00:00
parent affe460180
commit 99520c69d7
118 changed files with 8174 additions and 1491 deletions
@@ -12,6 +12,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Sequence, Any
from pydantic_ai import Agent
from pydantic_ai.models.openai import OpenAIChatModel
from pydantic_ai.models.anthropic import AnthropicModel
@@ -20,6 +22,8 @@ from pydantic_ai.providers.openai import OpenAIProvider
from pydantic_ai.providers.anthropic import AnthropicProvider
from pydantic_ai.providers.deepseek import DeepSeekProvider
from pydantic_ai.providers.google import GoogleProvider
from pydantic_ai.toolsets import AbstractToolset
from kilostar.core.global_state_machine.model_provider import Provider
from kilostar.utils.agent_model import ResponseModel, DepsModel
from kilostar.utils.error import ModelNotExistError
@@ -30,6 +34,7 @@ class AgentFactory:
支持 openai / claude / deepseek / gemini 四类后端,差异通过
``_models_mapping`` 中的 ``model_class`` + ``provider_class`` 键值对屏蔽。
同时支持传入本地工具(tools)和外部工具集(toolsets),包括 MCP 服务器。
"""
def __init__(self):
@@ -65,21 +70,22 @@ class AgentFactory:
deps_type: DepsModel,
agent_name: str,
tools: list = None,
toolsets: Sequence[AbstractToolset[Any]] = None,
) -> Agent:
"""
create_agent方法,将输入的provider对象实例化为一个pydantic-ai的agent对象
"""将输入的 provider 对象实例化为一个 pydantic-ai 的 agent 对象。
Args:
provider: Provider对象,从global_state_machine中获取
provider: Provider 对象,从 global_state_machine 中获取
model_id: 模型名
output_type: 输出格式
system_prompt: 系统提示词
deps_type: 依赖类型,在agent运行时动态输入的格式化消息
agent_name: agent的名字
tools: 工具列表
deps_type: 依赖类型,在 agent 运行时动态输入的格式化消息
agent_name: agent 的名字
tools: 本地工具函数列表
toolsets: 外部工具集列表(包括 MCP 服务器等 AbstractToolset 实例)
Returns:
返回被实例化的pydantic-aiAgent对象
被实例化的 pydantic-aiAgent 对象
"""
if model_id not in provider.provider_models:
raise ModelNotExistError("模型不存在")
@@ -109,13 +115,14 @@ class AgentFactory:
else:
model = model_class(model_id, provider=model_provider)
# 创建 Agent
# 创建 Agent,同时传入 tools 和 toolsets
agent = Agent(
model=model,
name=agent_name,
system_prompt=system_prompt,
output_type=output_type,
deps_type=deps_type,
tools=tools,
tools=tools or [],
toolsets=toolsets or [],
)
return agent
+82 -49
View File
@@ -16,6 +16,7 @@ import os
from typing import Dict
from fastapi import FastAPI, WebSocket, Request
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import FileResponse, JSONResponse
from fastapi.staticfiles import StaticFiles
from ray import serve
@@ -23,26 +24,68 @@ from ray import serve
from .agent import agent_router
from .auth import auth_router
from .cluster import cluster_router
from .health import health_router
from .platform.frontend import client_router
from .platform.onebot import onebot_router
from .provider import provider_router
from .resource import resource_router
from .workflow import workflow_router
from .chat import chat_router
from kilostar.utils.error import (
DemandError,
ModelNotExistError,
UserError,
UserNotExistError,
UserPasswordError,
ProviderError,
ProviderNotExistError,
WorkflowError,
WorkflowExit,
KiloStarError,
BusinessError,
InfraError,
)
from kilostar.utils.logger import get_logger
from kilostar.utils.request_context import (
bind_request_id,
new_request_id,
reset_request_id,
)
from kilostar.utils.i18n import t
_api_logger = get_logger("api")
def _get_locale(request: Request) -> str | None:
"""从请求头解析首选语言,供异常 handler 使用。"""
return request.headers.get("accept-language") or None
app = FastAPI()
_cors_origins_env = os.environ.get("KILOSTAR_CORS_ORIGINS", "*")
_cors_origins = [o.strip() for o in _cors_origins_env.split(",") if o.strip()]
_allow_credentials = "*" not in _cors_origins
app.add_middleware(
CORSMiddleware,
allow_origins=_cors_origins,
allow_credentials=_allow_credentials,
allow_methods=["*"],
allow_headers=["*"],
)
@app.middleware("http")
async def request_id_middleware(request: Request, call_next):
"""请求级 ``request_id`` 注入。
入口策略:``X-Request-Id`` 头存在则继承(便于网关/前端串联调用链),
否则生成新的 UUID。退出时把它写到响应头,方便客户端日志对账。
contextvars 让同一请求生命周期内所有协程的日志都自动带上这个 ID。
"""
incoming = request.headers.get("X-Request-Id", "").strip()
request_id = incoming or new_request_id()
token = bind_request_id(request_id)
try:
response = await call_next(request)
finally:
reset_request_id(token)
response.headers["X-Request-Id"] = request_id
return response
app.include_router(health_router) # 健康检查
app.include_router(client_router) # 客户端路径
app.include_router(onebot_router) # OneBot v11 路径
app.include_router(auth_router) # 用户路径
app.include_router(provider_router) # 供应商路径
app.include_router(resource_router) # 资源路径
@@ -52,49 +95,39 @@ app.include_router(workflow_router) # workflow路径
app.include_router(chat_router) # chat路径
@app.exception_handler(UserNotExistError)
async def user_not_exist_handler(request: Request, exc: UserNotExistError):
return JSONResponse(status_code=404, content={"message": "用户不存在"})
@app.exception_handler(BusinessError)
async def business_error_handler(request: Request, exc: BusinessError):
"""业务可预期错误:按 ``http_status`` 返回 4xx,附 ``code`` + 异常消息。"""
return JSONResponse(
status_code=exc.http_status,
content={"code": exc.code, "message": str(exc) or exc.code},
)
@app.exception_handler(UserPasswordError)
async def user_password_handler(request: Request, exc: UserPasswordError):
return JSONResponse(status_code=401, content={"message": "密码错误"})
@app.exception_handler(InfraError)
async def infra_error_handler(request: Request, exc: InfraError):
"""系统失败错误:落日志后返回脱敏的 5xx。"""
_api_logger.exception(
f"InfraError on {request.method} {request.url.path}: {exc}"
)
loc = _get_locale(request)
return JSONResponse(
status_code=exc.http_status,
content={"code": exc.code, "message": t("internal_error", accept_language=loc)},
)
@app.exception_handler(UserError)
async def user_error_handler(request: Request, exc: UserError):
return JSONResponse(status_code=400, content={"message": "用户相关错误"})
@app.exception_handler(ProviderNotExistError)
async def provider_not_exist_handler(request: Request, exc: ProviderNotExistError):
return JSONResponse(status_code=404, content={"message": "服务提供商不存在"})
@app.exception_handler(ProviderError)
async def provider_error_handler(request: Request, exc: ProviderError):
return JSONResponse(status_code=400, content={"message": "服务提供商错误"})
@app.exception_handler(ModelNotExistError)
async def model_not_exist_handler(request: Request, exc: ModelNotExistError):
return JSONResponse(status_code=404, content={"message": "模型不存在"})
@app.exception_handler(DemandError)
async def demand_error_handler(request: Request, exc: DemandError):
return JSONResponse(status_code=400, content={"message": "需求格式错误或不满足"})
@app.exception_handler(WorkflowExit)
async def workflow_exit_handler(request: Request, exc: WorkflowExit):
return JSONResponse(status_code=400, content={"message": "工作流已退出"})
@app.exception_handler(WorkflowError)
async def workflow_error_handler(request: Request, exc: WorkflowError):
return JSONResponse(status_code=500, content={"message": "工作流执行错误"})
@app.exception_handler(Exception)
async def unhandled_exception_handler(request: Request, exc: Exception):
"""全局兜底:未预期的异常落日志后返回脱敏的 500,避免泄露 traceback。"""
_api_logger.exception(
f"Unhandled exception on {request.method} {request.url.path}: {exc}"
)
loc = _get_locale(request)
return JSONResponse(
status_code=500,
content={"code": "internal_error", "message": t("internal_error", accept_language=loc)},
)
base_dir = os.path.dirname(
@@ -129,7 +162,7 @@ if os.path.exists(frontend_dir):
if os.path.exists(index_path):
return FileResponse(index_path)
return JSONResponse(
status_code=404, content={"detail": "Frontend build not found"}
status_code=404, content={"detail": t("frontend_not_found")}
)
else:
import logging
+15 -2
View File
@@ -15,7 +15,7 @@
from typing import Union
from kilostar.utils.ray_hook import ray_actor_hook
from fastapi import APIRouter, Depends
from fastapi import APIRouter, Depends, Request
from pydantic import BaseModel
from kilostar.utils.access import Accessor, TokenData
from kilostar.core.postgres_database.model import AgentType
@@ -23,6 +23,8 @@ from fastapi import HTTPException
from typing import Optional, List, Dict
from kilostar.utils.check_user.role_check import RoleChecker
from kilostar.core.postgres_database.model import UserAuthority
from kilostar.utils.mcp_helper import get_all_toolsets_for_scope
from kilostar.utils.i18n import t
agent_router = APIRouter(prefix="/api/v1/agent", tags=["agent"])
@@ -57,11 +59,13 @@ async def get_system_nodes(
@agent_router.post("")
async def load_agent(
agent_register: Union[AgentRegister, AgentLocalRegister],
request: Request,
_: TokenData = Depends(RoleChecker(allowed_roles=UserAuthority.USER)),
):
"""加载/重载某个系统节点的 Agent:先持久化配置,再调用对应节点 Actor 的 ``create_agent``。"""
global_state_machine = ray_actor_hook("global_state_machine").global_state_machine
postgres_database = ray_actor_hook("postgres_database").postgres_database
accept_lang = request.headers.get("accept-language", "")
if isinstance(agent_register, AgentLocalRegister):
pass
@@ -75,7 +79,10 @@ async def load_agent(
agent_register.tools,
)
match agent_register.individual_name:
scope = agent_register.individual_name
toolsets = await get_all_toolsets_for_scope(scope)
match scope:
case "regulatory_node":
node = ray_actor_hook("regulatory_node").regulatory_node
await node.create_agent.remote(
@@ -83,6 +90,8 @@ async def load_agent(
agent_register.provider_title,
agent_register.model_id,
agent_register.tools,
toolsets,
accept_lang,
)
case "consciousness_node":
node = ray_actor_hook("consciousness_node").consciousness_node
@@ -91,6 +100,8 @@ async def load_agent(
agent_register.provider_title,
agent_register.model_id,
agent_register.tools,
toolsets,
accept_lang,
)
case "control_node":
node = ray_actor_hook("control_node").control_node
@@ -99,6 +110,8 @@ async def load_agent(
agent_register.provider_title,
agent_register.model_id,
agent_register.tools,
toolsets,
accept_lang,
)
case _:
pass
+35 -6
View File
@@ -16,10 +16,40 @@ from fastapi import APIRouter, Depends
from pydantic import BaseModel
from kilostar.utils.ray_hook import ray_actor_hook
from kilostar.utils.access import Accessor, TokenData
from kilostar.core.individual.regulatory_node.template import (
MessageRequest,
MessageResponse,
)
chat_router = APIRouter(prefix="/api/v1/chat", tags=["chat"])
def _extract_reply(resp: MessageResponse | None) -> str | None:
"""从 RegulatoryNode.working 的输出里取出对用户的回复文本。
RegulatoryNode 现在的 output_type 只剩 ``MessageResponse``(聊天/简单任务/汇报),
没有则视为节点降级为静默——上层不写回 chat history。
"""
if resp is None:
return None
return resp.reply_message
async def _ask_regulatory(
*, user_id: str, chat_id: str, message: str
) -> str | None:
"""统一封装 chat 入口对 RegulatoryNode 的调用。"""
regulatory_node = ray_actor_hook("regulatory_node").regulatory_node
payload = MessageRequest(
platform="client",
user_name=user_id,
platform_id=chat_id,
message=message,
)
resp: MessageResponse | None = await regulatory_node.working.remote(payload)
return _extract_reply(resp)
class CreateChatRequest(BaseModel):
title: str = "新对话"
initial_message: str
@@ -45,9 +75,7 @@ async def create_chat_session(
)
# 调用监管节点处理简单任务/交流
regulatory_node = ray_actor_hook("regulatory_node").regulatory_node
# 在此发起任务并等待或异步返回结果
response_msg = await regulatory_node.handle_chat_message.remote(
response_msg = await _ask_regulatory(
user_id=token_data.user_id,
chat_id=chat.chat_id,
message=request.initial_message,
@@ -95,9 +123,10 @@ async def send_chat_message(
)
# 调用监管节点
regulatory_node = ray_actor_hook("regulatory_node").regulatory_node
response_msg = await regulatory_node.handle_chat_message.remote(
user_id=token_data.user_id, chat_id=chat_id, message=request.message
response_msg = await _ask_regulatory(
user_id=token_data.user_id,
chat_id=chat_id,
message=request.message,
)
# 存回复
+54
View File
@@ -0,0 +1,54 @@
# Copyright 2026 zhaoxi826
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""健康检查端点:用于容器存活/就绪探针。"""
from fastapi import APIRouter
from fastapi.responses import JSONResponse
from kilostar.utils.ray_hook import ray_actor_hook
health_router = APIRouter(tags=["health"])
@health_router.get("/health/live", include_in_schema=True)
async def liveness():
"""存活探针:进程能响应即视为存活。"""
return {"status": "alive"}
@health_router.get("/health/ready", include_in_schema=True)
async def readiness():
"""就绪探针:检查关键依赖(Postgres / GSM Actor)是否可达。"""
checks = {"postgres": False, "global_state_machine": False}
try:
postgres_database = ray_actor_hook("postgres_database").postgres_database
await postgres_database.ping.remote()
checks["postgres"] = True
except Exception:
pass
try:
gsm = ray_actor_hook("global_state_machine").global_state_machine
await gsm.get_skill_list.remote()
checks["global_state_machine"] = True
except Exception:
pass
all_ok = all(checks.values())
return JSONResponse(
status_code=200 if all_ok else 503,
content={"status": "ready" if all_ok else "not_ready", "checks": checks},
)
+2 -1
View File
@@ -13,5 +13,6 @@
# limitations under the License.
from .frontend import client_router
from .onebot import onebot_router
__all__ = ["client_router"]
__all__ = ["client_router", "onebot_router"]
+13 -3
View File
@@ -16,6 +16,10 @@ from fastapi import APIRouter, Depends, HTTPException, UploadFile, File
from pydantic import BaseModel
from kilostar.utils.access import Accessor, TokenData
from kilostar.utils.ray_hook import ray_actor_hook
from kilostar.core.individual.regulatory_node.template import (
MessageRequest,
MessageResponse,
)
import os
import anyio
from kilostar.utils.logger import get_logger
@@ -39,12 +43,18 @@ async def create_message(
logger.info("收到消息,来源:客户端")
logger.debug(f"消息内容:{message.message}")
regulatory_node = ray_actor_hook("regulatory_node").regulatory_node
reply = await regulatory_node.handle_client_message.remote(
user_id=token_data.user_id,
msg_request = MessageRequest(
platform="client",
user_name=token_data.username,
platform_id=token_data.user_id,
message=message.message,
)
return {"message": reply}
result = await regulatory_node.working.remote(msg_request)
if isinstance(result, MessageResponse):
return {"message": result.reply_message}
if isinstance(result, str):
return {"message": result}
return {"message": ""}
@client_router.post("/upload")
+279
View File
@@ -0,0 +1,279 @@
# Copyright 2026 zhaoxi826
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""OneBot v11 协议适配器。
接收来自 OneBot 实现端(NapCat / go-cqhttp / Lagrange.OneBot 等)的事件上报,
把消息事件翻译成 ``MessageRequest`` 投递给 RegulatoryNode。同时支持两种连接方式:
- HTTP 上报(POST ``/api/v1/adapter/onebot/event``):实现端把事件 POST 过来,
通过返回体里的 ``reply`` 走 v11 "快速操作" 自动回包。
- 反向 WebSocketWS ``/api/v1/adapter/onebot/ws``):实现端主动建立长连接,
服务端按 OneBot v11 反向 WS 规范返回 ``send_msg`` 等 action 主动回包。
模块还提供 ``send_message`` 工具函数,用 OneBot v11 HTTP API 主动给指定会话发消息。
"""
import asyncio
import json
import os
import uuid
from typing import Any, Dict, Optional
import httpx
from fastapi import APIRouter, Header, HTTPException, Request, WebSocket, WebSocketDisconnect
from kilostar.core.individual.regulatory_node.template import (
MessageRequest,
MessageResponse,
)
from kilostar.utils.logger import get_logger
from kilostar.utils.ray_hook import ray_actor_hook
logger = get_logger("onebot")
onebot_router = APIRouter(prefix="/api/v1/adapter/onebot", tags=["onebot"])
def _verify_token(token_from_header: Optional[str]) -> None:
"""校验 OneBot 实现端在 ``Authorization`` 头里携带的 access_token。
若环境变量 ``ONEBOT_ACCESS_TOKEN`` 未设置则跳过校验。OneBot v11 规范要求
格式为 ``Bearer <token>``,这里同时容忍只填 token 字符串本身的写法。
"""
expected = os.environ.get("ONEBOT_ACCESS_TOKEN")
if not expected:
return
if not token_from_header:
raise HTTPException(status_code=401, detail="Missing access_token")
raw = token_from_header.removeprefix("Bearer ").removeprefix("Token ").strip()
if raw != expected:
raise HTTPException(status_code=401, detail="Invalid access_token")
def _extract_plain_text(message: Any) -> str:
"""把 OneBot 消息字段(字符串或 segment 数组)展平成纯文本。
OneBot v11 既支持 CQ 码字符串,也支持消息段数组形式;这里只抽取 ``text``
段,其它段(图片/at/表情等)暂时丢弃。
"""
if isinstance(message, str):
return message
if isinstance(message, list):
parts = []
for seg in message:
if isinstance(seg, dict) and seg.get("type") == "text":
parts.append(seg.get("data", {}).get("text", ""))
return "".join(parts)
return ""
async def _dispatch_event(payload: Dict[str, Any]) -> Optional[Dict[str, Any]]:
"""把一次 OneBot 事件交给 RegulatoryNode 处理,返回快速操作字典或 ``None``。
仅处理 ``post_type == "message"`` 的私聊与群聊;元事件、通知、请求事件
一律忽略。返回结果遵循 OneBot v11 "快速操作" 约定:
- 群聊:``{"reply": "...", "at_sender": False, "_target": {...}}``
- 私聊:``{"reply": "...", "_target": {...}}``
其中 ``_target`` 仅供 ``_dispatch_via_ws`` 决定 send_msg 的入参;HTTP 模式下
会被剔除后再返回给实现端。
"""
if payload.get("post_type") != "message":
return None
message_type = payload.get("message_type") # private | group
user_id = str(payload.get("user_id", ""))
group_id = payload.get("group_id")
raw_text = _extract_plain_text(payload.get("message", ""))
sender = payload.get("sender") or {}
user_name = (
sender.get("card") or sender.get("nickname") or user_id or "onebot_user"
)
platform_id = (
f"group:{group_id}" if message_type == "group" else f"private:{user_id}"
)
if not raw_text.strip():
return None
logger.info(
f"[OneBot] {message_type} 消息 from {user_name}({user_id}) -> {raw_text!r}"
)
msg_request = MessageRequest(
platform="onebot",
user_name=user_name,
platform_id=platform_id,
message=raw_text,
)
try:
regulatory_node = ray_actor_hook("regulatory_node").regulatory_node
result = await regulatory_node.working.remote(msg_request)
except Exception as e:
logger.exception(f"[OneBot] RegulatoryNode 调用失败: {e}")
return None
reply_text = ""
if isinstance(result, MessageResponse):
reply_text = result.reply_message or ""
elif isinstance(result, str):
reply_text = result
if not reply_text:
return None
quick = {
"reply": reply_text,
"_target": {
"message_type": message_type,
"user_id": int(user_id) if user_id.isdigit() else user_id,
"group_id": group_id,
},
}
if message_type == "group":
quick["at_sender"] = False
return quick
@onebot_router.post("/event")
async def receive_event(
request: Request,
authorization: Optional[str] = Header(None),
):
"""HTTP 上报入口:接收 OneBot v11 事件并触发 RegulatoryNode。
若 RegulatoryNode 给出回复,会按 v11 "快速操作" 约定写到响应体里,由实现端
自动发送。若不需要回复则返回 ``{"status": "ok"}``。
"""
_verify_token(authorization)
payload: Dict[str, Any] = await request.json()
quick = await _dispatch_event(payload)
if not quick:
return {"status": "ok"}
quick.pop("_target", None)
return quick
async def _ws_call_action(
ws: WebSocket, action: str, params: Dict[str, Any]
) -> None:
"""通过反向 WS 给实现端发送一次 action 调用,不等待响应。"""
echo = uuid.uuid4().hex
frame = {"action": action, "params": params, "echo": echo}
await ws.send_text(json.dumps(frame, ensure_ascii=False))
@onebot_router.websocket("/ws")
async def reverse_websocket(
websocket: WebSocket,
authorization: Optional[str] = Header(None),
x_self_id: Optional[str] = Header(None),
):
"""反向 WebSocket 入口:接受 OneBot 实现端主动建立的长连接。
握手时校验 ``Authorization`` 头;之后循环读 JSON 帧。带 ``post_type`` 的
视为事件上报,调用 RegulatoryNode 处理后通过 ``send_msg`` action 主动回包;
带 ``echo`` 的视为 action 响应,目前直接丢弃(后续若需可在此处认领 future)。
"""
try:
_verify_token(authorization)
except HTTPException:
await websocket.close(code=4401)
return
await websocket.accept()
logger.info(f"[OneBot] reverse WS connected (self_id={x_self_id})")
try:
while True:
text = await websocket.receive_text()
try:
payload = json.loads(text)
except json.JSONDecodeError:
logger.warning(f"[OneBot] invalid JSON frame: {text[:200]}")
continue
# action 响应帧(含 echo 而无 post_type),目前忽略
if "post_type" not in payload and "echo" in payload:
continue
quick = await _dispatch_event(payload)
if not quick:
continue
target = quick.get("_target", {})
params: Dict[str, Any] = {"message": quick["reply"]}
if target.get("message_type") == "group" and target.get("group_id"):
params["group_id"] = target["group_id"]
action = "send_group_msg"
else:
params["user_id"] = target.get("user_id")
action = "send_private_msg"
asyncio.create_task(_ws_call_action(websocket, action, params))
except WebSocketDisconnect:
logger.info("[OneBot] reverse WS disconnected")
except Exception as e:
logger.exception(f"[OneBot] reverse WS error: {e}")
try:
await websocket.close()
except Exception:
pass
async def send_message(
user_id: Optional[int] = None,
group_id: Optional[int] = None,
message: str = "",
*,
base_url: Optional[str] = None,
access_token: Optional[str] = None,
) -> Dict[str, Any]:
"""通过 OneBot v11 HTTP API 主动给私聊或群聊发送一条消息。
Args:
user_id: 目标 QQ 号;与 ``group_id`` 二选一。
group_id: 目标群号;与 ``user_id`` 二选一。
message: 要发送的消息文本。
base_url: OneBot 实现端的 HTTP API 地址;默认读取 ``ONEBOT_HTTP_URL``。
access_token: 鉴权 token;默认读取 ``ONEBOT_ACCESS_TOKEN``。
Returns:
OneBot HTTP API 的原始响应 JSON。
"""
if not user_id and not group_id:
raise ValueError("必须指定 user_id 或 group_id 之一")
base = base_url or os.environ.get("ONEBOT_HTTP_URL", "http://127.0.0.1:5700")
token = access_token or os.environ.get("ONEBOT_ACCESS_TOKEN")
if group_id:
action = "send_group_msg"
body = {"group_id": int(group_id), "message": message}
else:
action = "send_private_msg"
body = {"user_id": int(user_id), "message": message}
headers: Dict[str, str] = {}
if token:
headers["Authorization"] = f"Bearer {token}"
url = f"{base.rstrip('/')}/{action}"
async with httpx.AsyncClient(timeout=15.0) as client:
resp = await client.post(url, json=body, headers=headers)
resp.raise_for_status()
return resp.json()
+15 -5
View File
@@ -14,11 +14,10 @@
from fastapi import APIRouter, Depends
from pydantic import BaseModel
from typing import Literal
from typing import Any, Dict, Literal
from kilostar.utils.access import TokenData, Accessor
from kilostar.utils.check_user.role_check import RoleChecker
from kilostar.core.postgres_database.model import UserAuthority
from typing import Dict
from kilostar.core.global_state_machine.model_provider.base_provider import Provider
from kilostar.utils.ray_hook import ray_actor_hook
@@ -50,16 +49,27 @@ async def create_provider(
)
def _mask_apikey(key: str) -> str:
if not key or len(key) <= 8:
return "***"
return key[:4] + "***" + key[-4:]
@provider_router.get("/list")
async def get_provider_list(
_: TokenData = Depends(Accessor.get_current_user),
) -> Dict[str, Dict[str, Provider]]:
"""返回当前所有已注册的 Provider,前端用以展示模型清单。"""
) -> Dict[str, Any]:
"""返回当前所有已注册的 Provider,前端用以展示模型清单。apikey 脱敏。"""
global_state_machine = ray_actor_hook("global_state_machine").global_state_machine
provider_list: Dict[
str, Provider
] = await global_state_machine.get_provider_list.remote()
return {"provider_list": provider_list}
masked = {}
for title, p in provider_list.items():
d = p.model_dump() if hasattr(p, "model_dump") else dict(p)
d["provider_apikey"] = _mask_apikey(d.get("provider_apikey", ""))
masked[title] = d
return {"provider_list": masked}
@provider_router.delete("/{provider_title}")
+274 -5
View File
@@ -12,13 +12,15 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, Dict, List, Optional
from pydantic import BaseModel
import viceroy
from kilostar.utils.ray_hook import ray_actor_hook
from fastapi import APIRouter, Depends
from fastapi import APIRouter, Depends, HTTPException
from kilostar.utils.access import TokenData
from kilostar.utils.check_user.role_check import RoleChecker
from kilostar.core.postgres_database.model import UserAuthority
from kilostar.utils.mcp_helper import list_mcp_tools_from_gsm
resource_router = APIRouter(prefix="/api/v1/resource")
@@ -30,13 +32,24 @@ class Skill(BaseModel):
path: str | None
class MCPServerConfig(BaseModel):
"""``POST /mcp`` 入参:MCP 服务器配置。"""
name: str
transport: str = "stdio" # stdio | sse | http
command: str | None = None
args: list[str] | None = None
url: str | None = None
tool_prefix: str | None = None
env: Dict[str, str] | None = None
@resource_router.post("/skill")
async def install_skill(
skill: Skill, _: TokenData = Depends(RoleChecker(allowed_roles=UserAuthority.USER))
):
"""通过 viceroy 把 skill 仓库克隆到 ``plugin/skill``,并在状态机中登记。"""
global_state_machine = ray_actor_hook("global_state_machine").global_state_machine
# noinspection PyUnresolvedReferences
import os
skill_output_dir = os.path.abspath(
@@ -73,19 +86,275 @@ async def delete_skill(
):
"""从状态机中移除 skill 注册项;不会删除磁盘上的代码文件。"""
global_state_machine = ray_actor_hook("global_state_machine").global_state_machine
# Note: this only removes it from the state machine manager.
await global_state_machine.remove_skill.remote(skill_name)
return {"message": "success"}
# ─── MCP Server Management ───
@resource_router.post("/mcp")
async def add_mcp_server(
config: MCPServerConfig,
_: TokenData = Depends(RoleChecker(allowed_roles=UserAuthority.SUPER_ADMINISTRATOR)),
):
"""注册一个 MCP 服务器到全局状态机。"""
global_state_machine = ray_actor_hook("global_state_machine").global_state_machine
import uuid
server_id = str(uuid.uuid4())[:8]
cfg_dict = config.model_dump(exclude_none=True)
await global_state_machine.add_mcp_server.remote(server_id, cfg_dict)
return {"server_id": server_id, "message": "MCP server registered"}
@resource_router.get("/mcp")
async def list_mcp_servers(
_: TokenData = Depends(RoleChecker(allowed_roles=UserAuthority.USER)),
):
"""返回已注册的全部 MCP 服务器配置;env 中的敏感字段脱敏。"""
global_state_machine = ray_actor_hook("global_state_machine").global_state_machine
servers = await global_state_machine.list_mcp_servers.remote()
for s in servers:
if "env" in s and isinstance(s["env"], dict):
s["env"] = _mask_config(s["env"])
return {"servers": servers}
@resource_router.delete("/mcp/{server_id}")
async def delete_mcp_server(
server_id: str,
_: TokenData = Depends(RoleChecker(allowed_roles=UserAuthority.SUPER_ADMINISTRATOR)),
):
"""从状态机中移除一个 MCP 服务器配置。"""
global_state_machine = ray_actor_hook("global_state_machine").global_state_machine
ok = await global_state_machine.delete_mcp_server.remote(server_id)
if not ok:
raise HTTPException(status_code=404, detail="MCP server not found")
return {"message": "success"}
# ─── Tool Management ───
@resource_router.get("/tool")
async def get_tools(
_: TokenData = Depends(RoleChecker(allowed_roles=UserAuthority.USER)),
):
"""汇总各作用域 tool_mapper,返回去重后的工具名称列表。"""
"""返回按分类聚合的工具信息(包含系统工具、搜索工具、MCP 工具等)。
其中 ``mcp_servers`` 会现场尝试连接每个已注册的 MCP 服务器并列出它们暴露的
工具名,便于前端展示;任意一台 MCP server 不可达不影响其他工具的返回。
"""
global_state_machine = ray_actor_hook("global_state_machine").global_state_machine
tool_mapper = await global_state_machine.get_tool_mapper.remote()
categories = await global_state_machine.get_tool_categories.remote()
all_tool_names = set()
for scope_tools in tool_mapper.values():
all_tool_names.update(scope_tools.keys())
return {"tools": list(all_tool_names)}
mcp_servers = await list_mcp_tools_from_gsm()
return {
"tools": list(all_tool_names),
"categories": categories,
"mcp_servers": mcp_servers,
}
# ─── Tool Config ManagementTavily API key 等运行期配置)───
def _mask_secret(value: Any) -> Any:
"""对像 ``api_key`` / ``token`` / ``secret`` 这种敏感字段做简单脱敏。"""
if not isinstance(value, str) or not value:
return value
if len(value) <= 8:
return "***"
return value[:4] + "***" + value[-4:]
def _mask_config(config: Dict[str, Any]) -> Dict[str, Any]:
masked: Dict[str, Any] = {}
for k, v in config.items():
if any(s in k.lower() for s in ("key", "token", "secret", "password")):
masked[k] = _mask_secret(v)
else:
masked[k] = v
return masked
class ToolConfigUpdate(BaseModel):
"""``PUT /tool/config/{tool_name}`` 入参:要写入的工具配置 KV。"""
config: Dict[str, Any]
@resource_router.get("/tool/config")
async def list_tool_configs(
_: TokenData = Depends(RoleChecker(allowed_roles=UserAuthority.SUPER_ADMINISTRATOR)),
):
"""列出所有工具运行期配置;敏感字段会被脱敏。"""
global_state_machine = ray_actor_hook("global_state_machine").global_state_machine
raw = await global_state_machine.list_tool_configs.remote()
return {
"configs": {name: _mask_config(cfg) for name, cfg in raw.items()},
}
@resource_router.get("/tool/config/{tool_name}")
async def get_tool_config(
tool_name: str,
_: TokenData = Depends(RoleChecker(allowed_roles=UserAuthority.SUPER_ADMINISTRATOR)),
):
"""按工具名取出脱敏后的配置。"""
global_state_machine = ray_actor_hook("global_state_machine").global_state_machine
raw = await global_state_machine.get_tool_config.remote(tool_name)
return {"tool_name": tool_name, "config": _mask_config(raw)}
@resource_router.put("/tool/config/{tool_name}")
async def set_tool_config(
tool_name: str,
body: ToolConfigUpdate,
_: TokenData = Depends(RoleChecker(allowed_roles=UserAuthority.SUPER_ADMINISTRATOR)),
):
"""写入/覆盖某工具的运行期配置(如 ``tavily_search`` 的 ``api_key``)。"""
global_state_machine = ray_actor_hook("global_state_machine").global_state_machine
await global_state_machine.set_tool_config.remote(tool_name, body.config)
return {"message": "success"}
@resource_router.delete("/tool/config/{tool_name}")
async def delete_tool_config(
tool_name: str,
_: TokenData = Depends(RoleChecker(allowed_roles=UserAuthority.SUPER_ADMINISTRATOR)),
):
"""删除某工具的运行期配置。"""
global_state_machine = ray_actor_hook("global_state_machine").global_state_machine
ok = await global_state_machine.delete_tool_config.remote(tool_name)
if not ok:
raise HTTPException(status_code=404, detail="Tool config not found")
return {"message": "success"}
# ─── Custom Toolset Management ───
class CustomToolsetCreate(BaseModel):
name: str
tools: List[str]
description: Optional[str] = None
class CustomToolsetUpdate(BaseModel):
name: Optional[str] = None
tools: Optional[List[str]] = None
description: Optional[str] = None
async def _assert_toolset_owner_or_admin(
toolset: Dict[str, Any], token_data: TokenData
) -> None:
"""校验 toolset 归属:非 owner 且非管理员则抛 403。"""
from kilostar.utils.check_user.role_check import get_authority
if toolset.get("owner_id") == token_data.user_id:
return
authority = await get_authority(token_data.user_id)
if authority >= UserAuthority.ADMINISTRATOR:
return
raise HTTPException(status_code=403, detail="无权访问此自定义工具组")
@resource_router.post("/custom-toolset")
async def create_custom_toolset(
body: CustomToolsetCreate,
token_data: TokenData = Depends(RoleChecker(allowed_roles=UserAuthority.USER)),
):
global_state_machine = ray_actor_hook("global_state_machine").global_state_machine
import uuid
toolset_id = str(uuid.uuid4())[:8]
try:
saved = await global_state_machine.add_custom_toolset.remote(
toolset_id=toolset_id,
name=body.name,
tools=body.tools,
description=body.description,
owner_id=token_data.user_id,
)
except ValueError as e:
raise HTTPException(status_code=400, detail=str(e))
return {"toolset_id": toolset_id, "toolset": saved}
@resource_router.get("/custom-toolset")
async def list_custom_toolsets(
token_data: TokenData = Depends(RoleChecker(allowed_roles=UserAuthority.USER)),
):
"""列出工具组:USER 只能看到自己的;ADMIN 及以上可看全部。"""
from kilostar.utils.check_user.role_check import get_authority
global_state_machine = ray_actor_hook("global_state_machine").global_state_machine
toolsets = await global_state_machine.list_custom_toolsets.remote()
authority = await get_authority(token_data.user_id)
if authority < UserAuthority.ADMINISTRATOR:
toolsets = [t for t in toolsets if t.get("owner_id") == token_data.user_id]
return {"toolsets": toolsets}
@resource_router.get("/custom-toolset/{toolset_id}")
async def get_custom_toolset(
toolset_id: str,
token_data: TokenData = Depends(RoleChecker(allowed_roles=UserAuthority.USER)),
):
global_state_machine = ray_actor_hook("global_state_machine").global_state_machine
ts = await global_state_machine.get_custom_toolset.remote(toolset_id)
if not ts:
raise HTTPException(status_code=404, detail="Custom toolset not found")
await _assert_toolset_owner_or_admin(ts, token_data)
return ts
@resource_router.put("/custom-toolset/{toolset_id}")
async def update_custom_toolset(
toolset_id: str,
body: CustomToolsetUpdate,
token_data: TokenData = Depends(RoleChecker(allowed_roles=UserAuthority.USER)),
):
global_state_machine = ray_actor_hook("global_state_machine").global_state_machine
existing = await global_state_machine.get_custom_toolset.remote(toolset_id)
if not existing:
raise HTTPException(status_code=404, detail="Custom toolset not found")
await _assert_toolset_owner_or_admin(existing, token_data)
name = body.name if body.name is not None else existing["name"]
tools = body.tools if body.tools is not None else existing["tools"]
description = body.description if body.description is not None else existing.get("description")
try:
saved = await global_state_machine.add_custom_toolset.remote(
toolset_id=toolset_id,
name=name,
tools=tools,
description=description,
owner_id=existing.get("owner_id", token_data.user_id),
)
except ValueError as e:
raise HTTPException(status_code=400, detail=str(e))
return {"toolset": saved}
@resource_router.delete("/custom-toolset/{toolset_id}")
async def delete_custom_toolset(
toolset_id: str,
token_data: TokenData = Depends(RoleChecker(allowed_roles=UserAuthority.USER)),
):
"""删除工具组:USER 只能删自己的;ADMIN 及以上可删任意。"""
global_state_machine = ray_actor_hook("global_state_machine").global_state_machine
existing = await global_state_machine.get_custom_toolset.remote(toolset_id)
if not existing:
raise HTTPException(status_code=404, detail="Custom toolset not found")
await _assert_toolset_owner_or_admin(existing, token_data)
ok = await global_state_machine.delete_custom_toolset.remote(toolset_id)
if not ok:
raise HTTPException(status_code=404, detail="Custom toolset not found")
return {"message": "success"}
+86
View File
@@ -119,3 +119,89 @@ async def get_workflow_detail(
"steps": steps,
"context_blackboard": context.blackboard if context else {},
}
@workflow_router.post("/{trace_id}/resume")
async def resume_workflow(
trace_id: str,
token_data: TokenData = Depends(Accessor.get_current_user),
):
"""从 ``workflow_graph_state`` 持久化恢复一个被中断/挂起的工作流。
新 fire 一个 ray tasktask 入口的 ``hydrate`` 检查会自动走 resume 路径
把剩余节点跑完。
"""
postgres_database = ray_actor_hook("postgres_database").postgres_database
wf = await postgres_database.get_workflow.remote(trace_id)
if not wf:
raise HTTPException(status_code=404, detail="Workflow not found")
if getattr(wf, "user_id", None) != token_data.user_id:
raise HTTPException(status_code=403, detail="Forbidden")
record = await postgres_database.get_workflow_graph_state.remote(trace_id)
if record is None:
raise HTTPException(
status_code=409, detail="该工作流没有可恢复的图持久化记录"
)
global_workflow_manager = ray_actor_hook(
"global_workflow_manager"
).global_workflow_manager
await global_workflow_manager.create_trace.remote(trace_id)
from kilostar.core.work.workflow.workflow_engine import run_workflow_task
# workflow_data 在 resume 路径上不会被使用(hydrate 会走 resume 分支),
# 这里给个空 dict 占位即可
run_workflow_task.remote({}, trace_id)
return {"trace_id": trace_id, "status": "resuming"}
@workflow_router.get("/{trace_id}/graph")
async def get_workflow_graph_mermaid(
trace_id: str,
token_data: TokenData = Depends(Accessor.get_current_user),
):
"""返回当前 workflow 引擎的 mermaid 图源码(节点拓扑)。
拓扑本身对所有 trace 是同一份;但如果该 trace 已经有 ``workflow_graph_state``
持久化记录,会读出 history 里"已经成功跑过的节点"作为 ``highlighted_nodes``
传给 mermaid,前端拿到的 mermaid 源码会自带 visited 节点高亮。
"""
from kilostar.core.work.workflow.workflow_engine import workflow_graph
postgres_database = ray_actor_hook("postgres_database").postgres_database
wf = await postgres_database.get_workflow.remote(trace_id)
if not wf:
raise HTTPException(status_code=404, detail="Workflow not found")
if getattr(wf, "user_id", None) != token_data.user_id:
raise HTTPException(status_code=403, detail="Forbidden")
visited: list[str] = []
record = await postgres_database.get_workflow_graph_state.remote(trace_id)
if record is not None:
history = getattr(record, "history", None) or []
# history 里每条 NodeSnapshot.id 形如 "ClassName:hash",截前缀作为 NodeIdent
# 只取 status == "success" 的节点(避免 "created" / "running" 带噪声)
seen: set[str] = set()
for entry in history:
if not isinstance(entry, dict):
continue
if entry.get("kind") != "node":
continue
if entry.get("status") != "success":
continue
sid = entry.get("id") or ""
cls_name = sid.split(":", 1)[0] if sid else ""
if cls_name and cls_name not in seen:
seen.add(cls_name)
visited.append(cls_name)
try:
if visited:
mermaid = workflow_graph.mermaid_code(highlighted_nodes=visited)
else:
mermaid = workflow_graph.mermaid_code()
except Exception as e:
raise HTTPException(status_code=500, detail=f"mermaid 生成失败: {e}")
return {"trace_id": trace_id, "mermaid": mermaid, "visited": visited}
@@ -13,47 +13,129 @@
# limitations under the License.
import ray
from kilostar.core.global_state_machine.provider_manager import ProviderManager
from kilostar.core.global_state_machine.tool_manager import GlobalToolManager
from kilostar.core.postgres_database import PostgresDatabase
from kilostar.core.global_state_machine.skill_manager import GlobalSkillManager
from typing import Any, Dict, List, Optional, Tuple
from kilostar.core.global_state_machine.individual_manager import (
GlobalIndividualManager,
)
from kilostar.core.global_state_machine.provider_manager import ProviderManager
from kilostar.core.global_state_machine.skill_manager import GlobalSkillManager
from kilostar.core.global_state_machine.tool_manager import GlobalToolManager
from kilostar.core.global_state_machine.gsm_snapshot import GSMSnapshot
from kilostar.core.postgres_database import PostgresDatabase
@ray.remote
class GlobalStateMachine:
"""全局状态机 Actor,统一持有 Provider/Tool/Skill/Individual 四个注册表。
"""全局状态机 Actor,统一持有 Provider/Tool/Skill/Individual/MCP/CustomToolset 注册表。
其它 Actor 通过 ``ray.get_actor("global_state_machine")`` 拿到本实例,
再调用本类暴露的方法来读写各注册表,避免每个 Actor 各自维护一份状态
所有持久化都走 PostgresDatabase;启动时由 ``init_state_machine`` 一次性把
Provider / MCP / Tool config / Custom toolset 拉到内存,保证后续读操作零等待
"""
def __init__(self, postgres_database: PostgresDatabase):
import sys
print("GSM __init__ START", file=sys.stderr, flush=True)
print(" event_dict done", file=sys.stderr, flush=True)
self._global_provider_manager = ProviderManager(postgres_database)
print(" provider_manager done", file=sys.stderr, flush=True)
self._global_tool_manager = GlobalToolManager()
print(" tool_manager done", file=sys.stderr, flush=True)
self._global_skill_manager = GlobalSkillManager()
print(" skill_manager done", file=sys.stderr, flush=True)
self._global_individual_manager = GlobalIndividualManager()
print(" individual_manager done", file=sys.stderr, flush=True)
# 内存注册表(启动时由 init_state_machine 从 DB 加载)
self._mcp_servers: Dict[str, Dict[str, Any]] = {}
self._tool_configs: Dict[str, Dict[str, Any]] = {}
self._custom_toolsets: Dict[str, Dict[str, Any]] = {}
# 配置快照与版本号:每次写入 → version+=1 → ray.put 新 snapshot
# 读端通过 current_config_ref 拿 ref 后用 ray.get 直读,绕开 actor 单线程瓶颈
self._config_version: int = 0
self._current_ref: Optional[ray.ObjectRef] = None
self.postgres_database = postgres_database
print("GSM __init__ DONE", file=sys.stderr, flush=True)
async def init_state_machine(self):
"""从数据库加载 Provider/Individual 注册表到内存。"""
"""启动期一次性把 Provider/Individual/MCP/ToolConfig/CustomToolset 拉到内存。"""
await self._global_provider_manager.init_provider_register(
self.postgres_database
)
await self._global_individual_manager.init_individual_register(
self.postgres_database
)
# MCP servers
rows = await self.postgres_database.list_mcp_servers_db.remote()
self._mcp_servers = {row["server_id"]: row for row in rows}
# Tool configs
cfg_rows = await self.postgres_database.list_tool_configs_db.remote()
self._tool_configs = {row["tool_name"]: row["config"] for row in cfg_rows}
# Custom toolsets
ts_rows = await self.postgres_database.list_custom_toolsets.remote()
self._custom_toolsets = {row["toolset_id"]: row for row in ts_rows}
# 让 tool_manager 立刻把 custom toolset 装配成 FunctionToolset
self._global_tool_manager.rebuild_custom_toolsets(self._custom_toolsets)
# 启动期一次性发布 v1 快照,让等待中的读端立刻可用
self._publish_snapshot()
# ─── Snapshot 发布(Object Store 读路径) ────────────────────
def _build_snapshot(self) -> GSMSnapshot:
"""把当前内存状态打包成 GSMSnapshot,调用方应已确保数据是最新的。
注意 ``tool_funcs`` 这里被拍平成 ``{tool_name: callable}`` —— 内部按
scope 分桶的细节在 snapshot 层不暴露,task 端只关心"我能调哪些函数"
如果同名工具在多个 scope 出现,后写入的覆盖(系统工具与第三方互斥,
实际不会冲突)。
``system_tools_by_scope`` 单独保留按 scope 分桶的工具名列表,让客户端
在自己进程里复刻 ``get_toolsets_for_scope`` 的合并语义(fetch_snapshot
用户调 ``build_toolsets_for_scope`` 即可重建 FunctionToolset 列表)。
"""
tm = self._global_tool_manager
flat_funcs: Dict[str, Any] = {}
system_tools_by_scope: Dict[str, List[str]] = {}
for _scope, name_to_func in tm._tool_funcs.items():
system_tools_by_scope[_scope] = list(name_to_func.keys())
for name, func in name_to_func.items():
flat_funcs[name] = func
return GSMSnapshot(
version=self._config_version,
providers=dict(self._global_provider_manager.provider_register),
individuals=dict(self._global_individual_manager._individuals),
mcp_servers=dict(self._mcp_servers),
tool_configs=dict(self._tool_configs),
custom_toolsets=dict(self._custom_toolsets),
skills=dict(self._global_skill_manager.skill_mapper),
tool_metadata=dict(tm.tool_metadata),
tool_funcs=flat_funcs,
third_party_funcs=dict(tm._third_party_funcs),
tool_mapper={
scope: dict(name_to_cls)
for scope, name_to_cls in tm.tool_mapper.items()
},
system_tools_by_scope=system_tools_by_scope,
)
def _publish_snapshot(self) -> None:
"""版本号 +1 并把当前状态 put 到 Ray Object Store。
旧 ref 会因为引用计数归零而进入回收队列;正在执行的 task 已经把 ref
拷贝到了自己的进程,dec 不会影响它们的读取。
"""
self._config_version += 1
self._current_ref = ray.put(self._build_snapshot())
async def current_config_ref(self) -> Tuple[int, ray.ObjectRef]:
"""返回 ``(version, ObjectRef)``,调用方拿了 ref 后用 ``ray.get`` 自取。
**不要**直接返回 snapshot 对象 —— 那样会走 actor RPC 反序列化,丧失
object store 的共享内存优势。返回 ref 才能让调用方在自己进程里 ray.get。
"""
if self._current_ref is None:
self._publish_snapshot()
return self._config_version, self._current_ref
async def current_version(self) -> int:
"""轻量版:只返回当前版本号,用于读端判断本地缓存是否还新。"""
return self._config_version
# ─── Provider ──────────────────────────────────────────────
async def add_provider_wrap(
self,
@@ -64,7 +146,7 @@ class GlobalStateMachine:
provider_owner,
):
"""新增一个模型 Provider:内存注册 + 数据库持久化一并完成。"""
return await self._global_provider_manager.add_provider(
result = await self._global_provider_manager.add_provider(
provider_type=provider_type,
provider_title=provider_title,
provider_url=provider_url,
@@ -72,8 +154,9 @@ class GlobalStateMachine:
provider_owner=provider_owner,
postgres_database=self.postgres_database,
)
self._publish_snapshot()
return result
# Provider Manager Methods
def get_provider_list(self):
"""返回内存中已登记的全部 Provider。"""
return self._global_provider_manager.get_provider_list()
@@ -84,11 +167,14 @@ class GlobalStateMachine:
async def delete_provider(self, provider_title: str):
"""删除一个 Provider:内存注册 + 数据库持久化一并完成。"""
return await self._global_provider_manager.delete_provider(
result = await self._global_provider_manager.delete_provider(
provider_title, self.postgres_database
)
self._publish_snapshot()
return result
# ─── Tool / Toolset ────────────────────────────────────────
# Tool Manager Methods
def get_tool_mapper(self):
"""返回 agent_name -> {tool_name: callable} 的全量映射。"""
return self._global_tool_manager.tool_mapper
@@ -96,37 +182,152 @@ class GlobalStateMachine:
def get_tool_list(self, agent_name: str):
"""返回某个 agent 可用的工具集(其专属工具与 default 工具的并集)。"""
tools = self._global_tool_manager.tool_mapper.get(agent_name, {})
# also include default tools
default_tools = self._global_tool_manager.tool_mapper.get("default", {})
merged_tools = {**default_tools, **tools}
return merged_tools
return {**default_tools, **tools}
def get_tool_categories(self):
"""返回工具按分类聚合的完整信息。"""
return {
"system": self._global_tool_manager.get_system_tools(),
"third_party": self._global_tool_manager.get_third_party_tools(),
"by_category": {
"system": self._global_tool_manager.get_tools_by_category("system"),
"search": self._global_tool_manager.get_tools_by_category("search"),
"mcp": self._global_tool_manager.get_tools_by_category("mcp"),
"other": self._global_tool_manager.get_tools_by_category("other"),
},
"all": self._global_tool_manager.get_all_tools(),
}
def get_toolsets_for_scope(self, scope: str) -> List[Any]:
"""返回某个 scope 下的"系统 + 自定义工具组"toolset 列表(不含 MCP)。"""
return self._global_tool_manager.get_toolsets_for_scope(scope)
# ─── MCP Server Registry ───────────────────────────────────
async def add_mcp_server(self, server_id: str, config: Dict[str, Any]) -> bool:
"""注册一个 MCP 服务器配置(写库 → 写内存)。"""
saved = await self.postgres_database.upsert_mcp_server.remote(server_id, config)
self._mcp_servers[server_id] = saved
self._publish_snapshot()
return True
def get_mcp_server(self, server_id: str) -> Optional[Dict[str, Any]]:
return self._mcp_servers.get(server_id)
def list_mcp_servers(self) -> List[Dict[str, Any]]:
return [
{"server_id": sid, **cfg} for sid, cfg in self._mcp_servers.items()
]
async def delete_mcp_server(self, server_id: str) -> bool:
ok = await self.postgres_database.delete_mcp_server_db.remote(server_id)
self._mcp_servers.pop(server_id, None)
self._publish_snapshot()
return bool(ok)
def get_mcp_server_configs(self) -> Dict[str, Dict[str, Any]]:
"""返回原始 MCP 服务器配置字典(供节点创建 toolsets 时使用)。"""
return dict(self._mcp_servers)
# ─── Tool ConfigTavily API key 等)─────────────────────
async def set_tool_config(self, tool_name: str, config: Dict[str, Any]) -> bool:
"""整体覆盖某工具的运行期配置(敏感字段在 DAO 内自动加密)。"""
saved = await self.postgres_database.upsert_tool_config.remote(
tool_name, config
)
self._tool_configs[tool_name] = saved["config"]
self._publish_snapshot()
return True
def get_tool_config(self, tool_name: str) -> Dict[str, Any]:
"""按工具名取出配置;不存在则返回空字典。"""
return dict(self._tool_configs.get(tool_name, {}))
def list_tool_configs(self) -> Dict[str, Dict[str, Any]]:
"""返回全部已配置工具的配置(包含敏感字段,调用方需自行脱敏)。"""
return dict(self._tool_configs)
async def delete_tool_config(self, tool_name: str) -> bool:
ok = await self.postgres_database.delete_tool_config_db.remote(tool_name)
self._tool_configs.pop(tool_name, None)
self._publish_snapshot()
return bool(ok)
# ─── Custom Toolset(用户自定义工具组)──────────────────
async def add_custom_toolset(
self,
toolset_id: str,
name: str,
tools: List[str],
description: Optional[str] = None,
owner_id: Optional[str] = None,
) -> Dict[str, Any]:
"""新增/更新一个自定义工具组:仅允许引用非 system/非 mcp 的工具。"""
# 校验:只能放第三方(非 system / 非 mcp)工具
invalid = [
t for t in tools if not self._global_tool_manager.is_third_party_tool(t)
]
if invalid:
raise ValueError(
f"自定义工具组只允许包含第三方工具,以下不合法:{invalid}"
)
saved = await self.postgres_database.upsert_custom_toolset.remote(
toolset_id=toolset_id,
name=name,
tools=list(tools),
description=description,
owner_id=owner_id,
)
self._custom_toolsets[toolset_id] = saved
self._global_tool_manager.rebuild_custom_toolsets(self._custom_toolsets)
self._publish_snapshot()
return saved
def list_custom_toolsets(self) -> List[Dict[str, Any]]:
return list(self._custom_toolsets.values())
def get_custom_toolset(self, toolset_id: str) -> Optional[Dict[str, Any]]:
return self._custom_toolsets.get(toolset_id)
async def delete_custom_toolset(self, toolset_id: str) -> bool:
ok = await self.postgres_database.delete_custom_toolset.remote(toolset_id)
self._custom_toolsets.pop(toolset_id, None)
self._global_tool_manager.rebuild_custom_toolsets(self._custom_toolsets)
self._publish_snapshot()
return bool(ok)
# ─── Skill ────────────────────────────────────────────────
# Skill Manager Methods
def add_skill(self, skill_name: str):
"""注册一个新的 Skill 名称到 Skill 注册表。"""
return self._global_skill_manager.add_skill(skill_name)
result = self._global_skill_manager.add_skill(skill_name)
self._publish_snapshot()
return result
def get_skill_list(self):
"""返回全部已注册的 Skill 名称。"""
return self._global_skill_manager.get_skill_list()
def remove_skill(self, skill_name: str):
"""从注册表中移除一个 Skill。"""
return self._global_skill_manager.remove_skill(skill_name)
result = self._global_skill_manager.remove_skill(skill_name)
self._publish_snapshot()
return result
# ─── Individual ───────────────────────────────────────────
# Individual Manager Methods
def add_individual(self, agent_id: str, config):
"""把一个 Worker Individual 的运行期配置加入注册表。"""
return self._global_individual_manager.add_individual(agent_id, config)
result = self._global_individual_manager.add_individual(agent_id, config)
self._publish_snapshot()
return result
def get_individual(self, agent_id: str):
"""按 agent_id 取出某个 Worker Individual 的配置。"""
return self._global_individual_manager.get_individual(agent_id)
def remove_individual(self, agent_id: str):
"""从注册表中移除一个 Worker Individual。"""
return self._global_individual_manager.remove_individual(agent_id)
result = self._global_individual_manager.remove_individual(agent_id)
self._publish_snapshot()
return result
def list_individuals(self):
"""返回当前注册的全部 Worker Individual 列表。"""
return self._global_individual_manager.list_individuals()
@@ -0,0 +1,184 @@
# Copyright 2026 zhaoxi826
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""GSM 快照对象与客户端拉取工具。
设计动机:把 GSM actor 内存里的"读路径配置"打包成可放进 Ray Object Store
的不可变快照。读端不再走 actor RPC,而是 ``fetch_snapshot()`` 一次拿到全量
当前配置,亚毫秒级共享内存读,绕开单 actor 的吞吐瓶颈。
GSM 仍然是 source of truth + 写入串行化器,但读路径解耦:
- 写入 → GSM 内存更新 → ``ray.put(snapshot)`` 拿到新 ObjectRef → 版本号 +1
- 读取 → ``current_config_ref()`` 拿 (version, ref) → ``ray.get(ref)`` 直读
旧的 ``get_provider / get_individual / ...`` 接口保留不动,是低频路径的兜底;
新代码(特别是 skill task 这种高并发热路径)应优先走 ``fetch_snapshot``。
"""
from __future__ import annotations
import asyncio
from dataclasses import dataclass, field
from typing import Any, Callable, Dict, List, Optional, Tuple
import ray
from kilostar.core.global_state_machine.model_provider.base_provider import Provider
from kilostar.utils.logger import get_logger
_logger = get_logger("gsm_snapshot")
@dataclass
class GSMSnapshot:
"""GSM 配置的不可变快照,所有字段都必须 cloudpickle 友好。
本类故意不放 ``FunctionToolset`` 实例 —— pydantic-ai toolset 的可序列化性
随版本可能变动,让 task 端按 ``tool_funcs`` + ``tool_mapper`` 自己装配
既隔离了 pydantic-ai 的实现风险,又让 snapshot 体积更小。
"""
version: int = 0
providers: Dict[str, Provider] = field(default_factory=dict)
individuals: Dict[str, Dict[str, Any]] = field(default_factory=dict)
mcp_servers: Dict[str, Dict[str, Any]] = field(default_factory=dict)
tool_configs: Dict[str, Dict[str, Any]] = field(default_factory=dict)
custom_toolsets: Dict[str, Dict[str, Any]] = field(default_factory=dict)
skills: Dict[str, Tuple[str, str]] = field(default_factory=dict)
tool_metadata: Dict[str, Dict[str, Any]] = field(default_factory=dict)
tool_funcs: Dict[str, Callable[..., Any]] = field(default_factory=dict)
third_party_funcs: Dict[str, Callable[..., Any]] = field(default_factory=dict)
tool_mapper: Dict[str, Dict[str, type]] = field(default_factory=dict)
# ``{scope: [tool_name, ...]}``:系统工具按 scope 维护的工具名清单。
# 客户端按名字 + ``tool_funcs`` 在自己进程里重建 FunctionToolset
# 避开把不可序列化/版本耦合的 toolset 实例塞进快照的坑。
system_tools_by_scope: Dict[str, List[str]] = field(default_factory=dict)
_local_cache: Dict[str, Any] = {"version": -1, "snapshot": None}
_cache_lock = asyncio.Lock()
async def fetch_snapshot(
*,
use_cache: bool = True,
gsm_actor: Optional[Any] = None,
) -> GSMSnapshot:
"""拉取当前 GSM 快照。
优先走"版本号检查 + ObjectRef 直读"路径:
1. 调 ``gsm.current_version.remote()`` 看本地缓存是否还新(一次轻量 RPC)
2. 若本地缓存版本号一致,直接返回缓存(亚毫秒,零网络)
3. 否则调 ``gsm.current_config_ref.remote()`` 拿 ref``ray.get`` 解出
4. 更新本地缓存
Args:
use_cache: 默认开启进程内 LRU 缓存(实际是单槽位,持有当前版本);
测试或诊断场景可关掉强制重拉。
gsm_actor: 可选传入 GSM actor handle;省略时通过 ``ray_actor_hook`` 获取。
Note:
本函数在 task / actor 进程内多次调用是廉价的;建议每次需要 config 时
现取,不要把 snapshot 长期持有跨任务边界(避免 ObjectRef 阻碍回收)。
"""
if gsm_actor is None:
from kilostar.utils.ray_hook import ray_actor_hook
gsm_actor = ray_actor_hook("global_state_machine").global_state_machine
if use_cache:
async with _cache_lock:
try:
latest_version = await gsm_actor.current_version.remote()
except Exception:
latest_version = None
if (
latest_version is not None
and _local_cache.get("version") == latest_version
and _local_cache.get("snapshot") is not None
):
return _local_cache["snapshot"]
version, ref = await gsm_actor.current_config_ref.remote()
snapshot = ray.get(ref)
_local_cache["version"] = version
_local_cache["snapshot"] = snapshot
return snapshot
version, ref = await gsm_actor.current_config_ref.remote()
return ray.get(ref)
def reset_local_cache() -> None:
"""清空本进程内的快照缓存(测试用)。"""
_local_cache["version"] = -1
_local_cache["snapshot"] = None
# ─── 客户端 helper:从快照重建本地视图 ─────────────────────────────
def build_toolsets_for_scope(
snapshot: GSMSnapshot, scope: str
) -> List[Any]:
"""在调用方进程里按 ``snapshot`` 现场组装 FunctionToolset 列表。
复刻 ``GlobalToolManager.get_toolsets_for_scope`` 的合并逻辑:
- 系统 toolset:按 ``default`` + ``scope`` 两个 bucket 拼装
- 自定义 toolset``custom_toolsets`` 里所有有效项
返回的 toolset 是 *进程局部* 的——pydantic-ai FunctionToolset 实例不能跨进程
共享,但函数对象本身已经躺在 snapshot 里被 cloudpickle 还原过,
重新 ``FunctionToolset(tools=[...])`` 几乎零代价。
"""
try:
from pydantic_ai.toolsets import FunctionToolset
except ImportError:
_logger.warning("pydantic_ai.toolsets unavailable; cannot build toolsets")
return []
result: List[Any] = []
for bucket in ("default", scope):
names = snapshot.system_tools_by_scope.get(bucket) or []
funcs = [snapshot.tool_funcs[n] for n in names if n in snapshot.tool_funcs]
if not funcs:
continue
try:
result.append(
FunctionToolset(tools=funcs, id=f"system::{bucket}")
)
except Exception as e: # pragma: no cover - 防御
_logger.error(f"build system toolset {bucket} failed: {e}")
for toolset_id, defn in snapshot.custom_toolsets.items():
names = defn.get("tools") or []
funcs = [
snapshot.third_party_funcs[n]
for n in names
if n in snapshot.third_party_funcs
]
if not funcs:
continue
try:
result.append(
FunctionToolset(tools=funcs, id=f"custom::{toolset_id}")
)
except Exception as e: # pragma: no cover - 防御
_logger.error(f"build custom toolset {toolset_id} failed: {e}")
return result
@@ -1,35 +1,42 @@
# Copyright 2026 zhaoxi826
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import pathlib
import importlib
import inspect
from collections import defaultdict
from typing import Any, Callable, Dict, List, Type
from kilostar.plugin.tool_plugin.base_tool import BaseToolData
from typing import Dict, Type
from kilostar.utils.logger import get_logger
logger = get_logger("tool_manager")
_SYSTEM_BUCKET = "system"
class GlobalToolManager:
"""工具注册表:扫描 ``kilostar/plugin/tool_plugin/`` 下所有 BaseToolData 子类,
按 ``action_scope`` 分桶到 ``tool_mapper[scope][plugin_name]``;无 scope 的归入 ``default``。"""
按 ``action_scope`` 打包成 ``FunctionToolset``。
三类 toolset
- **system**``is_system=True`` 的工具,按 scope 分组
- **custom**:用户自定义工具组(由 ``rebuild_custom_toolsets`` 动态构建)
- **mcp**:由 ``mcp_helper`` 独立管理,不经过本类
``category="mcp"`` 的工具不会被本类管理。
"""
tool_metadata: Dict[str, Dict[str, Any]]
_tool_funcs: Dict[str, Dict[str, Callable]]
_system_toolsets: Dict[str, Any]
_custom_toolsets: Dict[str, Any]
_third_party_funcs: Dict[str, Callable]
tool_mapper: Dict[str, Dict[str, Type[BaseToolData]]]
def __init__(self):
def __init__(self) -> None:
self.tool_metadata = {}
self._tool_funcs = defaultdict(dict)
self._system_toolsets = {}
self._custom_toolsets = {}
self._third_party_funcs = {}
self.tool_mapper = defaultdict(dict)
tool_plugin_dir = (
@@ -39,21 +46,154 @@ class GlobalToolManager:
return
for item in tool_plugin_dir.iterdir():
if item.is_dir() and not item.name.startswith("__"):
plugin_name = item.name
module_name = f"kilostar.plugin.tool_plugin.{plugin_name}"
if not (item.is_dir() and not item.name.startswith("__")):
continue
plugin_name = item.name
module_name = f"kilostar.plugin.tool_plugin.{plugin_name}"
try:
module = importlib.import_module(module_name)
for name, obj in inspect.getmembers(module, inspect.isclass):
if issubclass(obj, BaseToolData) and obj is not BaseToolData:
# It's a valid tool class
action_scopes = obj.model_fields.get("action_scope").default
try:
module = importlib.import_module(module_name)
except Exception as e:
logger.warning(f"Failed to import tool plugin {plugin_name}: {e}")
continue
if not action_scopes:
self.tool_mapper["default"][plugin_name] = obj
else:
for scope in action_scopes:
self.tool_mapper[scope][plugin_name] = obj
except Exception as e:
logger.warning(f"Failed to load tool plugin {plugin_name}: {e}")
tool_data_cls = self._find_tool_data_class(module)
if tool_data_cls is None:
continue
tool_func = getattr(module, plugin_name, None)
if not callable(tool_func):
logger.warning(
f"Tool plugin '{plugin_name}' has no callable named "
f"'{plugin_name}' in its module; skipped."
)
continue
action_scopes = (
tool_data_cls.model_fields.get("action_scope").default or []
)
is_system = bool(tool_data_cls.model_fields.get("is_system").default)
category_field = tool_data_cls.model_fields.get("category")
category = (category_field.default if category_field else "other") or "other"
self.tool_metadata[plugin_name] = {
"name": plugin_name,
"is_system": is_system,
"category": category,
"action_scope": list(action_scopes),
}
if category == "mcp":
continue
scopes = [s for s in action_scopes if s] or ["default"]
if is_system:
for scope in scopes:
self._tool_funcs[scope][plugin_name] = tool_func
self.tool_mapper[scope][plugin_name] = tool_data_cls
else:
self._third_party_funcs[plugin_name] = tool_func
for scope in scopes:
self.tool_mapper[scope][plugin_name] = tool_data_cls
self._build_system_toolsets()
def _build_system_toolsets(self) -> None:
FunctionToolset = self._import_function_toolset()
if FunctionToolset is None:
return
for scope, name_to_func in self._tool_funcs.items():
if not name_to_func:
continue
try:
self._system_toolsets[scope] = FunctionToolset(
tools=list(name_to_func.values()),
id=f"system::{scope}",
)
except Exception as e:
logger.error(f"Failed to build system toolset {scope}: {e}")
def rebuild_custom_toolsets(self, custom_defs: Dict[str, Dict[str, Any]]) -> None:
"""根据 DB 中的自定义工具组定义重建 custom FunctionToolset。"""
FunctionToolset = self._import_function_toolset()
if FunctionToolset is None:
self._custom_toolsets = {}
return
new_map: Dict[str, Any] = {}
for toolset_id, defn in custom_defs.items():
tools_names = defn.get("tools") or []
funcs = [
self._third_party_funcs[n]
for n in tools_names
if n in self._third_party_funcs
]
if not funcs:
continue
try:
new_map[toolset_id] = FunctionToolset(
tools=funcs,
id=f"custom::{toolset_id}",
)
except Exception as e:
logger.error(f"Failed to build custom toolset {toolset_id}: {e}")
self._custom_toolsets = new_map
@staticmethod
def _import_function_toolset():
try:
from pydantic_ai.toolsets import FunctionToolset
return FunctionToolset
except ImportError:
logger.warning("pydantic_ai.toolsets unavailable")
return None
@staticmethod
def _find_tool_data_class(module) -> Type[BaseToolData] | None:
for _, obj in inspect.getmembers(module, inspect.isclass):
if issubclass(obj, BaseToolData) and obj is not BaseToolData:
return obj
return None
# ─── Toolset accessors ───
def get_system_toolset(self, scope: str) -> Any | None:
return self._system_toolsets.get(scope)
def get_toolsets_for_scope(self, scope: str) -> List[Any]:
"""合并 systemdefault + scope+ 全部 custom toolset。"""
result: List[Any] = []
for s in ("default", scope):
ts = self._system_toolsets.get(s)
if ts is not None:
result.append(ts)
result.extend(self._custom_toolsets.values())
return result
# ─── Metadata accessors ───
def is_third_party_tool(self, tool_name: str) -> bool:
return tool_name in self._third_party_funcs
def get_tools_by_category(self, category: str) -> List[Dict[str, Any]]:
return [m for m in self.tool_metadata.values() if m.get("category") == category]
def get_system_tools(self) -> List[Dict[str, Any]]:
return [m for m in self.tool_metadata.values() if m.get("is_system") is True]
def get_third_party_tools(self) -> List[Dict[str, Any]]:
return [
m
for m in self.tool_metadata.values()
if m.get("is_system") is not True and m.get("category") != "mcp"
]
def get_all_tools(self) -> List[Dict[str, Any]]:
return list(self.tool_metadata.values())
# 兼容旧接口
def get_non_system_tools(self) -> List[Dict[str, Any]]:
return self.get_third_party_tools()
def get_personal_tools(self) -> List[Dict[str, Any]]:
return self.get_third_party_tools()
@@ -29,6 +29,7 @@ from kilostar.core.global_state_machine.global_state_machine import GlobalStateM
from kilostar.core.global_state_machine.model_provider.base_provider import Provider
from kilostar.adapter.model_adapter.agent_factory import AgentFactory
from kilostar.utils.ray_hook import ray_actor_hook
from kilostar.utils.i18n import agent_prompt
@ray.remote
@@ -45,22 +46,19 @@ class ConsciousnessNode:
provider_title: str,
model_id: str,
tools_list: list[str] = None,
toolsets=None,
locale: str | None = None,
) -> None:
system_prompt: str = (
"你叫kilostar,是一个多智能体AI助手系统中的【意识节点 (Consciousness Node)】。\n"
"你是系统的'高级规划师''架构师',负责处理监控节点分配过来的复杂任务。\n"
"你的主要工作场景包括:\n"
"1. 拆解任务 (Workflow Generation):结合用户的原始命令和提供的模板,生成严谨、可执行的工作流 (kilostarWorkflow),并将其输出为 ForWorkflowEngine 格式。拆解时步骤应清晰连贯。\n"
"2. 中途指导 (Workflow Execution):在工作流执行中,如果某一步骤指派给你,你需要对控制节点的结果进行分析或提供下一步的指导,输出 ForWorkflow 格式。\n"
"3. 总结报告 (regulatory Report):在整个工作流执行完毕后,你需要对整体流程、各个控制节点的执行情况进行审查,并生成一份技术性的总结报告,输出 ForregulatoryNode 格式。\n"
"请确保所有的思考和生成过程符合逻辑,严密且高质量。"
)
system_prompt: str = agent_prompt("consciousness_node", locale=locale)
output_type = Union[ForregulatoryNode, ForWorkflow, ForWorkflowEngine]
from kilostar.utils.get_tool import load_tools_from_list
from kilostar.core.global_state_machine.gsm_snapshot import fetch_snapshot
provider: Provider = await global_state_machine.get_provider.remote(
provider_title
)
snapshot = await fetch_snapshot(gsm_actor=global_state_machine)
provider: Provider = snapshot.providers.get(provider_title)
if provider is None:
from kilostar.utils.i18n import t
raise ValueError(t("provider_not_registered", locale=locale, provider_title=provider_title))
agent_factory = AgentFactory()
callables = load_tools_from_list(tools_list)
@@ -72,6 +70,7 @@ class ConsciousnessNode:
deps_type=ConsciousnessNodeDeps,
agent_name="consciousness_node",
tools=callables,
toolsets=toolsets,
)
@self.agent.system_prompt
@@ -95,6 +94,13 @@ class ConsciousnessNode:
开始进行工作流设计的交互过程(与用户通过 SSE 进行确认或直接生成)
目前简化为:直接根据 command 拆解并构建工作流,然后提交执行。
"""
from kilostar.utils.request_context import trace_id_scope
# 进入工作流域:把 trace_id 绑到 contextvars,本协程所有日志自动带上它
with trace_id_scope(trace_id):
await self._do_start_workflow_design(trace_id, command)
async def _do_start_workflow_design(self, trace_id: str, command: str):
self.logger.info(
f"ConsciousnessNode: 开始为 trace_id {trace_id} 设计工作流。原始命令:{command}"
)
@@ -116,11 +122,11 @@ class ConsciousnessNode:
original_command=command, available_skills=available_skills
)
# 通知 SSE 正在生成图结构
# 通知 SSE 正在生成图结构(pending 队列:节点端写入 → API SSE 读取,单向下行)
global_workflow_manager = ray_actor_hook(
"global_workflow_manager"
).global_workflow_manager
await global_workflow_manager.put_received.remote(
await global_workflow_manager.put_pending.remote(
trace_id, "正在为您构建并规划工作流任务节点,请稍候..."
)
@@ -131,17 +137,17 @@ class ConsciousnessNode:
workflow = result.workflow
workflow.trace_id = trace_id
await global_workflow_manager.put_received.remote(
await global_workflow_manager.put_pending.remote(
trace_id, "工作流构建完成,即将开始执行!"
)
# 将生成的完整工作流提交执行
workflow_engine = ray_actor_hook(
"workflow_running_engine"
).workflow_running_engine
await workflow_engine.execute_workflow.remote(workflow)
# 直接以 ray task 形式 fire workflow,不再经过 WorkflowRunningEngine 这层中转:
# workflow 是一次性、有头有尾的执行,task 语义比常驻 actor 更贴。
from kilostar.core.work.workflow.workflow_engine import run_workflow_task
run_workflow_task.remote(workflow.model_dump(), trace_id)
else:
await global_workflow_manager.put_received.remote(
await global_workflow_manager.put_pending.remote(
trace_id, "很抱歉,工作流生成失败。"
)
await postgres_database.update_workflow_status.remote(trace_id, "failed")
@@ -22,6 +22,7 @@ from kilostar.core.individual.control_node.template import (
ForWorkflowInput,
ControlNodeDeps,
)
from kilostar.utils.i18n import agent_prompt
@ray.remote
@@ -44,6 +45,8 @@ class ControlNode:
provider_title: str,
model_id: str,
tools_list: list[str] = None,
toolsets=None,
locale: str | None = None,
) -> None:
"""
create_agent方法,将agent对象装配到Control的属性内
@@ -54,25 +57,21 @@ class ControlNode:
global_state_machine: 全局状态机
provider_title: 供应商名
model_id: 模型id
locale: 语言代码(zh/en),控制system prompt语言
Returns:
无返回
"""
system_prompt: str = (
"你叫kilostar,是一个多智能体AI助手系统中的【控制节点 (Control Node)】。\n"
"你是系统的'执行者''车间主任',专门负责执行工作流中分配给你的具体子任务。\n"
"你的工作职责是:\n"
"1. 仔细分析分配给你的工作流步骤 (workflow_step) 的目标和要求。\n"
"2. 运用你被分配的工具 (如有) 或者依靠自身的知识和推理能力,精准、高效地完成该任务。\n"
"3. 将执行的结果、产生的数据或者具体的输出,严格按照 ForWorkflow 格式返回。\n"
"请注意:你的输出应当具体、实用,直接提供任务所要求的结果,不要做过多无关的寒暄。"
)
system_prompt: str = agent_prompt("control_node", locale=locale)
output_type = ForWorkflow
from kilostar.utils.get_tool import load_tools_from_list
from kilostar.core.global_state_machine.gsm_snapshot import fetch_snapshot
provider: Provider = await global_state_machine.get_provider.remote(
provider_title
)
snapshot = await fetch_snapshot(gsm_actor=global_state_machine)
provider: Provider = snapshot.providers.get(provider_title)
if provider is None:
from kilostar.utils.i18n import t
raise ValueError(t("provider_not_registered", locale=locale, provider_title=provider_title))
agent_factory = AgentFactory()
callables = load_tools_from_list(tools_list)
@@ -84,6 +83,7 @@ class ControlNode:
deps_type=ControlNodeDeps,
agent_name="control_node",
tools=callables,
toolsets=toolsets,
)
@self.agent.system_prompt
@@ -24,6 +24,7 @@ from kilostar.core.individual.regulatory_node.template import (
MessageResponse
)
from pydantic_ai import RunContext, Agent
from kilostar.utils.i18n import agent_prompt
@ray.remote
@@ -46,6 +47,8 @@ class RegulatoryNode:
provider_title: str,
model_id: str,
tools_list: list[str] = None,
toolsets=None,
locale: str | None = None,
) -> None:
"""
create_agent方法,将agent对象装配到regulatoryNode的属性内
@@ -56,24 +59,21 @@ class RegulatoryNode:
provider_title: 供应商名
model_id: 模型id
tools_list: 工具列表
locale: 语言代码(zh/en),控制system prompt语言
Returns:
无返回
"""
system_prompt: str = (
"你叫kilostar,是一个多智能体AI助手系统中的【监控节点 (regulatory Node)】。\n"
"你是系统的'前台接待''大脑皮层',负责接收用户的初始请求或工作流的最终报告。\n"
"你的核心职责是进行【意图识别与路由】。请仔细阅读用户的请求:\n"
"1. 如果用户只是进行简单的问候、闲聊或查询非常基础的信息,请直接生成友好的回复,使用 ForUser 格式。\n"
"2. 如果用户提出的是复杂任务(如需要编写代码、多步骤规划、数据处理等),请务必将其判定为需要工作流处理的任务,"
" 并使用 ForConsciousnessNode 格式将其移交意识节点处理。\n"
"3. 如果你收到的是 TerminationMessage(代表工作流已完成并生成了报告),请将报告内容转化为友好的面向用户的回复,使用 ForUser 格式。\n"
"请保持冷静、专业,并严格遵循上述路由规则。"
)
system_prompt: str = agent_prompt("regulatory_node", locale=locale)
output_type = Union[MessageResponse]
from kilostar.utils.get_tool import load_tools_from_list
provider: Provider = await global_state_machine.get_provider.remote(
provider_title
)
from kilostar.core.global_state_machine.gsm_snapshot import fetch_snapshot
# 走 Object Store 快照而不是 actor RPC:高频读路径不再受单 actor 串行限制
snapshot = await fetch_snapshot(gsm_actor=global_state_machine)
provider: Provider = snapshot.providers.get(provider_title)
if provider is None:
from kilostar.utils.i18n import t
raise ValueError(t("provider_not_registered", locale=locale, provider_title=provider_title))
agent_factory = AgentFactory()
callables = load_tools_from_list(tools_list)
@@ -85,6 +85,7 @@ class RegulatoryNode:
deps_type=RegulatoryNodeDeps,
agent_name="regulatory_node",
tools=callables,
toolsets=toolsets,
)
@self.agent.system_prompt
@@ -112,16 +113,15 @@ class RegulatoryNode:
)
return prompt
async def working(self, payload: MessageRequest) -> str:
"""working方法,是节点唯一的调用方法,对_run函数的结果进行判断并实现最终回复
async def working(self, payload: MessageRequest) -> Union[MessageResponse, None]:
"""working方法,是节点唯一的调用方法,对_run函数的结果进行判断并返回最终回复
Args:
payload: 消息载荷,包含所有信息
Returns:
str,监控节点对用户的回复
MessageResponse 或 None监控节点对用户的结构化回复
"""
await self._run(payload)
return ""
return await self._run(payload)
async def _run(
self, payload: MessageRequest
@@ -140,7 +140,8 @@ class RegulatoryNode:
deps=deps,)
response: MessageResponse = agent_response.output
response.platform = platform
response.platform_id = MessageRequest.platform_id
response.platform_id = payload.platform_id
return response
except:
pass
except Exception as e:
self.logger.exception(f"RegulatoryNode._run failed: {e}")
return None
@@ -49,7 +49,7 @@ class MessageRequest(RequestModel):
MessageRequest类
任何消息渠道向regulatory_node发送消息请求的模型
"""
platform: Literal["client"]
platform: Literal["client", "onebot"]
user_name: str
platform_id: Optional[str]
message: str
@@ -59,6 +59,6 @@ class MessageResponse(RegulatoryNodeResponse):
MessageResponse类
regulatory_node回复的模型
"""
platform: Optional[Literal["client"]] = Field(description="系统自动填入的platform")
platform: Optional[Literal["client", "onebot"]] = Field(description="系统自动填入的platform")
platform_id: Optional[str] = Field(description="系统自动填入的platform_id")
reply_message: str = Field(...,description="模型回复的消息")
@@ -23,12 +23,16 @@ from kilostar.core.postgres_database.model.individual import (
from kilostar.core.postgres_database.model.workflow import (
Workflow,
WorkflowContextModel,
WorkflowGraphStateModel,
)
from kilostar.core.postgres_database.model.chat_history import (
ChatHistoryRegister,
ChatHistoryMessage,
)
from kilostar.core.postgres_database.model.system_node import SystemNodeConfigModel
from kilostar.core.postgres_database.model.mcp_server import MCPServerModel
from kilostar.core.postgres_database.model.tool_config import ToolConfigModel
from kilostar.core.postgres_database.model.custom_toolset import CustomToolsetModel
# 兼容旧代码的别名
Provider = ProviderModel
@@ -49,9 +53,13 @@ __all__ = [
"SpecialIndividualModel",
"Workflow",
"WorkflowContextModel",
"WorkflowGraphStateModel",
"ChatHistoryRegister",
"ChatHistoryMessage",
"SystemNodeConfigModel",
"SystemNodeConfig",
"MCPServerModel",
"ToolConfigModel",
"CustomToolsetModel",
"AgentType",
]
@@ -0,0 +1,32 @@
from datetime import datetime
from typing import List, Optional
from sqlalchemy import String, Text, DateTime, func
from sqlalchemy.dialects.postgresql import JSONB
from sqlalchemy.orm import Mapped, mapped_column
from .base import BaseDataModel
class CustomToolsetModel(BaseDataModel):
"""用户自定义工具组:把若干个非 system / 非 mcp 的工具插件打包成一个 toolset。
``tools`` 字段保存工具名列表(即 ``plugin/tool_plugin/`` 下的目录名);
GSM 启动时按列表把对应工具函数装进同一个 ``FunctionToolset``。
"""
__tablename__ = "custom_toolset"
toolset_id: Mapped[str] = mapped_column(String(64), primary_key=True)
name: Mapped[str] = mapped_column(String(100), nullable=False)
description: Mapped[Optional[str]] = mapped_column(Text)
owner_id: Mapped[Optional[str]] = mapped_column(String(64), index=True)
tools: Mapped[List[str]] = mapped_column(
JSONB, default=list, comment="工具名列表,仅允许非 system/非 mcp 的工具"
)
created_at: Mapped[datetime] = mapped_column(
DateTime(timezone=True), server_default=func.now()
)
updated_at: Mapped[datetime] = mapped_column(
DateTime(timezone=True), server_default=func.now(), onupdate=func.now()
)
@@ -0,0 +1,29 @@
from typing import Optional
from datetime import datetime
from sqlalchemy import String, DateTime, func
from sqlalchemy.dialects.postgresql import JSONB
from sqlalchemy.orm import Mapped, mapped_column
from .base import BaseDataModel
class MCPServerModel(BaseDataModel):
"""MCP 服务器注册表,记录 stdio/sse/http 三种 transport 的连接配置。"""
__tablename__ = "mcp_server"
server_id: Mapped[str] = mapped_column(String(64), primary_key=True)
name: Mapped[str] = mapped_column(String(100), nullable=False)
transport: Mapped[str] = mapped_column(String(16), nullable=False)
command: Mapped[Optional[str]] = mapped_column(String(255))
args: Mapped[list] = mapped_column(JSONB, default=list)
url: Mapped[Optional[str]] = mapped_column(String(500))
tool_prefix: Mapped[Optional[str]] = mapped_column(String(64))
env: Mapped[dict] = mapped_column(JSONB, default=dict, comment="敏感字段已 Fernet 加密")
created_at: Mapped[datetime] = mapped_column(
DateTime(timezone=True), server_default=func.now()
)
updated_at: Mapped[datetime] = mapped_column(
DateTime(timezone=True), server_default=func.now(), onupdate=func.now()
)
@@ -0,0 +1,22 @@
from datetime import datetime
from sqlalchemy import String, DateTime, func
from sqlalchemy.dialects.postgresql import JSONB
from sqlalchemy.orm import Mapped, mapped_column
from .base import BaseDataModel
class ToolConfigModel(BaseDataModel):
"""工具运行期配置(如 Tavily API key);config 内的敏感字段已 Fernet 加密。"""
__tablename__ = "tool_config"
tool_name: Mapped[str] = mapped_column(String(100), primary_key=True)
config: Mapped[dict] = mapped_column(JSONB, default=dict)
created_at: Mapped[datetime] = mapped_column(
DateTime(timezone=True), server_default=func.now()
)
updated_at: Mapped[datetime] = mapped_column(
DateTime(timezone=True), server_default=func.now(), onupdate=func.now()
)
@@ -79,3 +79,28 @@ class WorkflowContextModel(BaseDataModel):
updated_at: Mapped[str] = mapped_column(
DateTime(timezone=True), server_default=func.now(), onupdate=func.now()
)
class WorkflowGraphStateModel(BaseDataModel):
"""pydantic_graph 持久化 blob 的存储表。
与 ``workflow_context`` 解耦——后者面向"业务展示 / 用户可读",前者面向
"graph 引擎自身的状态恢复"。一份 trace_id 一行,jsonb 直接存 history 全量。
"""
__tablename__ = "workflow_graph_state"
trace_id: Mapped[str] = mapped_column(
String(64), primary_key=True, comment="对应的工作流 Trace ID"
)
history: Mapped[list] = mapped_column(
JSONB,
default=list,
comment="pydantic_graph FullStatePersistence.history 的 JSON 序列化",
)
created_at: Mapped[str] = mapped_column(
DateTime(timezone=True), server_default=func.now()
)
updated_at: Mapped[str] = mapped_column(
DateTime(timezone=True), server_default=func.now(), onupdate=func.now()
)
@@ -0,0 +1,81 @@
from typing import Any, Dict, List, Optional
from sqlalchemy import delete, select
from kilostar.core.postgres_database.database_exception import database_exception
from kilostar.core.postgres_database.model.custom_toolset import CustomToolsetModel
class CustomToolsetDatabase:
"""用户自定义工具组 DAO。``tools`` 字段是工具名列表,业务层负责保证只放非 system/非 mcp 的工具。"""
def __init__(self, async_session_maker):
self.async_session_maker = async_session_maker
@staticmethod
def _to_dict(row: CustomToolsetModel) -> Dict[str, Any]:
return {
"toolset_id": row.toolset_id,
"name": row.name,
"description": row.description,
"owner_id": row.owner_id,
"tools": list(row.tools or []),
}
@database_exception
async def upsert(
self,
toolset_id: str,
name: str,
tools: List[str],
description: Optional[str] = None,
owner_id: Optional[str] = None,
) -> Dict[str, Any]:
async with self.async_session_maker() as session:
stmt = select(CustomToolsetModel).where(
CustomToolsetModel.toolset_id == toolset_id
)
row = (await session.execute(stmt)).scalar_one_or_none()
if row:
row.name = name
row.description = description
row.owner_id = owner_id
row.tools = list(tools)
else:
row = CustomToolsetModel(
toolset_id=toolset_id,
name=name,
description=description,
owner_id=owner_id,
tools=list(tools),
)
session.add(row)
await session.commit()
await session.refresh(row)
return self._to_dict(row)
@database_exception
async def get(self, toolset_id: str) -> Optional[Dict[str, Any]]:
async with self.async_session_maker() as session:
stmt = select(CustomToolsetModel).where(
CustomToolsetModel.toolset_id == toolset_id
)
row = (await session.execute(stmt)).scalar_one_or_none()
return self._to_dict(row) if row else None
@database_exception
async def list_all(self) -> List[Dict[str, Any]]:
async with self.async_session_maker() as session:
stmt = select(CustomToolsetModel)
rows = (await session.execute(stmt)).scalars().all()
return [self._to_dict(r) for r in rows]
@database_exception
async def delete(self, toolset_id: str) -> bool:
async with self.async_session_maker() as session:
stmt = delete(CustomToolsetModel).where(
CustomToolsetModel.toolset_id == toolset_id
)
result = await session.execute(stmt)
await session.commit()
return result.rowcount > 0
@@ -0,0 +1,79 @@
from typing import Any, Dict, List, Optional
from sqlalchemy import delete, select
from kilostar.core.postgres_database.database_exception import database_exception
from kilostar.core.postgres_database.model.mcp_server import MCPServerModel
from kilostar.utils.crypto import decrypt_dict_secrets, encrypt_dict_secrets
class MCPServerDatabase:
"""MCP 服务器配置 DAO;写入前自动加密 ``env`` 中的敏感字段,读出后自动解密。"""
def __init__(self, async_session_maker):
self.async_session_maker = async_session_maker
@staticmethod
def _to_dict(row: MCPServerModel) -> Dict[str, Any]:
return {
"server_id": row.server_id,
"name": row.name,
"transport": row.transport,
"command": row.command,
"args": row.args or [],
"url": row.url,
"tool_prefix": row.tool_prefix,
"env": decrypt_dict_secrets(row.env or {}),
}
@database_exception
async def upsert(self, server_id: str, config: Dict[str, Any]) -> Dict[str, Any]:
env = encrypt_dict_secrets(config.get("env") or {})
async with self.async_session_maker() as session:
stmt = select(MCPServerModel).where(MCPServerModel.server_id == server_id)
row = (await session.execute(stmt)).scalar_one_or_none()
if row:
row.name = config.get("name", row.name)
row.transport = config.get("transport", row.transport)
row.command = config.get("command")
row.args = config.get("args") or []
row.url = config.get("url")
row.tool_prefix = config.get("tool_prefix")
row.env = env
else:
row = MCPServerModel(
server_id=server_id,
name=config.get("name", server_id),
transport=config.get("transport", "stdio"),
command=config.get("command"),
args=config.get("args") or [],
url=config.get("url"),
tool_prefix=config.get("tool_prefix"),
env=env,
)
session.add(row)
await session.commit()
await session.refresh(row)
return self._to_dict(row)
@database_exception
async def get(self, server_id: str) -> Optional[Dict[str, Any]]:
async with self.async_session_maker() as session:
stmt = select(MCPServerModel).where(MCPServerModel.server_id == server_id)
row = (await session.execute(stmt)).scalar_one_or_none()
return self._to_dict(row) if row else None
@database_exception
async def list_all(self) -> List[Dict[str, Any]]:
async with self.async_session_maker() as session:
stmt = select(MCPServerModel)
rows = (await session.execute(stmt)).scalars().all()
return [self._to_dict(r) for r in rows]
@database_exception
async def delete(self, server_id: str) -> bool:
async with self.async_session_maker() as session:
stmt = delete(MCPServerModel).where(MCPServerModel.server_id == server_id)
result = await session.execute(stmt)
await session.commit()
return result.rowcount > 0
@@ -17,10 +17,37 @@ from typing import List
from kilostar.core.postgres_database.model.provider import ProviderModel
from sqlalchemy import select
from kilostar.core.postgres_database.database_exception import database_exception
from kilostar.utils.crypto import (
CryptoError,
decrypt_secret,
encrypt_secret,
is_encrypted,
)
from kilostar.utils.logger import get_logger
logger = get_logger("provider_dao")
def _decrypt_apikey(value):
if not value:
return value
if not is_encrypted(value):
return value
try:
return decrypt_secret(value)
except CryptoError as e:
logger.error(f"Provider apikey 解密失败: {e}")
return value
def _encrypt_apikey(value):
if not value or is_encrypted(value):
return value
return encrypt_secret(value)
class ProviderDatabase:
"""Provider 表的 DAO:模型 Provider 的增删查改。"""
"""Provider 表的 DAO:模型 Provider 的增删查改``provider_apikey`` 透明 Fernet 加解密"""
def __init__(self, async_session_maker):
self.async_session_maker = async_session_maker
@@ -37,11 +64,10 @@ class ProviderDatabase:
provider_id=provider.provider_id,
provider_title=provider.provider_title,
provider_url=provider.provider_url,
provider_apikey=provider.provider_apikey,
provider_apikey=_decrypt_apikey(provider.provider_apikey),
provider_models=provider.provider_models,
provider_type=provider.provider_type,
provider_owner=provider.provider_owner,
provider_status=provider.provider_status,
is_active=provider.is_active,
)
for provider in results
@@ -50,7 +76,9 @@ class ProviderDatabase:
@database_exception
async def add_provider(self, **kwargs) -> None:
"""新建一条 Provider 记录;字段通过 kwargs 直接传给 ProviderModel"""
"""新建一条 Provider 记录;``provider_apikey`` 写入前自动加密"""
if "provider_apikey" in kwargs:
kwargs["provider_apikey"] = _encrypt_apikey(kwargs["provider_apikey"])
async with self.async_session_maker() as session:
provider = ProviderModel(**kwargs)
session.add(provider)
@@ -67,7 +95,9 @@ class ProviderDatabase:
@database_exception
async def update_provider(self, provider_id: str, **kwargs) -> None:
"""部分更新指定 Provider 的字段;不存在时返回 None,否则返回刷新后的对象"""
"""部分更新指定 Provider 的字段;``provider_apikey`` 写入前自动加密"""
if "provider_apikey" in kwargs:
kwargs["provider_apikey"] = _encrypt_apikey(kwargs["provider_apikey"])
async with self.async_session_maker() as session:
provider = await session.get(ProviderModel, provider_id)
if provider is not None:
@@ -76,5 +106,7 @@ class ProviderDatabase:
session.add(provider)
await session.commit()
await session.refresh(provider)
# 解密返回,方便上游使用
provider.provider_apikey = _decrypt_apikey(provider.provider_apikey)
return provider
return None
@@ -0,0 +1,58 @@
from typing import Any, Dict, List, Optional
from sqlalchemy import delete, select
from kilostar.core.postgres_database.database_exception import database_exception
from kilostar.core.postgres_database.model.tool_config import ToolConfigModel
from kilostar.utils.crypto import decrypt_dict_secrets, encrypt_dict_secrets
class ToolConfigDatabase:
"""工具运行期配置 DAO;config 中的敏感字段(key/token/secret/password 系列)自动加解密。"""
def __init__(self, async_session_maker):
self.async_session_maker = async_session_maker
@staticmethod
def _to_dict(row: ToolConfigModel) -> Dict[str, Any]:
return {
"tool_name": row.tool_name,
"config": decrypt_dict_secrets(row.config or {}),
}
@database_exception
async def upsert(self, tool_name: str, config: Dict[str, Any]) -> Dict[str, Any]:
encrypted = encrypt_dict_secrets(config or {})
async with self.async_session_maker() as session:
stmt = select(ToolConfigModel).where(ToolConfigModel.tool_name == tool_name)
row = (await session.execute(stmt)).scalar_one_or_none()
if row:
row.config = encrypted
else:
row = ToolConfigModel(tool_name=tool_name, config=encrypted)
session.add(row)
await session.commit()
await session.refresh(row)
return self._to_dict(row)
@database_exception
async def get(self, tool_name: str) -> Optional[Dict[str, Any]]:
async with self.async_session_maker() as session:
stmt = select(ToolConfigModel).where(ToolConfigModel.tool_name == tool_name)
row = (await session.execute(stmt)).scalar_one_or_none()
return self._to_dict(row) if row else None
@database_exception
async def list_all(self) -> List[Dict[str, Any]]:
async with self.async_session_maker() as session:
stmt = select(ToolConfigModel)
rows = (await session.execute(stmt)).scalars().all()
return [self._to_dict(r) for r in rows]
@database_exception
async def delete(self, tool_name: str) -> bool:
async with self.async_session_maker() as session:
stmt = delete(ToolConfigModel).where(ToolConfigModel.tool_name == tool_name)
result = await session.execute(stmt)
await session.commit()
return result.rowcount > 0
@@ -17,6 +17,7 @@ from typing import List, Optional
from kilostar.core.postgres_database.model.workflow import (
Workflow,
WorkflowContextModel,
WorkflowGraphStateModel,
)
from sqlalchemy.ext.asyncio import async_sessionmaker, AsyncSession
from kilostar.core.postgres_database.database_exception import database_exception
@@ -101,3 +102,58 @@ class WorkflowDatabase:
)
results = await session.execute(statement)
return results.scalar_one_or_none()
# ─── pydantic_graph 持久化(resume 用)─────────────────────────────
@database_exception
async def upsert_workflow_graph_state(
self, trace_id: str, history: list
) -> WorkflowGraphStateModel:
"""落 pydantic_graph FullStatePersistence.history 的 JSON 视图。
每个节点边界都会被引擎调一次,覆盖式写入;回滚到任一历史点是 graph
引擎自身的能力,DB 这层只保留最新版本。
"""
async with self.async_session_maker() as session:
statement = select(WorkflowGraphStateModel).where(
WorkflowGraphStateModel.trace_id == trace_id
)
results = await session.execute(statement)
record = results.scalar_one_or_none()
if record:
record.history = history
else:
record = WorkflowGraphStateModel(
trace_id=trace_id, history=history
)
session.add(record)
await session.commit()
await session.refresh(record)
return record
@database_exception
async def get_workflow_graph_state(
self, trace_id: str
) -> Optional[WorkflowGraphStateModel]:
"""读取 graph 持久化 history;不存在返回 ``None``。"""
async with self.async_session_maker() as session:
statement = select(WorkflowGraphStateModel).where(
WorkflowGraphStateModel.trace_id == trace_id
)
results = await session.execute(statement)
return results.scalar_one_or_none()
@database_exception
async def delete_workflow_graph_state(self, trace_id: str) -> bool:
"""删除某个工作流的 graph 持久化记录(用于显式清理)。"""
async with self.async_session_maker() as session:
statement = select(WorkflowGraphStateModel).where(
WorkflowGraphStateModel.trace_id == trace_id
)
results = await session.execute(statement)
record = results.scalar_one_or_none()
if record is None:
return False
await session.delete(record)
await session.commit()
return True
+112
View File
@@ -38,6 +38,9 @@ from kilostar.core.postgres_database.model.chat_history import (
ChatHistoryMessage,
)
from kilostar.core.postgres_database.model.system_node import SystemNodeConfigModel
from kilostar.core.postgres_database.model.mcp_server import MCPServerModel
from kilostar.core.postgres_database.model.tool_config import ToolConfigModel
from kilostar.core.postgres_database.model.custom_toolset import CustomToolsetModel
from .module.individual import IndividualDatabase
from .module.user import AuthDatabase
@@ -45,6 +48,9 @@ from .module.provider import ProviderDatabase
from .module.system_node import SystemNodeDatabase
from .module.workflow import WorkflowDatabase
from .module.chat_history import ChatHistoryDatabase
from .module.mcp_server import MCPServerDatabase
from .module.tool_config import ToolConfigDatabase
from .module.custom_toolset import CustomToolsetDatabase
@ray.remote
@@ -76,6 +82,9 @@ class PostgresDatabase:
self._system_node_database = SystemNodeDatabase(self.async_session_maker)
self._workflow_database = WorkflowDatabase(self.async_session_maker)
self._chat_history_database = ChatHistoryDatabase(self.async_session_maker)
self._mcp_server_database = MCPServerDatabase(self.async_session_maker)
self._tool_config_database = ToolConfigDatabase(self.async_session_maker)
self._custom_toolset_database = CustomToolsetDatabase(self.async_session_maker)
self.ready_event = asyncio.Event()
@@ -91,6 +100,15 @@ class PostgresDatabase:
finally:
self.ready_event.set()
async def ping(self) -> bool:
"""轻量探活:等待 ready 后执行 ``SELECT 1``。"""
from sqlalchemy import text
await self.ready_event.wait()
async with self.async_engine.connect() as conn:
await conn.execute(text("SELECT 1"))
return True
# Auth Database Methods
async def add_user(self, user_name: str, hashed_password: str):
"""新建一名用户。"""
@@ -242,6 +260,24 @@ class PostgresDatabase:
await self.ready_event.wait()
return await self._workflow_database.get_workflow_context(trace_id)
# Workflow Graph State (pydantic_graph 持久化)
async def upsert_workflow_graph_state(self, trace_id: str, history: list):
"""覆盖式写入 graph 持久化 historypydantic_graph 节点边界自动调用)。"""
await self.ready_event.wait()
return await self._workflow_database.upsert_workflow_graph_state(
trace_id, history
)
async def get_workflow_graph_state(self, trace_id: str):
"""读取 graph 持久化记录,用于跨进程 resume。"""
await self.ready_event.wait()
return await self._workflow_database.get_workflow_graph_state(trace_id)
async def delete_workflow_graph_state(self, trace_id: str):
"""显式清理 graph 持久化记录(已完成/失败的 workflow 释放空间)。"""
await self.ready_event.wait()
return await self._workflow_database.delete_workflow_graph_state(trace_id)
# Chat History Database Methods
async def create_chat_session(self, user_id: str, title: str = "新对话"):
"""新建一个聊天会话。"""
@@ -264,3 +300,79 @@ class PostgresDatabase:
"""返回某个聊天会话的全部消息。"""
await self.ready_event.wait()
return await self._chat_history_database.list_chat_messages(chat_id)
# MCP Server Database Methods
async def upsert_mcp_server(self, server_id: str, config: dict):
"""插入或更新一条 MCP 服务器配置;env 中敏感字段自动加密。"""
await self.ready_event.wait()
return await self._mcp_server_database.upsert(server_id, config)
async def get_mcp_server_db(self, server_id: str):
"""读取单条 MCP 服务器配置;env 自动解密。"""
await self.ready_event.wait()
return await self._mcp_server_database.get(server_id)
async def list_mcp_servers_db(self):
"""读取全部 MCP 服务器配置。"""
await self.ready_event.wait()
return await self._mcp_server_database.list_all()
async def delete_mcp_server_db(self, server_id: str):
"""删除某条 MCP 服务器配置。"""
await self.ready_event.wait()
return await self._mcp_server_database.delete(server_id)
# Tool Config Database Methods
async def upsert_tool_config(self, tool_name: str, config: dict):
"""插入或更新某工具的运行期配置;敏感字段自动加密。"""
await self.ready_event.wait()
return await self._tool_config_database.upsert(tool_name, config)
async def get_tool_config_db(self, tool_name: str):
"""读取某工具的运行期配置;敏感字段自动解密。"""
await self.ready_event.wait()
return await self._tool_config_database.get(tool_name)
async def list_tool_configs_db(self):
"""读取全部工具的运行期配置。"""
await self.ready_event.wait()
return await self._tool_config_database.list_all()
async def delete_tool_config_db(self, tool_name: str):
"""删除某工具的运行期配置。"""
await self.ready_event.wait()
return await self._tool_config_database.delete(tool_name)
# Custom Toolset Database Methods
async def upsert_custom_toolset(
self,
toolset_id: str,
name: str,
tools: list,
description: str = None,
owner_id: str = None,
):
"""插入或更新一个用户自定义工具组。"""
await self.ready_event.wait()
return await self._custom_toolset_database.upsert(
toolset_id=toolset_id,
name=name,
tools=tools,
description=description,
owner_id=owner_id,
)
async def get_custom_toolset(self, toolset_id: str):
"""按 ID 读取一个自定义工具组。"""
await self.ready_event.wait()
return await self._custom_toolset_database.get(toolset_id)
async def list_custom_toolsets(self):
"""读取全部自定义工具组。"""
await self.ready_event.wait()
return await self._custom_toolset_database.list_all()
async def delete_custom_toolset(self, toolset_id: str):
"""删除一个自定义工具组。"""
await self.ready_event.wait()
return await self._custom_toolset_database.delete(toolset_id)
@@ -0,0 +1,191 @@
# Copyright 2026 zhaoxi826
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Postgres 后端的 ``BaseStatePersistence`` 实现,让 graph 跨进程 resume。
设计思路:
- **复用 ``FullStatePersistence`` 的内存语义**snapshot/record_run/load_next 等
的实现已经做得很完备(NodeSnapshot 状态机、deep_copy、type adapter),不重写
这些细节,只是在每次"史会发生变更"的钩子触发后,把内存 history 序列化为 JSON
写到 ``workflow_graph_state`` 表里。
- **异步 IO 解耦**DB 写入是 fire-and-forget 模式,不在 graph 节点路径上阻塞——
graph 跑得快,持久化追得上即可。但 ``snapshot_end`` 一定会 await(确保关机
之前最终 history 落盘)。
- **resume 入口**:从 DB 读出 history JSON → ``load_json`` 还原 → ``Graph.iter_from_persistence``
跑剩余节点。
这一层不直接持有 SQLAlchemy session;通过两个 awaitable 注入 IO,便于测试。
"""
from __future__ import annotations
import asyncio
from contextlib import asynccontextmanager
from dataclasses import dataclass, field
from typing import Any, AsyncIterator, Awaitable, Callable, Optional
from pydantic_graph import BaseNode, End
from pydantic_graph.persistence import (
BaseStatePersistence,
NodeSnapshot,
Snapshot,
)
from pydantic_graph.persistence.in_mem import FullStatePersistence
from kilostar.utils.logger import get_logger
_logger = get_logger("graph_persistence")
# IO 注入签名:写一份 history JSON / 读一份 history JSONNone=没有)
WriteHistory = Callable[[str, Any], Awaitable[None]]
ReadHistory = Callable[[str], Awaitable[Optional[Any]]]
@dataclass
class PostgresStatePersistence(BaseStatePersistence):
"""复用 ``FullStatePersistence`` 内存语义 + 把 history 落 postgres。
每个 hook 触发后异步把 history 写库,DB 失败不影响 graph 继续推进
(只记 warning),保证 graph 自身的可用性。
"""
trace_id: str
write_history: WriteHistory
read_history: ReadHistory
_inner: FullStatePersistence = field(default_factory=FullStatePersistence)
# ─── BaseStatePersistence 接口 ──────────────────────────────────
async def snapshot_node(self, state, next_node):
await self._inner.snapshot_node(state, next_node)
await self._flush()
async def snapshot_node_if_new(self, snapshot_id, state, next_node):
await self._inner.snapshot_node_if_new(snapshot_id, state, next_node)
await self._flush()
async def snapshot_end(self, state, end):
await self._inner.snapshot_end(state, end)
# graph 已结束:必须确保最终 snapshot 落盘后再返回
await self._flush(must_succeed=True)
@asynccontextmanager
async def record_run(self, snapshot_id: str) -> AsyncIterator[None]:
async with self._inner.record_run(snapshot_id):
yield
# record_run 退出时 NodeSnapshot 状态从 running → success/error,需要刷盘
await self._flush()
async def load_next(self) -> Optional[NodeSnapshot]:
return await self._inner.load_next()
async def load_all(self) -> list[Snapshot]:
return await self._inner.load_all()
def should_set_types(self) -> bool:
return self._inner.should_set_types()
def set_types(self, state_type, run_end_type):
self._inner.set_types(state_type, run_end_type)
# ─── 序列化 / 反序列化 ─────────────────────────────────────────
async def hydrate(self) -> bool:
"""从 DB 拉一次 history 并恢复到内存;返回是否拉到了内容。
在 ``run_workflow_task`` 决定 fresh/resume 时调用。``set_types`` 必须
在调用前由 ``Graph.iter_from_persistence`` 替我们调过——否则
``_snapshots_type_adapter`` 还没准备好就会 assert 失败。
"""
try:
raw = await self.read_history(self.trace_id)
except Exception as e: # pragma: no cover - 防御
_logger.warning(f"hydrate read failed: {e}")
return False
if not raw:
return False
try:
self._inner.load_json(_to_json_bytes(raw))
except AssertionError:
# 没 set_types 时 load_json 会 assert;调用方需先调 set_graph_types
raise
except Exception as e:
_logger.warning(f"hydrate load_json failed: {e}")
return False
return True
async def _flush(self, *, must_succeed: bool = False) -> None:
"""把内存 history 序列化后异步写 DB。
``must_succeed=False`` 时 DB 异常仅记 warning``True`` 时再抛出,
让 ``snapshot_end`` 这种"必须落盘"的场景能感知失败。
"""
if self._inner._snapshots_type_adapter is None:
# 类型还没注册,没法 dump(首次 set_graph_types 还没跑过)
return
try:
blob = self._inner.dump_json()
except Exception as e: # pragma: no cover
_logger.warning(f"dump history failed: {e}")
return
try:
await self.write_history(self.trace_id, _from_json_bytes(blob))
except Exception as e:
_logger.warning(f"persist history failed: {e}")
if must_succeed:
raise
def _to_json_bytes(value: Any) -> bytes:
"""把 DB 读出的 ``list[dict]`` / ``str`` / ``bytes`` 都规范成 bytes 喂 load_json。"""
if isinstance(value, (bytes, bytearray)):
return bytes(value)
if isinstance(value, str):
return value.encode("utf-8")
# 假定是 list/dict(来自 JSONB),转回 JSON 字符串
import json as _json
return _json.dumps(value).encode("utf-8")
def _from_json_bytes(blob: bytes) -> Any:
"""把 ``dump_json`` 出来的 bytes 转成 list/dict 以便 JSONB 友好存储。"""
import json as _json
return _json.loads(blob.decode("utf-8"))
def build_postgres_persistence(trace_id: str) -> PostgresStatePersistence:
"""生产环境构造 PostgresStatePersistence:从 ray_actor_hook 取 postgres handle。"""
from kilostar.utils.ray_hook import ray_actor_hook
postgres_database = ray_actor_hook("postgres_database").postgres_database
async def _write(tid: str, history: Any) -> None:
await postgres_database.upsert_workflow_graph_state.remote(tid, history)
async def _read(tid: str) -> Optional[Any]:
record = await postgres_database.get_workflow_graph_state.remote(tid)
if record is None:
return None
# ORM 模型 / dict / list 都兼容
return getattr(record, "history", None) or record
return PostgresStatePersistence(
trace_id=trace_id,
write_history=_write,
read_history=_read,
)
+13 -1
View File
@@ -13,7 +13,7 @@
# limitations under the License.
from pydantic import BaseModel, Field, model_validator
from typing import Optional, Union, List, Dict, Any
from typing import Literal, Optional, Union, List, Dict, Any
from .model import LogicGate, WorkflowMetadata, WorkStepStatus, WorkflowStatus
from ulid import ULID
from datetime import datetime
@@ -61,10 +61,22 @@ class WorkflowStep(BaseModel):
default=None, description="前置依赖输出"
)
outputs: Optional[str] = Field(default=None, description="当前步骤产出物变量名")
node: Literal["skill_individual", "consciousness_node"] = Field(
default="skill_individual",
description=(
"执行此步的节点类别:\n"
"- skill_individualtask 内现起一个专家子个体执行(一次性)\n"
"- consciousness_node:远程调用全局 ConsciousnessNode actor"
),
)
agent_id: Optional[str] = Field(
default=None,
description="分配给 skill_individual 的 Skill Individual 真实 agent_id,不可用名称代替",
)
require_approval: bool = Field(
default=False,
description="该步执行前是否需要人工审批;启用时会暂停工作流并通过 SSE 等待用户回执",
)
logic_gate: Optional[LogicGate] = Field(default=None, description="逻辑跳转控制")
+478 -123
View File
@@ -12,166 +12,521 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""Workflow 引擎:基于 ``pydantic_graph`` 的状态机驱动。
调度路径只剩两类节点:
- ``skill_individual``:在当前 ray task 进程内现起一个 ``SkillIndividual``
执行;用后即焚,不消耗 actor。
- ``consciousness_node``:远程调用全局 ConsciousnessNode actor 的
``working`` 方法(中途指导 / 审查类工作)。
每一步执行前可设置 ``require_approval=True`` 触发 ``HumanApproval`` 节点:
推送 SSE → ``await gwm.get_received`` 阻塞等用户回执 → 决策 continue/abort。
Graph 还接了 pydantic_graph 的 ``FullStatePersistence``,目前主要用于
节点边界自动 snapshotpostgres 持久化保留旧 ``upsert_workflow_context``
路径,跨进程 resume 留到后续)。
"""
from __future__ import annotations
import asyncio
import datetime
from dataclasses import dataclass
from typing import Any, Awaitable, Callable, Dict, List, Optional
import ray
from kilostar.core.work.workflow.workflow import KiloStarWorkflow
from typing import Dict, Any, List
from pydantic import BaseModel, Field
from pydantic_graph import BaseNode, End, Graph, GraphRunContext
from pydantic_graph.persistence import BaseStatePersistence
from pydantic_graph.persistence.in_mem import FullStatePersistence
from kilostar.core.work.workflow.model import WorkflowStatus
@ray.remote
def run_workflow_task(workflow_data: dict, trace_id: str):
from kilostar.utils.ray_hook import ray_actor_hook
from kilostar.core.work.workflow.model import WorkflowStatus
import datetime
from pydantic import BaseModel
# ─── State / Deps ─────────────────────────────────────────────────────────
# State passed through graph nodes
class WorkflowGraphState(BaseModel):
trace_id: str
blackboard: Dict[str, Any]
work_link: List[Dict[str, Any]]
current_step_index: int = 0
status: str = "running"
logs: List[Dict[str, Any]] = []
async def save_context(state: WorkflowGraphState):
postgres_database = ray_actor_hook("postgres_database").postgres_database
await postgres_database.upsert_workflow_context.remote(
state.trace_id,
workflow_pointer=state.current_step_index,
blackboard=state.blackboard,
work_link=state.work_link,
workflow_status={str(datetime.datetime.now()): state.status},
workflow_log=state.logs,
)
await postgres_database.update_workflow_status.remote(
state.trace_id, state.status
)
global_workflow_manager = ray_actor_hook(
"global_workflow_manager"
).global_workflow_manager
await global_workflow_manager.put_received.remote(
state.trace_id, f"执行步骤 {state.current_step_index + 1}..."
class WorkflowGraphState(BaseModel):
"""图运行期跨节点共享的状态。"""
trace_id: str
blackboard: Dict[str, Any] = Field(default_factory=dict)
work_link: List[Dict[str, Any]] = Field(default_factory=list)
current_step_index: int = 0
final_status: str = WorkflowStatus.RUNNING.value
logs: List[Dict[str, Any]] = Field(default_factory=list)
original_command: str = ""
# 已发过 put_pending 的 HumanApproval step index 列表;resume 后避免重复推送。
# 用 list(不是 set)是为了 pydantic_graph 序列化 history 时 JSON 友好。
approvals_notified: List[int] = Field(default_factory=list)
# 业务侧执行入口:把 step + state 喂进去,拿到 (output_text, success_bool)
StepExecutor = Callable[
[Dict[str, Any], "WorkflowGraphState"], Awaitable[tuple[str, bool]]
]
@dataclass
class WorkflowDeps:
"""节点运行期依赖:所有外部 IO 都从这里走,便于测试 mock。
每个字段都是一个 awaitable,签名贴近原 ``.remote()`` 调用。生产路径由
``build_default_deps`` 现场组装真实 actor handle 包装;单测可以传任意
``AsyncMock``。
``run_skill`` / ``run_consciousness`` 是把 step 派发到具体执行器的入口,
抽出来既能让 graph 节点保持纯逻辑,又便于测试无需真起 SkillIndividual。
"""
upsert_workflow_context: Callable[..., Awaitable[Any]]
update_workflow_status: Callable[[str, str], Awaitable[Any]]
put_pending: Callable[[str, str], Awaitable[Any]]
get_received: Callable[[str], Awaitable[str]]
run_skill: StepExecutor
run_consciousness: StepExecutor
# ─── 节点 ─────────────────────────────────────────────────────────────────
@dataclass
class Initialize(BaseNode[WorkflowGraphState, WorkflowDeps, str]):
"""图入口节点:把 workflow 标记为 RUNNING,发首条 SSE 提示。"""
async def run(
self, ctx: GraphRunContext[WorkflowGraphState, WorkflowDeps]
) -> "Dispatch":
await ctx.deps.update_workflow_status(
ctx.state.trace_id, WorkflowStatus.RUNNING.value
)
await _persist_context(ctx, status=WorkflowStatus.RUNNING.value)
return Dispatch()
async def execute_step(state: WorkflowGraphState):
"""执行单一工作流节点逻辑"""
@dataclass
class Dispatch(BaseNode[WorkflowGraphState, WorkflowDeps, str]):
"""读取当前 step,按 ``node`` / ``require_approval`` 字段选择下一节点。"""
async def run(
self,
ctx: GraphRunContext[WorkflowGraphState, WorkflowDeps],
) -> "HumanApproval | SkillStep | ConsciousnessStep | Finalize":
state = ctx.state
if state.current_step_index >= len(state.work_link):
state.status = WorkflowStatus.COMPLETED
return state
return Finalize(status=WorkflowStatus.COMPLETED.value)
step = state.work_link[state.current_step_index]
step.get("node", "")
action = step.get("action", "")
step_data = state.work_link[state.current_step_index]
if step_data.get("require_approval"):
return HumanApproval()
# 记录开始状态
node_type = step_data.get("node") or "skill_individual"
if node_type == "consciousness_node":
return ConsciousnessStep()
if node_type == "skill_individual":
return SkillStep()
# 未识别的 node 类型按失败处理(保守:不静默吞)
state.logs.append(
{
str(state.current_step_index): [
str(datetime.datetime.now()),
"working",
f"开始执行: {step.get('name', '未命名步骤')}",
"failed",
f"未知节点类型: {node_type}",
]
}
)
await save_context(state)
await _persist_context(ctx, status=WorkflowStatus.FAILED.value)
return Finalize(status=WorkflowStatus.FAILED.value)
try:
# TODO: 实际对接不同节点执行逻辑 (例如: control_node, agent 技能)
# 这里是简化版,向控制节点或指定 skill 发送指令
# ... 模拟执行逻辑 ...
await asyncio.sleep(2)
@dataclass
class HumanApproval(BaseNode[WorkflowGraphState, WorkflowDeps, str]):
"""人工审批节点:暂停 graph,等用户通过 SSE 回执决策。
# 记录结果
state.blackboard[
step.get("outputs", f"step_{state.current_step_index}_result")
] = "Success execution of " + action
state.logs[-1][str(state.current_step_index)] = [
str(datetime.datetime.now()),
"completed",
f"成功: {action}",
]
回执约定(轻量协议,沿用现有 SSE 通道):
# 判断逻辑跳转
logic_gate = step.get("logic_gate")
if logic_gate and logic_gate.get("if_pass") == "exit":
state.status = WorkflowStatus.COMPLETED
else:
state.current_step_index += 1
- 含 ``approve`` / ``yes`` / ``ok`` 视为通过,回到 Dispatch 继续执行
- 其它(包括 ``reject``/``no``/``abort``)视为拒绝,工作流终止为 FAILED
"""
except Exception as e:
state.logs[-1][str(state.current_step_index)] = [
str(datetime.datetime.now()),
"failed",
str(e),
]
state.status = WorkflowStatus.FAILED
logic_gate = step.get("logic_gate")
if logic_gate and logic_gate.get("if_fail"):
fail_target = logic_gate.get("if_fail")
if "jump_to_step_" in fail_target:
target_step = int(fail_target.split("_")[-1]) - 1
state.current_step_index = target_step
state.status = WorkflowStatus.RUNNING
async def run(
self,
ctx: GraphRunContext[WorkflowGraphState, WorkflowDeps],
) -> "Dispatch | Finalize":
state = ctx.state
step_data = state.work_link[state.current_step_index]
# idempotent 推送:仅当本 step 还没通知过时才发 put_pending。
# 这样 resume 场景(HumanApproval 节点被重新进入)不会给前端发重复消息。
if state.current_step_index not in state.approvals_notified:
await ctx.deps.put_pending(
state.trace_id,
f"步骤 {state.current_step_index + 1} ({step_data.get('name', '')}) "
f"需要人工审批,请回复 approve / reject。",
)
state.approvals_notified.append(state.current_step_index)
await _persist_context(ctx, status=WorkflowStatus.HANGUP.value)
await save_context(state)
return state
reply = (await ctx.deps.get_received(state.trace_id) or "").strip().lower()
if any(token in reply for token in ("approve", "yes", "ok")):
# 把 require_approval 置否避免无限循环重新进 HumanApproval
step_data["require_approval"] = False
await _persist_context(ctx, status=WorkflowStatus.RUNNING.value)
return Dispatch()
async def _run():
postgres_database = ray_actor_hook("postgres_database").postgres_database
await postgres_database.update_workflow_status.remote(
trace_id, WorkflowStatus.RUNNING
state.logs.append(
{
str(state.current_step_index): [
str(datetime.datetime.now()),
"rejected",
f"用户拒绝执行该步骤: {reply!r}",
]
}
)
await _persist_context(ctx, status=WorkflowStatus.FAILED.value)
return Finalize(status=WorkflowStatus.FAILED.value)
state = WorkflowGraphState(
trace_id=trace_id,
blackboard={},
work_link=workflow_data.get("work_link", []),
)
await save_context(state)
# 简单的图执行驱动 (模拟 pydantic-ai.graph.run 行为,直至 Graph 库正式稳定)
while state.status == WorkflowStatus.RUNNING and state.current_step_index < len(
state.work_link
):
state = await execute_step(state)
@dataclass
class SkillStep(BaseNode[WorkflowGraphState, WorkflowDeps, str]):
"""skill_individual 路径:当前进程内拉一个专家子个体执行该步。"""
await postgres_database.update_workflow_status.remote(trace_id, state.status)
global_workflow_manager = ray_actor_hook(
"global_workflow_manager"
).global_workflow_manager
async def run(
self,
ctx: GraphRunContext[WorkflowGraphState, WorkflowDeps],
) -> "Dispatch | Finalize":
return await _execute_step(ctx, executor=ctx.deps.run_skill)
@dataclass
class ConsciousnessStep(BaseNode[WorkflowGraphState, WorkflowDeps, str]):
"""consciousness_node 路径:远程调用 ConsciousnessNode actor 处理该步。"""
async def run(
self,
ctx: GraphRunContext[WorkflowGraphState, WorkflowDeps],
) -> "Dispatch | Finalize":
return await _execute_step(ctx, executor=ctx.deps.run_consciousness)
@dataclass
class Finalize(BaseNode[WorkflowGraphState, WorkflowDeps, str]):
"""收尾节点:写最终状态、推送最终 SSE,把 workflow 终态作为 graph 输出。"""
status: str
async def run(
self, ctx: GraphRunContext[WorkflowGraphState, WorkflowDeps]
) -> End[str]:
ctx.state.final_status = self.status
await ctx.deps.update_workflow_status(ctx.state.trace_id, self.status)
msg = (
"工作流执行完成!"
if state.status == WorkflowStatus.COMPLETED
if self.status == WorkflowStatus.COMPLETED.value
else "工作流执行失败。"
)
await global_workflow_manager.put_received.remote(trace_id, msg)
await ctx.deps.put_pending(ctx.state.trace_id, msg)
return End(self.status)
asyncio.run(_run())
# ─── 内部 helper ─────────────────────────────────────────────────────────
async def _persist_context(
ctx: GraphRunContext[WorkflowGraphState, WorkflowDeps], *, status: str
) -> None:
"""把当前 state 落库到 workflow_context 表(覆盖式写入)。"""
await ctx.deps.upsert_workflow_context(
ctx.state.trace_id,
workflow_pointer=ctx.state.current_step_index,
blackboard=ctx.state.blackboard,
work_link=ctx.state.work_link,
workflow_status={str(datetime.datetime.now()): status},
workflow_log=ctx.state.logs,
)
async def _execute_step(
ctx: GraphRunContext[WorkflowGraphState, WorkflowDeps],
*,
executor: StepExecutor,
) -> "Dispatch | Finalize":
"""SkillStep / ConsciousnessStep 共享的执行骨架。
"日志/SSE/blackboard 更新/逻辑闸门"这些跨节点共性逻辑抽出来;具体怎么
跑这一步交给 ``executor`` 决定(生产是 SkillIndividual.run / actor.working
远程调用,测试可以直接传 lambda)。
"""
state = ctx.state
step_data = state.work_link[state.current_step_index]
state.logs.append(
{
str(state.current_step_index): [
str(datetime.datetime.now()),
"working",
f"开始执行: {step_data.get('name', '未命名步骤')}",
]
}
)
await _persist_context(ctx, status=WorkflowStatus.RUNNING.value)
await ctx.deps.put_pending(
state.trace_id, f"执行步骤 {state.current_step_index + 1}..."
)
try:
output_text, success = await executor(step_data, state)
except Exception as e: # 执行器抛异常 → 走失败分支
output_text, success = str(e), False
if success:
output_key = step_data.get(
"outputs", f"step_{state.current_step_index}_result"
)
state.blackboard[output_key] = output_text
state.logs[-1][str(state.current_step_index)] = [
str(datetime.datetime.now()),
"completed",
f"成功: {step_data.get('action', '')}",
]
await _persist_context(ctx, status=WorkflowStatus.RUNNING.value)
logic_gate = step_data.get("logic_gate") or {}
if logic_gate.get("if_pass") == "exit":
return Finalize(status=WorkflowStatus.COMPLETED.value)
state.current_step_index += 1
if state.current_step_index >= len(state.work_link):
return Finalize(status=WorkflowStatus.COMPLETED.value)
return Dispatch()
# 失败:if_fail 跳转优先于直接收尾
state.logs[-1][str(state.current_step_index)] = [
str(datetime.datetime.now()),
"failed",
output_text,
]
logic_gate = step_data.get("logic_gate") or {}
fail_target = logic_gate.get("if_fail")
if fail_target and "jump_to_step_" in fail_target:
target_step = int(fail_target.split("_")[-1]) - 1
state.current_step_index = target_step
await _persist_context(ctx, status=WorkflowStatus.RUNNING.value)
return Dispatch()
await _persist_context(ctx, status=WorkflowStatus.FAILED.value)
return Finalize(status=WorkflowStatus.FAILED.value)
# ─── 图定义 ───────────────────────────────────────────────────────────────
workflow_graph: Graph[WorkflowGraphState, WorkflowDeps, str] = Graph(
nodes=[Initialize, Dispatch, HumanApproval, SkillStep, ConsciousnessStep, Finalize],
state_type=WorkflowGraphState,
run_end_type=str,
)
# ─── 默认执行器:把 step 派发到 SkillIndividual / ConsciousnessNode ────
async def _default_skill_executor(
step_data: Dict[str, Any], state: WorkflowGraphState
) -> tuple[str, bool]:
"""生产环境的 skill_individual 派发器:当前 task 进程现起 agent 执行。
每步现起一个 ``SkillIndividual`` 跑完即销毁,不绑定 actor 寿命。``agent_id``
是必须的(用于从 GSM 拉到该子个体的配置)。
"""
from kilostar.worker_individual.skill_individual import SkillIndividual
from kilostar.core.global_state_machine.gsm_snapshot import fetch_snapshot
agent_id = step_data.get("agent_id")
if not agent_id:
return "skill_individual 步骤缺少 agent_id", False
snapshot = await fetch_snapshot()
agent_config = snapshot.individuals.get(agent_id)
if not agent_config:
return f"未找到 agent_id={agent_id} 的专家子个体", False
individual = SkillIndividual(dict(agent_config))
task_event = {
"step": step_data,
"blackboard": state.blackboard,
"original_command": state.original_command,
}
result = await individual.run(task_event)
output = (
result.get("output", "") if isinstance(result, dict) else str(result)
)
return output or "(empty)", True
async def _default_consciousness_executor(
step_data: Dict[str, Any], state: WorkflowGraphState
) -> tuple[str, bool]:
"""生产环境的 consciousness 派发器:远程调用 ConsciousnessNode.working。"""
from kilostar.utils.ray_hook import ray_actor_hook
from kilostar.core.individual.consciousness_node.template import (
ForWorkflow,
ForWorkflowInput,
)
from kilostar.core.work.workflow.workflow import WorkflowStep
consciousness_node = ray_actor_hook("consciousness_node").consciousness_node
payload = ForWorkflowInput(
workflow_step=WorkflowStep.model_validate(step_data),
original_command=state.original_command,
)
result = await consciousness_node.working.remote(payload)
if isinstance(result, ForWorkflow):
return result.output, True
if result is None:
return "ConsciousnessNode 返回 None", False
return f"ConsciousnessNode 返回未知类型: {type(result).__name__}", False
def build_default_deps() -> WorkflowDeps:
"""生产环境构造 ``WorkflowDeps``:把 ray actor handle 包装成 awaitable。
抽出来是为了让 ``run_workflow_task`` 入口和测试入口共享同一套包装逻辑。
"""
from kilostar.utils.ray_hook import ray_actor_hook
postgres_database = ray_actor_hook("postgres_database").postgres_database
global_workflow_manager = ray_actor_hook(
"global_workflow_manager"
).global_workflow_manager
async def _upsert_workflow_context(trace_id: str, **kwargs: Any) -> Any:
return await postgres_database.upsert_workflow_context.remote(
trace_id, **kwargs
)
async def _update_workflow_status(trace_id: str, status: str) -> Any:
return await postgres_database.update_workflow_status.remote(
trace_id, status
)
async def _put_pending(trace_id: str, message: str) -> Any:
return await global_workflow_manager.put_pending.remote(trace_id, message)
async def _get_received(trace_id: str) -> str:
return await global_workflow_manager.get_received.remote(trace_id)
return WorkflowDeps(
upsert_workflow_context=_upsert_workflow_context,
update_workflow_status=_update_workflow_status,
put_pending=_put_pending,
get_received=_get_received,
run_skill=_default_skill_executor,
run_consciousness=_default_consciousness_executor,
)
async def run_workflow_graph(
workflow_data: Dict[str, Any],
trace_id: str,
*,
deps: Optional[WorkflowDeps] = None,
persistence: Optional[BaseStatePersistence] = None,
) -> str:
"""在当前事件循环里跑一遍 workflow graph,返回 workflow 终态字符串。
Args:
workflow_data: ``KiloStarWorkflow.model_dump()`` 出来的 dict
trace_id: 工作流追踪 id
deps: 缺省时通过 ``build_default_deps`` 现场构造(生产路径)
persistence: 缺省给一个 ``FullStatePersistence`` 让 graph 自动在节点
边界 snapshot;外部传入则共享同一份持久化(便于诊断 / 后续 resume)
"""
if deps is None:
deps = build_default_deps()
if persistence is None:
persistence = FullStatePersistence()
state = WorkflowGraphState(
trace_id=trace_id,
blackboard={},
work_link=list(workflow_data.get("work_link", []) or []),
original_command=(workflow_data.get("workflow_metadata") or {}).get(
"command", ""
)
or "",
)
result = await workflow_graph.run(
Initialize(),
state=state,
deps=deps,
persistence=persistence,
)
return result.output
async def resume_workflow_graph(
trace_id: str,
*,
deps: Optional[WorkflowDeps] = None,
persistence: Optional[BaseStatePersistence] = None,
) -> str:
"""从持久化里恢复一个工作流,跑剩余节点直至 End。
要求 ``persistence`` 已经事先 ``hydrate``(生产路径用 ``PostgresStatePersistence.hydrate``)。
本函数只负责调 ``Graph.iter_from_persistence`` 把剩下的节点跑完。
"""
if persistence is None:
raise ValueError("resume 必须显式传入 persistence")
if deps is None:
deps = build_default_deps()
final_output: str = WorkflowStatus.RUNNING.value
async with workflow_graph.iter_from_persistence(
persistence, deps=deps
) as run:
async for node in run:
if isinstance(node, End):
final_output = node.data
break
return final_output
@ray.remote
class WorkflowRunningEngine:
def __init__(
self, consciousness_node=None, control_node=None, regulatory_node=None
):
self.consciousness_node = consciousness_node
self.control_node = control_node
self.regulatory_node = regulatory_node
self.events_queue = asyncio.Queue()
def run_workflow_task(workflow_data: dict, trace_id: str):
"""workflow 的 ray task 入口:一次性执行,跑完即销毁。
async def put_event(self, event):
await self.events_queue.put(event)
生产路径下持久化交给 ``PostgresStatePersistence`` —— 即便进程崩溃,再 fire
一次相同 ``trace_id`` 的任务(或调 ``/workflow/{trace_id}/resume``)即可
续跑。同时为了支持 fresh start,先尝试 ``hydrate``
- hydrate 拿到内容 → 走 resume 路径
- hydrate 没拿到 → 走全新路径
async def run(self):
"""引擎循环提取事件"""
while True:
await self.events_queue.get()
await asyncio.sleep(1)
ray task 是新进程,contextvars 不会从 caller 传过来,所以入口先 bind 一次
``trace_id``,让节点内的日志自动带上它。
"""
from kilostar.utils.request_context import trace_id_scope
from kilostar.core.work.workflow.graph_persistence import (
build_postgres_persistence,
)
async def execute_workflow(self, workflow: KiloStarWorkflow):
# 这个方法可以由意识节点调用来提交一个完整的运行任务
workflow_dict = workflow.model_dump()
trace_id = workflow.trace_id
run_workflow_task.remote(workflow_dict, trace_id)
async def _entry() -> None:
with trace_id_scope(trace_id):
persistence = build_postgres_persistence(trace_id)
persistence.set_graph_types(workflow_graph)
recovered = False
try:
recovered = await persistence.hydrate()
except Exception: # pragma: no cover - 防御
recovered = False
if recovered:
await resume_workflow_graph(trace_id, persistence=persistence)
else:
await run_workflow_graph(
workflow_data, trace_id, persistence=persistence
)
asyncio.run(_entry())
@@ -28,10 +28,10 @@ class ApprovalToolData(BaseToolData):
"regulatory_node",
"growth_node",
"",
"",
]
] = ["control_node", "consciousness_node"]
config_args: Dict[str, str] = {}
category: str = "system"
async def approval(message: str, trace_id: str) -> str:
+2 -1
View File
@@ -29,7 +29,8 @@ class BaseToolData(BaseModel):
"regulatory_node",
"growth_node",
"",
"",
]
] = []
config_args: Dict[str, str] = {}
category: str = "other"
"""工具分类:system(系统内置)、search(搜索)、mcp(MCP 服务器)、other(其他)"""
@@ -12,6 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from .file_reader import FileReaderData, file_reader
from .file_reader import FileReaderToolData, file_reader
__all__ = ["FileReaderData", "file_reader"]
__all__ = ["FileReaderToolData", "file_reader"]
@@ -12,36 +12,45 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from pydantic_ai import RunContext
"""File Reader Tool Plugin for KiloStar.
Reads the contents of a file from the local filesystem.
"""
from kilostar.plugin.tool_plugin.base_tool import BaseToolData
import os
from typing import List, Literal, Dict
class FileReaderData(BaseToolData):
"""``file_reader`` 工具的元数据:声明工具的名称、描述与是否系统级别"""
class FileReaderToolData(BaseToolData):
"""``file_reader`` 工具的元数据。"""
is_system: bool = True
name: str = "file_reader"
description: str = "读取本地文件的内容"
action_scope: List[
Literal[
"control_node",
"consciousness_node",
"regulatory_node",
"growth_node",
"",
]
] = ["control_node"]
config_args: Dict[str, str] = {}
category: str = "system"
def file_reader(ctx: RunContext, filepath: str) -> str:
"""读取本地文件内容的工具
async def file_reader(file_path: str) -> str:
"""读取本地文件内容。
Args:
filepath: 目标文件的绝对路径或相对路径
file_path: 文件的绝对路径或相对路径
Returns:
如果文件存在并可读,返回文件内容;否则返回错误信息
文件内容文本,若文件不存在则返回错误信息
"""
if not os.path.exists(filepath):
return f"Error: 文件 {filepath} 不存在。"
if not os.path.isfile(filepath):
return f"Error: {filepath} 不是一个文件。"
try:
with open(filepath, "r", encoding="utf-8") as f:
content = f.read()
return content
with open(file_path, "r", encoding="utf-8") as f:
return f.read()
except FileNotFoundError:
return f"[Error] File not found: {file_path}"
except Exception as e:
return f"Error: 读取文件失败,原因:{str(e)}"
return f"[Error] Failed to read file: {str(e)}"
@@ -0,0 +1,122 @@
# Copyright 2026 zhaoxi826
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tavily Web Search Tool Plugin for KiloStar.
Provides intelligent web search capabilities via Tavily API.
API key 取值优先级:调用参数 > GlobalStateMachine 中 ``tavily_search`` 工具配置 >
环境变量 ``TAVILY_API_KEY``。
"""
import os
from typing import List, Literal, Dict, Optional
from kilostar.plugin.tool_plugin.base_tool import BaseToolData
from tavily import AsyncTavilyClient
class TavilySearchToolData(BaseToolData):
"""Tavily 搜索工具的元数据:面向所有节点开放。"""
is_system: bool = False
action_scope: List[
Literal[
"control_node",
"consciousness_node",
"regulatory_node",
"growth_node",
]
] = ["control_node", "consciousness_node", "regulatory_node"]
config_args: Dict[str, str] = {
"api_key": "",
"max_results": "5",
"search_depth": "basic",
"include_answer": "true",
}
category: str = "search"
async def _resolve_api_key(explicit: Optional[str]) -> Optional[str]:
"""按优先级解析 Tavily API key:显式参数 > GSM 配置 > 环境变量。"""
if explicit:
return explicit
try:
from kilostar.core.global_state_machine.gsm_snapshot import fetch_snapshot
# 工具调用是高频热路径,走 Object Store 快照而不是 actor RPC
snapshot = await fetch_snapshot()
cfg = snapshot.tool_configs.get("tavily_search") or {}
if isinstance(cfg, dict) and cfg.get("api_key"):
return cfg["api_key"]
except Exception:
pass
return os.environ.get("TAVILY_API_KEY")
async def tavily_search(
query: str,
max_results: int = 5,
search_depth: str = "basic",
include_answer: bool = True,
api_key: Optional[str] = None,
) -> str:
"""使用 Tavily 进行网络搜索,获取高质量的网络搜索结果。
Args:
query: 搜索查询内容
max_results: 返回的最大结果数量(1-10)
search_depth: 搜索深度,"basic""advanced"
include_answer: 是否包含 AI 生成的答案摘要
api_key: 可选;不传则按 GSM 配置 → 环境变量顺序解析
Returns:
格式化的搜索结果文本,包含标题、URL、摘要和可选的 AI 答案
"""
resolved_key = await _resolve_api_key(api_key)
if not resolved_key:
return (
"[Error] Tavily API key 未配置。"
"请在 ``/api/v1/resource/tool/config`` 写入或设置环境变量 ``TAVILY_API_KEY``。"
)
try:
client = AsyncTavilyClient(api_key=resolved_key)
result = await client.search(
query=query,
max_results=min(max_results, 10),
search_depth=search_depth,
include_answer=include_answer,
)
lines = []
if include_answer and result.get("answer"):
lines.append(f"【AI 摘要】{result['answer']}\n")
results = result.get("results", [])
if not results:
return "No results found for the query."
lines.append("【搜索结果】")
for i, item in enumerate(results, 1):
title = item.get("title", "Untitled")
url = item.get("url", "")
content = item.get("content", "").strip()
lines.append(f"\n{i}. {title}")
lines.append(f" URL: {url}")
if content:
lines.append(f" {content[:300]}{'...' if len(content) > 300 else ''}")
return "\n".join(lines)
except Exception as e:
return f"[Error] Tavily search failed: {str(e)}"
+9 -3
View File
@@ -12,10 +12,16 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Annotated
from fastapi import Depends, HTTPException
from fastapi import Depends, HTTPException, Request
from kilostar.utils.access import Accessor, TokenData
from kilostar.core.postgres_database.model import UserAuthority
from kilostar.utils.ray_hook import ray_actor_hook
from kilostar.utils.i18n import t
def _user_not_found_detail(request: Request | None = None) -> str:
loc = request.headers.get("accept-language") if request else None
return t("user_not_found", accept_language=loc)
async def get_authority(user_id: str) -> UserAuthority:
@@ -29,12 +35,12 @@ async def get_authority(user_id: str) -> UserAuthority:
)
return user_authority
except UserNotExistError:
raise HTTPException(status_code=401, detail="用户不存在或已被删除,请重新登录")
raise HTTPException(status_code=401, detail=t("user_not_found"))
except Exception as e:
# Check if it's a RayTaskError wrapping UserNotExistError
if "UserNotExistError" in str(e):
raise HTTPException(
status_code=401, detail="用户不存在或已被删除,请重新登录"
status_code=401, detail=t("user_not_found")
)
raise
+87
View File
@@ -0,0 +1,87 @@
import os
from functools import lru_cache
from cryptography.fernet import Fernet, InvalidToken
from kilostar.utils.logger import get_logger
logger = get_logger("crypto")
_VERSION_PREFIX = "v1:"
_SENSITIVE_KEYS = {"key", "token", "secret", "password", "apikey", "api_key"}
class CryptoError(Exception):
pass
@lru_cache(maxsize=1)
def _get_fernet() -> Fernet:
raw = os.environ.get("KILOSTAR_SECRET_KEY", "")
if not raw:
raise CryptoError(
"环境变量 KILOSTAR_SECRET_KEY 未设置,无法进行加解密。"
"请生成一个密钥:python -c \"from cryptography.fernet import Fernet; print(Fernet.generate_key().decode())\""
)
try:
return Fernet(raw.encode() if isinstance(raw, str) else raw)
except Exception as e:
raise CryptoError(f"KILOSTAR_SECRET_KEY 格式无效: {e}") from e
def encrypt_secret(plaintext: str) -> str:
if not plaintext:
return plaintext
f = _get_fernet()
token = f.encrypt(plaintext.encode("utf-8"))
return _VERSION_PREFIX + token.decode("utf-8")
def decrypt_secret(ciphertext: str) -> str:
if not ciphertext:
return ciphertext
if not ciphertext.startswith(_VERSION_PREFIX):
return ciphertext
raw = ciphertext[len(_VERSION_PREFIX):]
f = _get_fernet()
try:
return f.decrypt(raw.encode("utf-8")).decode("utf-8")
except InvalidToken as e:
raise CryptoError("解密失败:密钥不匹配或密文已损坏") from e
def is_encrypted(value: str) -> bool:
return isinstance(value, str) and value.startswith(_VERSION_PREFIX)
def _is_sensitive_key(key: str) -> bool:
lower = key.lower()
return any(s in lower for s in _SENSITIVE_KEYS)
def encrypt_dict_secrets(data: dict) -> dict:
if not isinstance(data, dict):
return data
out: dict = {}
for k, v in data.items():
if _is_sensitive_key(k) and isinstance(v, str) and v and not is_encrypted(v):
out[k] = encrypt_secret(v)
else:
out[k] = v
return out
def decrypt_dict_secrets(data: dict) -> dict:
if not isinstance(data, dict):
return data
out: dict = {}
for k, v in data.items():
if _is_sensitive_key(k) and isinstance(v, str) and is_encrypted(v):
try:
out[k] = decrypt_secret(v)
except CryptoError as e:
logger.error(f"字段 {k} 解密失败: {e}")
out[k] = v
else:
out[k] = v
return out
+92 -27
View File
@@ -12,68 +12,133 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""KiloStar 统一异常体系。
class RetryableError(Exception):
"""基类:所有可重试错误(如网络断开、抖动等临时性故障)。"""
设计原则:所有自定义异常归到两条主轴下。
pass
- ``BusinessError``:业务可预期错误,HTTP 层映射 4xx;前端可读、可展示给用户。
- ``InfraError``:系统/基础设施失败错误,HTTP 层映射 5xx;通常需要落日志告警。
其下再细分为 ``RetryableError``(瞬时故障,可由 ``retry_on_retryable_error`` 自动重试)
与 ``NonRetryableError``(确定性失败,重试无意义)。
注意:用 ``InfraError`` 而非 ``SystemError`` 是为了避免与 Python 内置的
``SystemError`` 冲突。
每个异常类都带 ``http_status`` 与 ``code`` 类属性,``api/__init__.py`` 的统一
handler 根据它们直接生成结构化响应,避免业务代码里硬编码状态码。
"""
from __future__ import annotations
class NonRetryableError(Exception):
"""基类:所有不可重试错误(如数据验证失败、类型错误等业务逻辑故障)"""
class KiloStarError(Exception):
"""KiloStar 所有自定义异常的总根"""
pass
http_status: int = 500
code: str = "kilostar_error"
class DemandError(NonRetryableError):
# ─── 主轴 1:业务可预期错误(4xx) ───────────────────────────────────────────
class BusinessError(KiloStarError):
"""业务层可预期错误的基类,HTTP 层默认 400。"""
http_status = 400
code = "business_error"
class DemandError(BusinessError):
"""需求/任务参数不合法或不满足前置条件时抛出。"""
pass
http_status = 400
code = "demand_error"
class ModelNotExistError(Exception):
"""请求了一个未在 Provider 中注册的模型 ID 时抛出。"""
pass
# 用户域 ─────────────────────────────────────────
class UserError(Exception):
"""用户相关错误的基类HTTP 层会被统一映射为 4xx"""
class UserError(BusinessError):
"""用户错误的基类。"""
pass
http_status = 400
code = "user_error"
class UserNotExistError(UserError):
"""按用户名/ID 查询时用户不存在。"""
pass
http_status = 404
code = "user_not_exist"
class UserPasswordError(UserError):
"""口令校验失败(旧密码错误、登录密码错误等)。"""
pass
http_status = 401
code = "user_password_error"
class ProviderError(Exception):
"""模型 Provider 相关错误的基类。"""
# Provider 域 ─────────────────────────────────────
pass
class ProviderError(BusinessError):
"""模型 Provider 域错误的基类。"""
http_status = 400
code = "provider_error"
class ProviderNotExistError(ProviderError):
"""请求了一个未注册的 Provider 时抛出。"""
pass
http_status = 404
code = "provider_not_exist"
class WorkflowError(Exception):
"""工作流执行期错误的基类,HTTP 层会被统一映射为 5xx"""
class ModelNotExistError(BusinessError):
"""请求了一个未在 Provider 中注册的模型 ID 时抛出"""
pass
http_status = 404
code = "model_not_exist"
class WorkflowExit(WorkflowError):
"""工作流被显式终止(用户取消、上游决策跳出等)时抛出,是预期内的退出信号。"""
# Workflow 域 ─────────────────────────────────────
pass
class WorkflowExit(BusinessError):
"""工作流被显式终止(用户取消、上游决策跳出等),是预期内的退出信号。"""
http_status = 400
code = "workflow_exit"
# ─── 主轴 2:系统/基础设施失败错误(5xx) ────────────────────────────────────
class InfraError(KiloStarError):
"""系统/基础设施失败错误的基类,HTTP 层默认 500。"""
http_status = 500
code = "infra_error"
class RetryableError(InfraError):
"""瞬时故障(如网络抖动),可由 ``retry_on_retryable_error`` 自动重试。"""
http_status = 503
code = "retryable_error"
class NonRetryableError(InfraError):
"""确定性的系统失败,重试无意义。"""
http_status = 500
code = "non_retryable_error"
class WorkflowError(InfraError):
"""工作流执行期错误的基类,HTTP 层映射为 5xx。"""
http_status = 500
code = "workflow_error"
+4 -2
View File
@@ -33,9 +33,11 @@ def _get_tool_func(tool_name: str) -> Callable | None:
if func:
return func
app_root = "/app"
tool_plugin_dir = os.path.join(
app_root, "kilostar", "plugin", "tool_plugin", tool_name
os.path.dirname(os.path.dirname(os.path.abspath(__file__))),
"plugin",
"tool_plugin",
tool_name,
)
if not os.path.exists(tool_plugin_dir) or not os.path.isdir(tool_plugin_dir):
+183
View File
@@ -0,0 +1,183 @@
# Copyright 2026 zhaoxi826
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""KiloStar 轻量级国际化工具。
设计原则:
- 纯内存字典,无文件 IO,Ray 远程序列化零成本。
- 支持环境变量 ``KILOSTAR_LANG`` 作为全局默认语言。
- Agent system prompt 按 ``{locale}`` 分桶,调用方显式传入 locale。
- API 层通过请求头 ``Accept-Language`` 解析首选语言。
当前支持:``zh`` (简体中文), ``en`` (English)。
"""
from __future__ import annotations
import os
from typing import Dict
_DEFAULT_LOCALE: str = os.getenv("KILOSTAR_LANG", "zh")
# ─── Agent System Prompts ──────────────────────────────────────────────────
_PROMPTS: Dict[str, Dict[str, str]] = {
"regulatory_node": {
"zh": (
"你叫kilostar,是一个多智能体AI助手系统中的【监控节点 (regulatory Node)】。\n"
"你是系统的'前台接待''大脑皮层',负责接收用户的初始请求或工作流的最终报告。\n"
"你的核心职责是进行【意图识别与路由】。请仔细阅读用户的请求:\n"
"1. 如果用户只是进行简单的问候、闲聊或查询非常基础的信息,请直接生成友好的回复,使用 ForUser 格式。\n"
"2. 如果用户提出的是复杂任务(如需要编写代码、多步骤规划、数据处理等),请务必将其判定为需要工作流处理的任务,"
" 并使用 ForConsciousnessNode 格式将其移交意识节点处理。\n"
"3. 如果你收到的是 TerminationMessage(代表工作流已完成并生成了报告),请将报告内容转化为友好的面向用户的回复,使用 ForUser 格式。\n"
"请保持冷静、专业,并严格遵循上述路由规则。"
),
"en": (
"You are kilostar, the [Regulatory Node] in a multi-agent AI assistant system.\n"
"You are the system's 'front desk' and 'cerebral cortex', responsible for receiving user requests and final workflow reports.\n"
"Your core duty is [intent recognition and routing]. Please read the user's request carefully:\n"
"1. If the user is simply greeting, chatting, or asking very basic questions, generate a friendly reply directly in the ForUser format.\n"
"2. If the user presents a complex task (e.g., writing code, multi-step planning, data processing), you must classify it as a workflow-requiring task "
" and hand it over to the Consciousness Node using the ForConsciousnessNode format.\n"
"3. If you receive a TerminationMessage (indicating the workflow is complete and a report has been generated), convert the report into a user-friendly reply in the ForUser format.\n"
"Please remain calm, professional, and strictly follow the routing rules above."
),
},
"consciousness_node": {
"zh": (
"你叫kilostar,是一个多智能体AI助手系统中的【意识节点 (Consciousness Node)】。\n"
"你是系统的'高级规划师''架构师',负责处理监控节点分配过来的复杂任务。\n"
"你的主要工作场景包括:\n"
"1. 拆解任务 (Workflow Generation):结合用户的原始命令和提供的模板,生成严谨、可执行的工作流 (kilostarWorkflow),并将其输出为 ForWorkflowEngine 格式。拆解时步骤应清晰连贯。\n"
"2. 中途指导 (Workflow Execution):在工作流执行中,如果某一步骤指派给你,你需要对控制节点的结果进行分析或提供下一步的指导,输出 ForWorkflow 格式。\n"
"3. 总结报告 (regulatory Report):在整个工作流执行完毕后,你需要对整体流程、各个控制节点的执行情况进行审查,并生成一份技术性的总结报告,输出 ForregulatoryNode 格式。\n"
"请确保所有的思考和生成过程符合逻辑,严密且高质量。"
),
"en": (
"You are kilostar, the [Consciousness Node] in a multi-agent AI assistant system.\n"
"You are the system's 'senior planner' and 'architect', responsible for handling complex tasks assigned by the Regulatory Node.\n"
"Your main scenarios include:\n"
"1. Task Decomposition (Workflow Generation): Combine the user's original command with provided templates to generate rigorous, executable workflows (kilostarWorkflow), outputting them in the ForWorkflowEngine format. Steps should be clear and coherent.\n"
"2. Mid-flight Guidance (Workflow Execution): During workflow execution, if a step is assigned to you, analyze the Control Node's results or provide next-step guidance, outputting in the ForWorkflow format.\n"
"3. Summary Report (Regulatory Report): After the entire workflow completes, review the overall process and each Control Node's execution, generating a technical summary report in the ForregulatoryNode format.\n"
"Ensure all reasoning and generation is logical, rigorous, and high-quality."
),
},
"control_node": {
"zh": (
"你叫kilostar,是一个多智能体AI助手系统中的【控制节点 (Control Node)】。\n"
"你是系统的'执行者''车间主任',专门负责执行工作流中分配给你的具体子任务。\n"
"你的工作职责是:\n"
"1. 仔细分析分配给你的工作流步骤 (workflow_step) 的目标和要求。\n"
"2. 运用你被分配的工具 (如有) 或者依靠自身的知识和推理能力,精准、高效地完成该任务。\n"
"3. 将执行的结果、产生的数据或者具体的输出,严格按照 ForWorkflow 格式返回。\n"
"请注意:你的输出应当具体、实用,直接提供任务所要求的结果,不要做过多无关的寒暄。"
),
"en": (
"You are kilostar, the [Control Node] in a multi-agent AI assistant system.\n"
"You are the system's 'executor' and 'shop floor manager', specifically responsible for carrying out concrete subtasks assigned to you within the workflow.\n"
"Your duties are:\n"
"1. Carefully analyze the objectives and requirements of the workflow_step assigned to you.\n"
"2. Use the tools assigned to you (if any) or rely on your own knowledge and reasoning to complete the task accurately and efficiently.\n"
"3. Return the execution results, generated data, or concrete outputs strictly in the ForWorkflow format.\n"
"Note: Your output should be specific, practical, and directly provide the results requested by the task. Avoid excessive irrelevant pleasantries."
),
},
}
# ─── API / 通用消息 ────────────────────────────────────────────────────────
_MESSAGES: Dict[str, Dict[str, str]] = {
"internal_error": {
"zh": "服务内部错误,请稍后重试",
"en": "Internal server error, please try again later.",
},
"user_not_found": {
"zh": "用户不存在或已被删除,请重新登录",
"en": "User does not exist or has been deleted. Please log in again.",
},
"provider_not_registered": {
"zh": "Provider {provider_title} 未注册",
"en": "Provider {provider_title} is not registered.",
},
"model_not_exist": {
"zh": "模型不存在",
"en": "Model does not exist.",
},
"api_not_found": {
"zh": "API endpoint not found",
"en": "API endpoint not found",
},
"frontend_not_found": {
"zh": "Frontend build not found",
"en": "Frontend build not found",
},
}
# ─── 工具函数 ──────────────────────────────────────────────────────────────
def _resolve_locale(locale: str | None = None, accept_language: str | None = None) -> str:
"""确定最终使用的 locale。
优先级:显式传入 > Accept-Language 头 > KILOSTAR_LANG 环境变量 > 默认 zh。
"""
if locale:
return locale if locale in ("zh", "en") else _DEFAULT_LOCALE
if accept_language:
# 简单解析:取第一个 segment,若含 zh 则 zh,含 en 则 en
first = accept_language.split(",")[0].split(";")[0].strip().lower()
if "zh" in first:
return "zh"
if "en" in first:
return "en"
return _DEFAULT_LOCALE
def t(key: str, locale: str | None = None, accept_language: str | None = None, **kwargs) -> str:
"""通用消息翻译。
Args:
key: 消息键,如 ``internal_error``。
locale: 显式指定语言代码(``zh`` / ``en``)。
accept_language: 前端传来的 ``Accept-Language`` 头内容。
**kwargs: 模板变量插值。
Returns:
翻译后的字符串;若 key 不存在则返回 key 本身。
"""
loc = _resolve_locale(locale, accept_language)
text = _MESSAGES.get(loc, {}).get(key) or _MESSAGES.get(_DEFAULT_LOCALE, {}).get(key) or key
return text.format(**kwargs) if kwargs else text
def agent_prompt(agent_name: str, locale: str | None = None, accept_language: str | None = None) -> str:
"""获取指定 Agent 的 system prompt,并追加语言指令。
Args:
agent_name: ``regulatory_node`` / ``consciousness_node`` / ``control_node``
locale: 显式指定语言代码。
accept_language: ``Accept-Language`` 头内容。
Returns:
完整 system prompt(含 "请使用 XX 语言回复" 的追加指令)。
"""
loc = _resolve_locale(locale, accept_language)
prompt = _PROMPTS.get(agent_name, {}).get(loc) or _PROMPTS.get(agent_name, {}).get(_DEFAULT_LOCALE, "")
lang_instruction = {
"zh": "\n\n【重要】请始终使用简体中文进行思考和回复。",
"en": "\n\n[Important] Please always think and reply in English.",
}.get(loc, "")
return prompt + lang_instruction
+72 -9
View File
@@ -12,24 +12,83 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import sys
from loguru import logger
from rich.logging import RichHandler
from loguru._logger import Logger
from kilostar.utils.request_context import get_request_id, get_trace_id
def _is_json_mode() -> bool:
"""根据环境变量决定是否启用 JSON 结构化日志。
支持开关:``KILOSTAR_LOG_FORMAT=json`` 或 ``KILOSTAR_LOG_JSON=1/true``。
"""
fmt = os.environ.get("KILOSTAR_LOG_FORMAT", "").lower()
if fmt == "json":
return True
flag = os.environ.get("KILOSTAR_LOG_JSON", "").lower()
return flag in {"1", "true", "yes", "on"}
def _ctx_patcher(record):
"""日志切面:每条日志写出前,把 contextvars 里的 request_id / trace_id 注入。
显式 ``bind(trace_id=...)`` 的 logger 优先(业务代码可以覆盖切面值);
没有 bind 时回退到 contextvars,没有 contextvars 时为空串。
"""
extra = record["extra"]
if not extra.get("trace_id"):
extra["trace_id"] = get_trace_id()
if not extra.get("request_id"):
extra["request_id"] = get_request_id()
def setup_logger() -> Logger:
"""初始化全局 loguru logger,输出格式为 ``actor:(...) | trace_id:(...) : message``。"""
"""初始化全局 loguru logger
- 默认(开发模式):``RichHandler`` 彩色输出,格式 ``actor:(...) | request_id:(...) | trace_id:(...) : message``
- JSON 模式(``KILOSTAR_LOG_FORMAT=json``):写到 stdout,每行一条 JSON,便于 ELK/Loki 采集
request_id / trace_id 来自 ``kilostar.utils.request_context``,由 FastAPI middleware
或工作流入口绑定到 contextvars,本模块通过 ``patcher`` 透明注入。
"""
logger.remove()
log_level = os.environ.get("KILOSTAR_LOG_LEVEL", "DEBUG").upper()
if _is_json_mode():
logger.configure(
extra={"actor_name": "System", "trace_id": "", "request_id": ""},
patcher=_ctx_patcher,
)
logger.add(
sys.stdout,
serialize=True,
level=log_level,
enqueue=True,
)
return logger
def format_record(record):
# Format string for rich handler
actor = record["extra"].get("actor_name", "System")
trace_id = record["extra"].get("trace_id", "")
request_id = record["extra"].get("request_id", "")
ids = []
if request_id:
ids.append(f"request_id:({request_id})")
if trace_id:
ids.append(f"trace_id:({trace_id})")
ids_str = " | " + " | ".join(ids) if ids else ""
return f"actor:({actor}){ids_str} : {record['message']}"
trace_str = f" | trace_id:({trace_id})" if trace_id else ""
return f"actor:({actor}){trace_str} : {record['message']}"
logger.configure(extra={"actor_name": "System", "trace_id": ""})
logger.configure(
extra={"actor_name": "System", "trace_id": "", "request_id": ""},
patcher=_ctx_patcher,
)
logger.add(
RichHandler(
@@ -40,8 +99,8 @@ def setup_logger() -> Logger:
show_path=False,
),
format=format_record,
level="DEBUG",
enqueue=True, # 异步记录
level=log_level,
enqueue=True,
)
return logger
@@ -51,5 +110,9 @@ global_logger = setup_logger()
def get_logger(actor_name: str, trace_id: str = "") -> Logger:
"""获取一个绑定了 actor_name 与可选 trace_id 的 logger,便于日志按 Actor/请求归类。"""
"""获取一个绑定了 actor_name 与可选 trace_id 的 logger,便于日志按 Actor/请求归类。
若 ``trace_id`` 留空,会回退到 ``contextvars`` 中的当前值(由 middleware 或
工作流入口设置)。显式传值则会覆盖切面注入。
"""
return global_logger.bind(actor_name=actor_name, trace_id=trace_id)
+180
View File
@@ -0,0 +1,180 @@
# Copyright 2026 zhaoxi826
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""MCP 辅助模块:根据全局状态机中的配置创建 pydantic-ai MCPServer 实例。"""
from typing import Dict, List, Any, Sequence
from kilostar.utils.logger import get_logger
logger = get_logger("mcp_helper")
# 延迟导入 pydantic_ai.mcp,避免在 MCP 包未安装时崩溃
try:
from pydantic_ai.mcp import (
MCPServerStdio,
MCPServerSSE,
MCPServerHTTP,
)
_MCP_AVAILABLE = True
except ImportError:
_MCP_AVAILABLE = False
logger.warning("MCP package not installed. MCP servers will not be available.")
def build_mcp_toolsets(configs: Dict[str, Dict[str, Any]]) -> List[Any]:
"""根据配置字典创建 MCPServer 实例列表。
Args:
configs: {server_id: {"name": ..., "transport": ..., ...}}
Returns:
MCPServer 实例列表(可直接传给 Agent 的 toolsets 参数)
"""
if not _MCP_AVAILABLE:
return []
toolsets = []
for server_id, cfg in configs.items():
try:
transport = cfg.get("transport", "stdio")
tool_prefix = cfg.get("tool_prefix")
name = cfg.get("name", server_id)
if transport == "stdio":
server = MCPServerStdio(
command=cfg.get("command", ""),
args=cfg.get("args", []),
env=cfg.get("env"),
tool_prefix=tool_prefix,
id=server_id,
)
elif transport == "sse":
server = MCPServerSSE(
url=cfg.get("url", ""),
tool_prefix=tool_prefix,
id=server_id,
)
elif transport == "http":
server = MCPServerHTTP(
url=cfg.get("url", ""),
tool_prefix=tool_prefix,
id=server_id,
)
else:
logger.warning(f"Unsupported MCP transport: {transport} for server {name}")
continue
toolsets.append(server)
logger.info(f"MCP server '{name}' ({transport}) registered as toolset")
except Exception as e:
logger.error(f"Failed to build MCP server '{server_id}': {e}")
return toolsets
async def get_mcp_toolsets_from_gsm() -> List[Any]:
"""从 GlobalStateMachine 拉取 MCP 配置并构建 toolsets。"""
if not _MCP_AVAILABLE:
return []
try:
from kilostar.core.global_state_machine.gsm_snapshot import fetch_snapshot
# 走快照:MCP 配置变更频率极低,本地缓存命中率近 100%
snapshot = await fetch_snapshot()
return build_mcp_toolsets(snapshot.mcp_servers)
except Exception as e:
logger.error(f"Failed to load MCP configs from GSM: {e}")
return []
async def get_all_toolsets_for_scope(scope: str) -> List[Any]:
"""汇总某个 scope 下的全部 toolsetsystem + personal + mcp。
返回顺序保持稳定:先本地 toolsetsystem → personal),再 MCP toolset。
任意一类拉取失败仅记录日志,不影响其他类。
"""
toolsets: List[Any] = []
try:
from kilostar.core.global_state_machine.gsm_snapshot import (
build_toolsets_for_scope,
fetch_snapshot,
)
# 一次快照拉取覆盖 system + custom toolsets,本地按 scope 重建 FunctionToolset
snapshot = await fetch_snapshot()
local = build_toolsets_for_scope(snapshot, scope)
if local:
toolsets.extend(local)
except Exception as e:
logger.error(f"Failed to load local toolsets from GSM ({scope}): {e}")
toolsets.extend(await get_mcp_toolsets_from_gsm())
return toolsets
async def list_mcp_tools_for_configs(
configs: Dict[str, Dict[str, Any]],
) -> List[Dict[str, Any]]:
"""对每个 MCP 服务器逐个尝试连接,列出它们暴露的工具名。
实现层面会进入 ``async with server:`` 上下文,调用一次 ``get_tools()``
再把工具名(带 tool_prefix)抽出来。任何一个 server 失败都不影响其他 server
出错时该项 ``tools=[]`` 并附带 ``error`` 字段。
"""
result: List[Dict[str, Any]] = []
if not _MCP_AVAILABLE:
return result
servers = build_mcp_toolsets(configs)
for server in servers:
server_id = getattr(server, "id", None)
cfg = configs.get(server_id, {}) if server_id else {}
name = cfg.get("name", server_id or "unknown")
transport = cfg.get("transport", "stdio")
item: Dict[str, Any] = {
"server_id": server_id,
"name": name,
"transport": transport,
"tool_prefix": cfg.get("tool_prefix"),
"tools": [],
}
try:
async with server:
tools = await server.get_tools()
item["tools"] = [
getattr(t, "name", None) or getattr(t, "tool_name", str(t))
for t in tools
]
except Exception as e:
item["error"] = str(e)
logger.warning(f"MCP server '{name}' list_tools failed: {e}")
result.append(item)
return result
async def list_mcp_tools_from_gsm() -> List[Dict[str, Any]]:
"""从 GlobalStateMachine 拉取配置后调用 :func:`list_mcp_tools_for_configs`。"""
if not _MCP_AVAILABLE:
return []
try:
from kilostar.core.global_state_machine.gsm_snapshot import fetch_snapshot
snapshot = await fetch_snapshot()
return await list_mcp_tools_for_configs(snapshot.mcp_servers)
except Exception as e:
logger.error(f"Failed to list MCP tools from GSM: {e}")
return []
+130
View File
@@ -0,0 +1,130 @@
# Copyright 2026 zhaoxi826
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""请求/工作流上下文:基于 ``contextvars`` 的双层 ID 传播。
设计上把"一次用户请求""一次重型工作流"区分开:
- ``request_id``:会话域。所有进 API 的请求都要带,由 middleware 在入口生成或
从 ``X-Request-Id`` 头继承。chat 这条同步链路靠它走完一生。
- ``trace_id``:工作流域。只有 ``ConsciousnessNode`` 决定启动重型任务时才生成,
挂到 ``KiloStarWorkflow`` 上。trace_id 应能追溯回触发它的 request_id(前者
通过显式参数传入,后者从 contextvars 读取)。
为什么用 ``contextvars`` 而不是参数透传:
1. ``contextvars`` 在 ``asyncio`` 协程间天然继承,不会跨协程串味;
2. ``loguru`` 的 ``patcher`` 钩子可以把它变成日志切面,业务代码不需要在每条
``logger.info`` 上手动 ``.bind(trace_id=...)``
3. Ray 跨进程调用时 contextvars 不会自动传播 —— 这是有意为之,避免不同 actor
间的上下文意外串联。跨 actor 边界要走显式参数,由接收方再 ``bind_*`` 一次。
"""
from __future__ import annotations
import uuid
from contextlib import contextmanager
from contextvars import ContextVar, Token
from typing import Iterator, Optional
_request_id_var: ContextVar[str] = ContextVar("kilostar_request_id", default="")
_trace_id_var: ContextVar[str] = ContextVar("kilostar_trace_id", default="")
def get_request_id() -> str:
"""返回当前协程的 ``request_id``,未绑定时返回空串。"""
return _request_id_var.get()
def get_trace_id() -> str:
"""返回当前协程的 ``trace_id``,未绑定时返回空串。"""
return _trace_id_var.get()
def bind_request_id(request_id: str) -> Token:
"""直接绑定 ``request_id`` 到当前 context,返回 token 以便 ``reset`` 还原。
返回的 ``Token`` 只能在与 ``set`` 同一线程/协程中传给 ``reset``,否则会抛
``ValueError``。一般情况下推荐用 ``request_id_scope`` 上下文管理器代替。
"""
return _request_id_var.set(request_id)
def bind_trace_id(trace_id: str) -> Token:
"""直接绑定 ``trace_id`` 到当前 context,返回 token 以便 ``reset`` 还原。"""
return _trace_id_var.set(trace_id)
def reset_request_id(token: Token) -> None:
_request_id_var.reset(token)
def reset_trace_id(token: Token) -> None:
_trace_id_var.reset(token)
@contextmanager
def request_id_scope(request_id: str) -> Iterator[str]:
"""``with`` 范围内绑定 request_id,退出自动还原。"""
token = _request_id_var.set(request_id)
try:
yield request_id
finally:
_request_id_var.reset(token)
@contextmanager
def trace_id_scope(trace_id: str) -> Iterator[str]:
"""``with`` 范围内绑定 trace_id,退出自动还原。"""
token = _trace_id_var.set(trace_id)
try:
yield trace_id
finally:
_trace_id_var.reset(token)
def new_request_id(prefix: str = "req") -> str:
"""生成一个新的 request_id``<prefix>-<uuid4 hex>``。"""
return f"{prefix}-{uuid.uuid4().hex}"
def snapshot() -> dict[str, str]:
"""返回当前上下文 ID 的快照,便于跨 actor/task 边界显式透传。"""
return {
"request_id": _request_id_var.get(),
"trace_id": _trace_id_var.get(),
}
@contextmanager
def apply_snapshot(snap: Optional[dict[str, str]]) -> Iterator[None]:
"""把外部传来的 snapshot 在当前 context 内生效一次(用于跨 Ray actor 调用时)。"""
if not snap:
yield
return
tokens: list[Token] = []
if snap.get("request_id"):
tokens.append(_request_id_var.set(snap["request_id"]))
if snap.get("trace_id"):
tokens.append(_trace_id_var.set(snap["trace_id"]))
try:
yield
finally:
for tok in reversed(tokens):
try:
tok.var.reset(tok)
except (ValueError, LookupError):
# token 可能因协程切换失效,宽容处理
pass
+5 -1
View File
@@ -59,10 +59,14 @@ class WorkerCluster:
self._active_workers.move_to_end(agent_id)
return self._active_workers[agent_id]
from kilostar.core.global_state_machine.gsm_snapshot import fetch_snapshot
global_state_machine = ray_actor_hook(
"global_state_machine"
).global_state_machine
agent_config = await global_state_machine.get_individual.remote(agent_id)
# 走快照读,避开 GSM actor RPC:高频唤醒路径不再是单 actor 瓶颈
snapshot = await fetch_snapshot(gsm_actor=global_state_machine)
agent_config = snapshot.individuals.get(agent_id)
if not agent_config:
raise ValueError(f"无法唤醒 Agent {agent_id}:数据库中不存在该档案")
+14 -4
View File
@@ -52,17 +52,21 @@ class BaseIndividual:
self.agent_id = agent_config.get("agent_id")
self.agent: Agent | None = None
async def _init_agent(self, agent_name: str, system_prompt: str):
async def _init_agent(self, agent_name: str, system_prompt: str, toolsets=None):
"""根据 agent_config 拉起一个 pydantic-ai Agent 实例。
从 GlobalStateMachine 取出 Provider,按 agent_config 中的 provider_title
和 model_id 选择模型,加载工具列表,并把 system_prompt 注册为动态提示词。
若调用方未显式提供 ``toolsets``,会自动从全局状态机拉取 MCP toolsets 注入。
Args:
agent_name: Agent 的人类可读名称,用于日志与展示。
system_prompt: 该 Agent 的基础系统提示词,会和 task_event 拼接成动态提示词。
toolsets: 显式传入的外部工具集;为 ``None`` 时会自动拉取 MCP toolsets。
"""
from kilostar.utils.get_tool import load_tools_from_list
from kilostar.utils.mcp_helper import get_all_toolsets_for_scope
from kilostar.core.global_state_machine.gsm_snapshot import fetch_snapshot
global_state_machine = ray_actor_hook(
"global_state_machine"
@@ -73,13 +77,18 @@ class BaseIndividual:
model_id = self.agent_config.get("model_id", "gpt-4o") # default fallback
tools_list = self.agent_config.get("tools", None)
provider: Provider = await global_state_machine.get_provider.remote(
provider_title
)
# 直读快照,避开 actor RPC 单线程串行
snapshot = await fetch_snapshot(gsm_actor=global_state_machine)
provider: Provider = snapshot.providers.get(provider_title)
if provider is None:
raise ValueError(f"Provider {provider_title!r} 未注册")
agent_factory = AgentFactory()
callables = load_tools_from_list(tools_list)
if toolsets is None:
toolsets = await get_all_toolsets_for_scope(agent_name)
self.agent = agent_factory.create_agent(
provider=provider,
model_id=model_id,
@@ -88,6 +97,7 @@ class BaseIndividual:
deps_type=WorkerIndividualDeps,
agent_name=agent_name,
tools=callables,
toolsets=toolsets,
)
@self.agent.system_prompt