feat(system):优化后端

1.新增后端测试
2.增加了后端的加密
3.增加了i18n(国际化)
This commit is contained in:
2026-05-31 15:39:34 +00:00
parent affe460180
commit 99520c69d7
118 changed files with 8174 additions and 1491 deletions
@@ -13,47 +13,129 @@
# limitations under the License.
import ray
from kilostar.core.global_state_machine.provider_manager import ProviderManager
from kilostar.core.global_state_machine.tool_manager import GlobalToolManager
from kilostar.core.postgres_database import PostgresDatabase
from kilostar.core.global_state_machine.skill_manager import GlobalSkillManager
from typing import Any, Dict, List, Optional, Tuple
from kilostar.core.global_state_machine.individual_manager import (
GlobalIndividualManager,
)
from kilostar.core.global_state_machine.provider_manager import ProviderManager
from kilostar.core.global_state_machine.skill_manager import GlobalSkillManager
from kilostar.core.global_state_machine.tool_manager import GlobalToolManager
from kilostar.core.global_state_machine.gsm_snapshot import GSMSnapshot
from kilostar.core.postgres_database import PostgresDatabase
@ray.remote
class GlobalStateMachine:
"""全局状态机 Actor,统一持有 Provider/Tool/Skill/Individual 四个注册表。
"""全局状态机 Actor,统一持有 Provider/Tool/Skill/Individual/MCP/CustomToolset 注册表。
其它 Actor 通过 ``ray.get_actor("global_state_machine")`` 拿到本实例,
再调用本类暴露的方法来读写各注册表,避免每个 Actor 各自维护一份状态
所有持久化都走 PostgresDatabase;启动时由 ``init_state_machine`` 一次性把
Provider / MCP / Tool config / Custom toolset 拉到内存,保证后续读操作零等待
"""
def __init__(self, postgres_database: PostgresDatabase):
import sys
print("GSM __init__ START", file=sys.stderr, flush=True)
print(" event_dict done", file=sys.stderr, flush=True)
self._global_provider_manager = ProviderManager(postgres_database)
print(" provider_manager done", file=sys.stderr, flush=True)
self._global_tool_manager = GlobalToolManager()
print(" tool_manager done", file=sys.stderr, flush=True)
self._global_skill_manager = GlobalSkillManager()
print(" skill_manager done", file=sys.stderr, flush=True)
self._global_individual_manager = GlobalIndividualManager()
print(" individual_manager done", file=sys.stderr, flush=True)
# 内存注册表(启动时由 init_state_machine 从 DB 加载)
self._mcp_servers: Dict[str, Dict[str, Any]] = {}
self._tool_configs: Dict[str, Dict[str, Any]] = {}
self._custom_toolsets: Dict[str, Dict[str, Any]] = {}
# 配置快照与版本号:每次写入 → version+=1 → ray.put 新 snapshot
# 读端通过 current_config_ref 拿 ref 后用 ray.get 直读,绕开 actor 单线程瓶颈
self._config_version: int = 0
self._current_ref: Optional[ray.ObjectRef] = None
self.postgres_database = postgres_database
print("GSM __init__ DONE", file=sys.stderr, flush=True)
async def init_state_machine(self):
"""从数据库加载 Provider/Individual 注册表到内存。"""
"""启动期一次性把 Provider/Individual/MCP/ToolConfig/CustomToolset 拉到内存。"""
await self._global_provider_manager.init_provider_register(
self.postgres_database
)
await self._global_individual_manager.init_individual_register(
self.postgres_database
)
# MCP servers
rows = await self.postgres_database.list_mcp_servers_db.remote()
self._mcp_servers = {row["server_id"]: row for row in rows}
# Tool configs
cfg_rows = await self.postgres_database.list_tool_configs_db.remote()
self._tool_configs = {row["tool_name"]: row["config"] for row in cfg_rows}
# Custom toolsets
ts_rows = await self.postgres_database.list_custom_toolsets.remote()
self._custom_toolsets = {row["toolset_id"]: row for row in ts_rows}
# 让 tool_manager 立刻把 custom toolset 装配成 FunctionToolset
self._global_tool_manager.rebuild_custom_toolsets(self._custom_toolsets)
# 启动期一次性发布 v1 快照,让等待中的读端立刻可用
self._publish_snapshot()
# ─── Snapshot 发布(Object Store 读路径) ────────────────────
def _build_snapshot(self) -> GSMSnapshot:
"""把当前内存状态打包成 GSMSnapshot,调用方应已确保数据是最新的。
注意 ``tool_funcs`` 这里被拍平成 ``{tool_name: callable}`` —— 内部按
scope 分桶的细节在 snapshot 层不暴露,task 端只关心"我能调哪些函数"
如果同名工具在多个 scope 出现,后写入的覆盖(系统工具与第三方互斥,
实际不会冲突)。
``system_tools_by_scope`` 单独保留按 scope 分桶的工具名列表,让客户端
在自己进程里复刻 ``get_toolsets_for_scope`` 的合并语义(fetch_snapshot
用户调 ``build_toolsets_for_scope`` 即可重建 FunctionToolset 列表)。
"""
tm = self._global_tool_manager
flat_funcs: Dict[str, Any] = {}
system_tools_by_scope: Dict[str, List[str]] = {}
for _scope, name_to_func in tm._tool_funcs.items():
system_tools_by_scope[_scope] = list(name_to_func.keys())
for name, func in name_to_func.items():
flat_funcs[name] = func
return GSMSnapshot(
version=self._config_version,
providers=dict(self._global_provider_manager.provider_register),
individuals=dict(self._global_individual_manager._individuals),
mcp_servers=dict(self._mcp_servers),
tool_configs=dict(self._tool_configs),
custom_toolsets=dict(self._custom_toolsets),
skills=dict(self._global_skill_manager.skill_mapper),
tool_metadata=dict(tm.tool_metadata),
tool_funcs=flat_funcs,
third_party_funcs=dict(tm._third_party_funcs),
tool_mapper={
scope: dict(name_to_cls)
for scope, name_to_cls in tm.tool_mapper.items()
},
system_tools_by_scope=system_tools_by_scope,
)
def _publish_snapshot(self) -> None:
"""版本号 +1 并把当前状态 put 到 Ray Object Store。
旧 ref 会因为引用计数归零而进入回收队列;正在执行的 task 已经把 ref
拷贝到了自己的进程,dec 不会影响它们的读取。
"""
self._config_version += 1
self._current_ref = ray.put(self._build_snapshot())
async def current_config_ref(self) -> Tuple[int, ray.ObjectRef]:
"""返回 ``(version, ObjectRef)``,调用方拿了 ref 后用 ``ray.get`` 自取。
**不要**直接返回 snapshot 对象 —— 那样会走 actor RPC 反序列化,丧失
object store 的共享内存优势。返回 ref 才能让调用方在自己进程里 ray.get。
"""
if self._current_ref is None:
self._publish_snapshot()
return self._config_version, self._current_ref
async def current_version(self) -> int:
"""轻量版:只返回当前版本号,用于读端判断本地缓存是否还新。"""
return self._config_version
# ─── Provider ──────────────────────────────────────────────
async def add_provider_wrap(
self,
@@ -64,7 +146,7 @@ class GlobalStateMachine:
provider_owner,
):
"""新增一个模型 Provider:内存注册 + 数据库持久化一并完成。"""
return await self._global_provider_manager.add_provider(
result = await self._global_provider_manager.add_provider(
provider_type=provider_type,
provider_title=provider_title,
provider_url=provider_url,
@@ -72,8 +154,9 @@ class GlobalStateMachine:
provider_owner=provider_owner,
postgres_database=self.postgres_database,
)
self._publish_snapshot()
return result
# Provider Manager Methods
def get_provider_list(self):
"""返回内存中已登记的全部 Provider。"""
return self._global_provider_manager.get_provider_list()
@@ -84,11 +167,14 @@ class GlobalStateMachine:
async def delete_provider(self, provider_title: str):
"""删除一个 Provider:内存注册 + 数据库持久化一并完成。"""
return await self._global_provider_manager.delete_provider(
result = await self._global_provider_manager.delete_provider(
provider_title, self.postgres_database
)
self._publish_snapshot()
return result
# ─── Tool / Toolset ────────────────────────────────────────
# Tool Manager Methods
def get_tool_mapper(self):
"""返回 agent_name -> {tool_name: callable} 的全量映射。"""
return self._global_tool_manager.tool_mapper
@@ -96,37 +182,152 @@ class GlobalStateMachine:
def get_tool_list(self, agent_name: str):
"""返回某个 agent 可用的工具集(其专属工具与 default 工具的并集)。"""
tools = self._global_tool_manager.tool_mapper.get(agent_name, {})
# also include default tools
default_tools = self._global_tool_manager.tool_mapper.get("default", {})
merged_tools = {**default_tools, **tools}
return merged_tools
return {**default_tools, **tools}
def get_tool_categories(self):
"""返回工具按分类聚合的完整信息。"""
return {
"system": self._global_tool_manager.get_system_tools(),
"third_party": self._global_tool_manager.get_third_party_tools(),
"by_category": {
"system": self._global_tool_manager.get_tools_by_category("system"),
"search": self._global_tool_manager.get_tools_by_category("search"),
"mcp": self._global_tool_manager.get_tools_by_category("mcp"),
"other": self._global_tool_manager.get_tools_by_category("other"),
},
"all": self._global_tool_manager.get_all_tools(),
}
def get_toolsets_for_scope(self, scope: str) -> List[Any]:
"""返回某个 scope 下的"系统 + 自定义工具组"toolset 列表(不含 MCP)。"""
return self._global_tool_manager.get_toolsets_for_scope(scope)
# ─── MCP Server Registry ───────────────────────────────────
async def add_mcp_server(self, server_id: str, config: Dict[str, Any]) -> bool:
"""注册一个 MCP 服务器配置(写库 → 写内存)。"""
saved = await self.postgres_database.upsert_mcp_server.remote(server_id, config)
self._mcp_servers[server_id] = saved
self._publish_snapshot()
return True
def get_mcp_server(self, server_id: str) -> Optional[Dict[str, Any]]:
return self._mcp_servers.get(server_id)
def list_mcp_servers(self) -> List[Dict[str, Any]]:
return [
{"server_id": sid, **cfg} for sid, cfg in self._mcp_servers.items()
]
async def delete_mcp_server(self, server_id: str) -> bool:
ok = await self.postgres_database.delete_mcp_server_db.remote(server_id)
self._mcp_servers.pop(server_id, None)
self._publish_snapshot()
return bool(ok)
def get_mcp_server_configs(self) -> Dict[str, Dict[str, Any]]:
"""返回原始 MCP 服务器配置字典(供节点创建 toolsets 时使用)。"""
return dict(self._mcp_servers)
# ─── Tool ConfigTavily API key 等)─────────────────────
async def set_tool_config(self, tool_name: str, config: Dict[str, Any]) -> bool:
"""整体覆盖某工具的运行期配置(敏感字段在 DAO 内自动加密)。"""
saved = await self.postgres_database.upsert_tool_config.remote(
tool_name, config
)
self._tool_configs[tool_name] = saved["config"]
self._publish_snapshot()
return True
def get_tool_config(self, tool_name: str) -> Dict[str, Any]:
"""按工具名取出配置;不存在则返回空字典。"""
return dict(self._tool_configs.get(tool_name, {}))
def list_tool_configs(self) -> Dict[str, Dict[str, Any]]:
"""返回全部已配置工具的配置(包含敏感字段,调用方需自行脱敏)。"""
return dict(self._tool_configs)
async def delete_tool_config(self, tool_name: str) -> bool:
ok = await self.postgres_database.delete_tool_config_db.remote(tool_name)
self._tool_configs.pop(tool_name, None)
self._publish_snapshot()
return bool(ok)
# ─── Custom Toolset(用户自定义工具组)──────────────────
async def add_custom_toolset(
self,
toolset_id: str,
name: str,
tools: List[str],
description: Optional[str] = None,
owner_id: Optional[str] = None,
) -> Dict[str, Any]:
"""新增/更新一个自定义工具组:仅允许引用非 system/非 mcp 的工具。"""
# 校验:只能放第三方(非 system / 非 mcp)工具
invalid = [
t for t in tools if not self._global_tool_manager.is_third_party_tool(t)
]
if invalid:
raise ValueError(
f"自定义工具组只允许包含第三方工具,以下不合法:{invalid}"
)
saved = await self.postgres_database.upsert_custom_toolset.remote(
toolset_id=toolset_id,
name=name,
tools=list(tools),
description=description,
owner_id=owner_id,
)
self._custom_toolsets[toolset_id] = saved
self._global_tool_manager.rebuild_custom_toolsets(self._custom_toolsets)
self._publish_snapshot()
return saved
def list_custom_toolsets(self) -> List[Dict[str, Any]]:
return list(self._custom_toolsets.values())
def get_custom_toolset(self, toolset_id: str) -> Optional[Dict[str, Any]]:
return self._custom_toolsets.get(toolset_id)
async def delete_custom_toolset(self, toolset_id: str) -> bool:
ok = await self.postgres_database.delete_custom_toolset.remote(toolset_id)
self._custom_toolsets.pop(toolset_id, None)
self._global_tool_manager.rebuild_custom_toolsets(self._custom_toolsets)
self._publish_snapshot()
return bool(ok)
# ─── Skill ────────────────────────────────────────────────
# Skill Manager Methods
def add_skill(self, skill_name: str):
"""注册一个新的 Skill 名称到 Skill 注册表。"""
return self._global_skill_manager.add_skill(skill_name)
result = self._global_skill_manager.add_skill(skill_name)
self._publish_snapshot()
return result
def get_skill_list(self):
"""返回全部已注册的 Skill 名称。"""
return self._global_skill_manager.get_skill_list()
def remove_skill(self, skill_name: str):
"""从注册表中移除一个 Skill。"""
return self._global_skill_manager.remove_skill(skill_name)
result = self._global_skill_manager.remove_skill(skill_name)
self._publish_snapshot()
return result
# ─── Individual ───────────────────────────────────────────
# Individual Manager Methods
def add_individual(self, agent_id: str, config):
"""把一个 Worker Individual 的运行期配置加入注册表。"""
return self._global_individual_manager.add_individual(agent_id, config)
result = self._global_individual_manager.add_individual(agent_id, config)
self._publish_snapshot()
return result
def get_individual(self, agent_id: str):
"""按 agent_id 取出某个 Worker Individual 的配置。"""
return self._global_individual_manager.get_individual(agent_id)
def remove_individual(self, agent_id: str):
"""从注册表中移除一个 Worker Individual。"""
return self._global_individual_manager.remove_individual(agent_id)
result = self._global_individual_manager.remove_individual(agent_id)
self._publish_snapshot()
return result
def list_individuals(self):
"""返回当前注册的全部 Worker Individual 列表。"""
return self._global_individual_manager.list_individuals()
@@ -0,0 +1,184 @@
# Copyright 2026 zhaoxi826
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""GSM 快照对象与客户端拉取工具。
设计动机:把 GSM actor 内存里的"读路径配置"打包成可放进 Ray Object Store
的不可变快照。读端不再走 actor RPC,而是 ``fetch_snapshot()`` 一次拿到全量
当前配置,亚毫秒级共享内存读,绕开单 actor 的吞吐瓶颈。
GSM 仍然是 source of truth + 写入串行化器,但读路径解耦:
- 写入 → GSM 内存更新 → ``ray.put(snapshot)`` 拿到新 ObjectRef → 版本号 +1
- 读取 → ``current_config_ref()`` 拿 (version, ref) → ``ray.get(ref)`` 直读
旧的 ``get_provider / get_individual / ...`` 接口保留不动,是低频路径的兜底;
新代码(特别是 skill task 这种高并发热路径)应优先走 ``fetch_snapshot``。
"""
from __future__ import annotations
import asyncio
from dataclasses import dataclass, field
from typing import Any, Callable, Dict, List, Optional, Tuple
import ray
from kilostar.core.global_state_machine.model_provider.base_provider import Provider
from kilostar.utils.logger import get_logger
_logger = get_logger("gsm_snapshot")
@dataclass
class GSMSnapshot:
"""GSM 配置的不可变快照,所有字段都必须 cloudpickle 友好。
本类故意不放 ``FunctionToolset`` 实例 —— pydantic-ai toolset 的可序列化性
随版本可能变动,让 task 端按 ``tool_funcs`` + ``tool_mapper`` 自己装配
既隔离了 pydantic-ai 的实现风险,又让 snapshot 体积更小。
"""
version: int = 0
providers: Dict[str, Provider] = field(default_factory=dict)
individuals: Dict[str, Dict[str, Any]] = field(default_factory=dict)
mcp_servers: Dict[str, Dict[str, Any]] = field(default_factory=dict)
tool_configs: Dict[str, Dict[str, Any]] = field(default_factory=dict)
custom_toolsets: Dict[str, Dict[str, Any]] = field(default_factory=dict)
skills: Dict[str, Tuple[str, str]] = field(default_factory=dict)
tool_metadata: Dict[str, Dict[str, Any]] = field(default_factory=dict)
tool_funcs: Dict[str, Callable[..., Any]] = field(default_factory=dict)
third_party_funcs: Dict[str, Callable[..., Any]] = field(default_factory=dict)
tool_mapper: Dict[str, Dict[str, type]] = field(default_factory=dict)
# ``{scope: [tool_name, ...]}``:系统工具按 scope 维护的工具名清单。
# 客户端按名字 + ``tool_funcs`` 在自己进程里重建 FunctionToolset
# 避开把不可序列化/版本耦合的 toolset 实例塞进快照的坑。
system_tools_by_scope: Dict[str, List[str]] = field(default_factory=dict)
_local_cache: Dict[str, Any] = {"version": -1, "snapshot": None}
_cache_lock = asyncio.Lock()
async def fetch_snapshot(
*,
use_cache: bool = True,
gsm_actor: Optional[Any] = None,
) -> GSMSnapshot:
"""拉取当前 GSM 快照。
优先走"版本号检查 + ObjectRef 直读"路径:
1. 调 ``gsm.current_version.remote()`` 看本地缓存是否还新(一次轻量 RPC)
2. 若本地缓存版本号一致,直接返回缓存(亚毫秒,零网络)
3. 否则调 ``gsm.current_config_ref.remote()`` 拿 ref``ray.get`` 解出
4. 更新本地缓存
Args:
use_cache: 默认开启进程内 LRU 缓存(实际是单槽位,持有当前版本);
测试或诊断场景可关掉强制重拉。
gsm_actor: 可选传入 GSM actor handle;省略时通过 ``ray_actor_hook`` 获取。
Note:
本函数在 task / actor 进程内多次调用是廉价的;建议每次需要 config 时
现取,不要把 snapshot 长期持有跨任务边界(避免 ObjectRef 阻碍回收)。
"""
if gsm_actor is None:
from kilostar.utils.ray_hook import ray_actor_hook
gsm_actor = ray_actor_hook("global_state_machine").global_state_machine
if use_cache:
async with _cache_lock:
try:
latest_version = await gsm_actor.current_version.remote()
except Exception:
latest_version = None
if (
latest_version is not None
and _local_cache.get("version") == latest_version
and _local_cache.get("snapshot") is not None
):
return _local_cache["snapshot"]
version, ref = await gsm_actor.current_config_ref.remote()
snapshot = ray.get(ref)
_local_cache["version"] = version
_local_cache["snapshot"] = snapshot
return snapshot
version, ref = await gsm_actor.current_config_ref.remote()
return ray.get(ref)
def reset_local_cache() -> None:
"""清空本进程内的快照缓存(测试用)。"""
_local_cache["version"] = -1
_local_cache["snapshot"] = None
# ─── 客户端 helper:从快照重建本地视图 ─────────────────────────────
def build_toolsets_for_scope(
snapshot: GSMSnapshot, scope: str
) -> List[Any]:
"""在调用方进程里按 ``snapshot`` 现场组装 FunctionToolset 列表。
复刻 ``GlobalToolManager.get_toolsets_for_scope`` 的合并逻辑:
- 系统 toolset:按 ``default`` + ``scope`` 两个 bucket 拼装
- 自定义 toolset``custom_toolsets`` 里所有有效项
返回的 toolset 是 *进程局部* 的——pydantic-ai FunctionToolset 实例不能跨进程
共享,但函数对象本身已经躺在 snapshot 里被 cloudpickle 还原过,
重新 ``FunctionToolset(tools=[...])`` 几乎零代价。
"""
try:
from pydantic_ai.toolsets import FunctionToolset
except ImportError:
_logger.warning("pydantic_ai.toolsets unavailable; cannot build toolsets")
return []
result: List[Any] = []
for bucket in ("default", scope):
names = snapshot.system_tools_by_scope.get(bucket) or []
funcs = [snapshot.tool_funcs[n] for n in names if n in snapshot.tool_funcs]
if not funcs:
continue
try:
result.append(
FunctionToolset(tools=funcs, id=f"system::{bucket}")
)
except Exception as e: # pragma: no cover - 防御
_logger.error(f"build system toolset {bucket} failed: {e}")
for toolset_id, defn in snapshot.custom_toolsets.items():
names = defn.get("tools") or []
funcs = [
snapshot.third_party_funcs[n]
for n in names
if n in snapshot.third_party_funcs
]
if not funcs:
continue
try:
result.append(
FunctionToolset(tools=funcs, id=f"custom::{toolset_id}")
)
except Exception as e: # pragma: no cover - 防御
_logger.error(f"build custom toolset {toolset_id} failed: {e}")
return result
@@ -1,35 +1,42 @@
# Copyright 2026 zhaoxi826
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import pathlib
import importlib
import inspect
from collections import defaultdict
from typing import Any, Callable, Dict, List, Type
from kilostar.plugin.tool_plugin.base_tool import BaseToolData
from typing import Dict, Type
from kilostar.utils.logger import get_logger
logger = get_logger("tool_manager")
_SYSTEM_BUCKET = "system"
class GlobalToolManager:
"""工具注册表:扫描 ``kilostar/plugin/tool_plugin/`` 下所有 BaseToolData 子类,
按 ``action_scope`` 分桶到 ``tool_mapper[scope][plugin_name]``;无 scope 的归入 ``default``。"""
按 ``action_scope`` 打包成 ``FunctionToolset``。
三类 toolset
- **system**``is_system=True`` 的工具,按 scope 分组
- **custom**:用户自定义工具组(由 ``rebuild_custom_toolsets`` 动态构建)
- **mcp**:由 ``mcp_helper`` 独立管理,不经过本类
``category="mcp"`` 的工具不会被本类管理。
"""
tool_metadata: Dict[str, Dict[str, Any]]
_tool_funcs: Dict[str, Dict[str, Callable]]
_system_toolsets: Dict[str, Any]
_custom_toolsets: Dict[str, Any]
_third_party_funcs: Dict[str, Callable]
tool_mapper: Dict[str, Dict[str, Type[BaseToolData]]]
def __init__(self):
def __init__(self) -> None:
self.tool_metadata = {}
self._tool_funcs = defaultdict(dict)
self._system_toolsets = {}
self._custom_toolsets = {}
self._third_party_funcs = {}
self.tool_mapper = defaultdict(dict)
tool_plugin_dir = (
@@ -39,21 +46,154 @@ class GlobalToolManager:
return
for item in tool_plugin_dir.iterdir():
if item.is_dir() and not item.name.startswith("__"):
plugin_name = item.name
module_name = f"kilostar.plugin.tool_plugin.{plugin_name}"
if not (item.is_dir() and not item.name.startswith("__")):
continue
plugin_name = item.name
module_name = f"kilostar.plugin.tool_plugin.{plugin_name}"
try:
module = importlib.import_module(module_name)
for name, obj in inspect.getmembers(module, inspect.isclass):
if issubclass(obj, BaseToolData) and obj is not BaseToolData:
# It's a valid tool class
action_scopes = obj.model_fields.get("action_scope").default
try:
module = importlib.import_module(module_name)
except Exception as e:
logger.warning(f"Failed to import tool plugin {plugin_name}: {e}")
continue
if not action_scopes:
self.tool_mapper["default"][plugin_name] = obj
else:
for scope in action_scopes:
self.tool_mapper[scope][plugin_name] = obj
except Exception as e:
logger.warning(f"Failed to load tool plugin {plugin_name}: {e}")
tool_data_cls = self._find_tool_data_class(module)
if tool_data_cls is None:
continue
tool_func = getattr(module, plugin_name, None)
if not callable(tool_func):
logger.warning(
f"Tool plugin '{plugin_name}' has no callable named "
f"'{plugin_name}' in its module; skipped."
)
continue
action_scopes = (
tool_data_cls.model_fields.get("action_scope").default or []
)
is_system = bool(tool_data_cls.model_fields.get("is_system").default)
category_field = tool_data_cls.model_fields.get("category")
category = (category_field.default if category_field else "other") or "other"
self.tool_metadata[plugin_name] = {
"name": plugin_name,
"is_system": is_system,
"category": category,
"action_scope": list(action_scopes),
}
if category == "mcp":
continue
scopes = [s for s in action_scopes if s] or ["default"]
if is_system:
for scope in scopes:
self._tool_funcs[scope][plugin_name] = tool_func
self.tool_mapper[scope][plugin_name] = tool_data_cls
else:
self._third_party_funcs[plugin_name] = tool_func
for scope in scopes:
self.tool_mapper[scope][plugin_name] = tool_data_cls
self._build_system_toolsets()
def _build_system_toolsets(self) -> None:
FunctionToolset = self._import_function_toolset()
if FunctionToolset is None:
return
for scope, name_to_func in self._tool_funcs.items():
if not name_to_func:
continue
try:
self._system_toolsets[scope] = FunctionToolset(
tools=list(name_to_func.values()),
id=f"system::{scope}",
)
except Exception as e:
logger.error(f"Failed to build system toolset {scope}: {e}")
def rebuild_custom_toolsets(self, custom_defs: Dict[str, Dict[str, Any]]) -> None:
"""根据 DB 中的自定义工具组定义重建 custom FunctionToolset。"""
FunctionToolset = self._import_function_toolset()
if FunctionToolset is None:
self._custom_toolsets = {}
return
new_map: Dict[str, Any] = {}
for toolset_id, defn in custom_defs.items():
tools_names = defn.get("tools") or []
funcs = [
self._third_party_funcs[n]
for n in tools_names
if n in self._third_party_funcs
]
if not funcs:
continue
try:
new_map[toolset_id] = FunctionToolset(
tools=funcs,
id=f"custom::{toolset_id}",
)
except Exception as e:
logger.error(f"Failed to build custom toolset {toolset_id}: {e}")
self._custom_toolsets = new_map
@staticmethod
def _import_function_toolset():
try:
from pydantic_ai.toolsets import FunctionToolset
return FunctionToolset
except ImportError:
logger.warning("pydantic_ai.toolsets unavailable")
return None
@staticmethod
def _find_tool_data_class(module) -> Type[BaseToolData] | None:
for _, obj in inspect.getmembers(module, inspect.isclass):
if issubclass(obj, BaseToolData) and obj is not BaseToolData:
return obj
return None
# ─── Toolset accessors ───
def get_system_toolset(self, scope: str) -> Any | None:
return self._system_toolsets.get(scope)
def get_toolsets_for_scope(self, scope: str) -> List[Any]:
"""合并 systemdefault + scope+ 全部 custom toolset。"""
result: List[Any] = []
for s in ("default", scope):
ts = self._system_toolsets.get(s)
if ts is not None:
result.append(ts)
result.extend(self._custom_toolsets.values())
return result
# ─── Metadata accessors ───
def is_third_party_tool(self, tool_name: str) -> bool:
return tool_name in self._third_party_funcs
def get_tools_by_category(self, category: str) -> List[Dict[str, Any]]:
return [m for m in self.tool_metadata.values() if m.get("category") == category]
def get_system_tools(self) -> List[Dict[str, Any]]:
return [m for m in self.tool_metadata.values() if m.get("is_system") is True]
def get_third_party_tools(self) -> List[Dict[str, Any]]:
return [
m
for m in self.tool_metadata.values()
if m.get("is_system") is not True and m.get("category") != "mcp"
]
def get_all_tools(self) -> List[Dict[str, Any]]:
return list(self.tool_metadata.values())
# 兼容旧接口
def get_non_system_tools(self) -> List[Dict[str, Any]]:
return self.get_third_party_tools()
def get_personal_tools(self) -> List[Dict[str, Any]]:
return self.get_third_party_tools()
@@ -29,6 +29,7 @@ from kilostar.core.global_state_machine.global_state_machine import GlobalStateM
from kilostar.core.global_state_machine.model_provider.base_provider import Provider
from kilostar.adapter.model_adapter.agent_factory import AgentFactory
from kilostar.utils.ray_hook import ray_actor_hook
from kilostar.utils.i18n import agent_prompt
@ray.remote
@@ -45,22 +46,19 @@ class ConsciousnessNode:
provider_title: str,
model_id: str,
tools_list: list[str] = None,
toolsets=None,
locale: str | None = None,
) -> None:
system_prompt: str = (
"你叫kilostar,是一个多智能体AI助手系统中的【意识节点 (Consciousness Node)】。\n"
"你是系统的'高级规划师''架构师',负责处理监控节点分配过来的复杂任务。\n"
"你的主要工作场景包括:\n"
"1. 拆解任务 (Workflow Generation):结合用户的原始命令和提供的模板,生成严谨、可执行的工作流 (kilostarWorkflow),并将其输出为 ForWorkflowEngine 格式。拆解时步骤应清晰连贯。\n"
"2. 中途指导 (Workflow Execution):在工作流执行中,如果某一步骤指派给你,你需要对控制节点的结果进行分析或提供下一步的指导,输出 ForWorkflow 格式。\n"
"3. 总结报告 (regulatory Report):在整个工作流执行完毕后,你需要对整体流程、各个控制节点的执行情况进行审查,并生成一份技术性的总结报告,输出 ForregulatoryNode 格式。\n"
"请确保所有的思考和生成过程符合逻辑,严密且高质量。"
)
system_prompt: str = agent_prompt("consciousness_node", locale=locale)
output_type = Union[ForregulatoryNode, ForWorkflow, ForWorkflowEngine]
from kilostar.utils.get_tool import load_tools_from_list
from kilostar.core.global_state_machine.gsm_snapshot import fetch_snapshot
provider: Provider = await global_state_machine.get_provider.remote(
provider_title
)
snapshot = await fetch_snapshot(gsm_actor=global_state_machine)
provider: Provider = snapshot.providers.get(provider_title)
if provider is None:
from kilostar.utils.i18n import t
raise ValueError(t("provider_not_registered", locale=locale, provider_title=provider_title))
agent_factory = AgentFactory()
callables = load_tools_from_list(tools_list)
@@ -72,6 +70,7 @@ class ConsciousnessNode:
deps_type=ConsciousnessNodeDeps,
agent_name="consciousness_node",
tools=callables,
toolsets=toolsets,
)
@self.agent.system_prompt
@@ -95,6 +94,13 @@ class ConsciousnessNode:
开始进行工作流设计的交互过程(与用户通过 SSE 进行确认或直接生成)
目前简化为:直接根据 command 拆解并构建工作流,然后提交执行。
"""
from kilostar.utils.request_context import trace_id_scope
# 进入工作流域:把 trace_id 绑到 contextvars,本协程所有日志自动带上它
with trace_id_scope(trace_id):
await self._do_start_workflow_design(trace_id, command)
async def _do_start_workflow_design(self, trace_id: str, command: str):
self.logger.info(
f"ConsciousnessNode: 开始为 trace_id {trace_id} 设计工作流。原始命令:{command}"
)
@@ -116,11 +122,11 @@ class ConsciousnessNode:
original_command=command, available_skills=available_skills
)
# 通知 SSE 正在生成图结构
# 通知 SSE 正在生成图结构(pending 队列:节点端写入 → API SSE 读取,单向下行)
global_workflow_manager = ray_actor_hook(
"global_workflow_manager"
).global_workflow_manager
await global_workflow_manager.put_received.remote(
await global_workflow_manager.put_pending.remote(
trace_id, "正在为您构建并规划工作流任务节点,请稍候..."
)
@@ -131,17 +137,17 @@ class ConsciousnessNode:
workflow = result.workflow
workflow.trace_id = trace_id
await global_workflow_manager.put_received.remote(
await global_workflow_manager.put_pending.remote(
trace_id, "工作流构建完成,即将开始执行!"
)
# 将生成的完整工作流提交执行
workflow_engine = ray_actor_hook(
"workflow_running_engine"
).workflow_running_engine
await workflow_engine.execute_workflow.remote(workflow)
# 直接以 ray task 形式 fire workflow,不再经过 WorkflowRunningEngine 这层中转:
# workflow 是一次性、有头有尾的执行,task 语义比常驻 actor 更贴。
from kilostar.core.work.workflow.workflow_engine import run_workflow_task
run_workflow_task.remote(workflow.model_dump(), trace_id)
else:
await global_workflow_manager.put_received.remote(
await global_workflow_manager.put_pending.remote(
trace_id, "很抱歉,工作流生成失败。"
)
await postgres_database.update_workflow_status.remote(trace_id, "failed")
@@ -22,6 +22,7 @@ from kilostar.core.individual.control_node.template import (
ForWorkflowInput,
ControlNodeDeps,
)
from kilostar.utils.i18n import agent_prompt
@ray.remote
@@ -44,6 +45,8 @@ class ControlNode:
provider_title: str,
model_id: str,
tools_list: list[str] = None,
toolsets=None,
locale: str | None = None,
) -> None:
"""
create_agent方法,将agent对象装配到Control的属性内
@@ -54,25 +57,21 @@ class ControlNode:
global_state_machine: 全局状态机
provider_title: 供应商名
model_id: 模型id
locale: 语言代码(zh/en),控制system prompt语言
Returns:
无返回
"""
system_prompt: str = (
"你叫kilostar,是一个多智能体AI助手系统中的【控制节点 (Control Node)】。\n"
"你是系统的'执行者''车间主任',专门负责执行工作流中分配给你的具体子任务。\n"
"你的工作职责是:\n"
"1. 仔细分析分配给你的工作流步骤 (workflow_step) 的目标和要求。\n"
"2. 运用你被分配的工具 (如有) 或者依靠自身的知识和推理能力,精准、高效地完成该任务。\n"
"3. 将执行的结果、产生的数据或者具体的输出,严格按照 ForWorkflow 格式返回。\n"
"请注意:你的输出应当具体、实用,直接提供任务所要求的结果,不要做过多无关的寒暄。"
)
system_prompt: str = agent_prompt("control_node", locale=locale)
output_type = ForWorkflow
from kilostar.utils.get_tool import load_tools_from_list
from kilostar.core.global_state_machine.gsm_snapshot import fetch_snapshot
provider: Provider = await global_state_machine.get_provider.remote(
provider_title
)
snapshot = await fetch_snapshot(gsm_actor=global_state_machine)
provider: Provider = snapshot.providers.get(provider_title)
if provider is None:
from kilostar.utils.i18n import t
raise ValueError(t("provider_not_registered", locale=locale, provider_title=provider_title))
agent_factory = AgentFactory()
callables = load_tools_from_list(tools_list)
@@ -84,6 +83,7 @@ class ControlNode:
deps_type=ControlNodeDeps,
agent_name="control_node",
tools=callables,
toolsets=toolsets,
)
@self.agent.system_prompt
@@ -24,6 +24,7 @@ from kilostar.core.individual.regulatory_node.template import (
MessageResponse
)
from pydantic_ai import RunContext, Agent
from kilostar.utils.i18n import agent_prompt
@ray.remote
@@ -46,6 +47,8 @@ class RegulatoryNode:
provider_title: str,
model_id: str,
tools_list: list[str] = None,
toolsets=None,
locale: str | None = None,
) -> None:
"""
create_agent方法,将agent对象装配到regulatoryNode的属性内
@@ -56,24 +59,21 @@ class RegulatoryNode:
provider_title: 供应商名
model_id: 模型id
tools_list: 工具列表
locale: 语言代码(zh/en),控制system prompt语言
Returns:
无返回
"""
system_prompt: str = (
"你叫kilostar,是一个多智能体AI助手系统中的【监控节点 (regulatory Node)】。\n"
"你是系统的'前台接待''大脑皮层',负责接收用户的初始请求或工作流的最终报告。\n"
"你的核心职责是进行【意图识别与路由】。请仔细阅读用户的请求:\n"
"1. 如果用户只是进行简单的问候、闲聊或查询非常基础的信息,请直接生成友好的回复,使用 ForUser 格式。\n"
"2. 如果用户提出的是复杂任务(如需要编写代码、多步骤规划、数据处理等),请务必将其判定为需要工作流处理的任务,"
" 并使用 ForConsciousnessNode 格式将其移交意识节点处理。\n"
"3. 如果你收到的是 TerminationMessage(代表工作流已完成并生成了报告),请将报告内容转化为友好的面向用户的回复,使用 ForUser 格式。\n"
"请保持冷静、专业,并严格遵循上述路由规则。"
)
system_prompt: str = agent_prompt("regulatory_node", locale=locale)
output_type = Union[MessageResponse]
from kilostar.utils.get_tool import load_tools_from_list
provider: Provider = await global_state_machine.get_provider.remote(
provider_title
)
from kilostar.core.global_state_machine.gsm_snapshot import fetch_snapshot
# 走 Object Store 快照而不是 actor RPC:高频读路径不再受单 actor 串行限制
snapshot = await fetch_snapshot(gsm_actor=global_state_machine)
provider: Provider = snapshot.providers.get(provider_title)
if provider is None:
from kilostar.utils.i18n import t
raise ValueError(t("provider_not_registered", locale=locale, provider_title=provider_title))
agent_factory = AgentFactory()
callables = load_tools_from_list(tools_list)
@@ -85,6 +85,7 @@ class RegulatoryNode:
deps_type=RegulatoryNodeDeps,
agent_name="regulatory_node",
tools=callables,
toolsets=toolsets,
)
@self.agent.system_prompt
@@ -112,16 +113,15 @@ class RegulatoryNode:
)
return prompt
async def working(self, payload: MessageRequest) -> str:
"""working方法,是节点唯一的调用方法,对_run函数的结果进行判断并实现最终回复
async def working(self, payload: MessageRequest) -> Union[MessageResponse, None]:
"""working方法,是节点唯一的调用方法,对_run函数的结果进行判断并返回最终回复
Args:
payload: 消息载荷,包含所有信息
Returns:
str,监控节点对用户的回复
MessageResponse 或 None监控节点对用户的结构化回复
"""
await self._run(payload)
return ""
return await self._run(payload)
async def _run(
self, payload: MessageRequest
@@ -140,7 +140,8 @@ class RegulatoryNode:
deps=deps,)
response: MessageResponse = agent_response.output
response.platform = platform
response.platform_id = MessageRequest.platform_id
response.platform_id = payload.platform_id
return response
except:
pass
except Exception as e:
self.logger.exception(f"RegulatoryNode._run failed: {e}")
return None
@@ -49,7 +49,7 @@ class MessageRequest(RequestModel):
MessageRequest类
任何消息渠道向regulatory_node发送消息请求的模型
"""
platform: Literal["client"]
platform: Literal["client", "onebot"]
user_name: str
platform_id: Optional[str]
message: str
@@ -59,6 +59,6 @@ class MessageResponse(RegulatoryNodeResponse):
MessageResponse类
regulatory_node回复的模型
"""
platform: Optional[Literal["client"]] = Field(description="系统自动填入的platform")
platform: Optional[Literal["client", "onebot"]] = Field(description="系统自动填入的platform")
platform_id: Optional[str] = Field(description="系统自动填入的platform_id")
reply_message: str = Field(...,description="模型回复的消息")
@@ -23,12 +23,16 @@ from kilostar.core.postgres_database.model.individual import (
from kilostar.core.postgres_database.model.workflow import (
Workflow,
WorkflowContextModel,
WorkflowGraphStateModel,
)
from kilostar.core.postgres_database.model.chat_history import (
ChatHistoryRegister,
ChatHistoryMessage,
)
from kilostar.core.postgres_database.model.system_node import SystemNodeConfigModel
from kilostar.core.postgres_database.model.mcp_server import MCPServerModel
from kilostar.core.postgres_database.model.tool_config import ToolConfigModel
from kilostar.core.postgres_database.model.custom_toolset import CustomToolsetModel
# 兼容旧代码的别名
Provider = ProviderModel
@@ -49,9 +53,13 @@ __all__ = [
"SpecialIndividualModel",
"Workflow",
"WorkflowContextModel",
"WorkflowGraphStateModel",
"ChatHistoryRegister",
"ChatHistoryMessage",
"SystemNodeConfigModel",
"SystemNodeConfig",
"MCPServerModel",
"ToolConfigModel",
"CustomToolsetModel",
"AgentType",
]
@@ -0,0 +1,32 @@
from datetime import datetime
from typing import List, Optional
from sqlalchemy import String, Text, DateTime, func
from sqlalchemy.dialects.postgresql import JSONB
from sqlalchemy.orm import Mapped, mapped_column
from .base import BaseDataModel
class CustomToolsetModel(BaseDataModel):
"""用户自定义工具组:把若干个非 system / 非 mcp 的工具插件打包成一个 toolset。
``tools`` 字段保存工具名列表(即 ``plugin/tool_plugin/`` 下的目录名);
GSM 启动时按列表把对应工具函数装进同一个 ``FunctionToolset``。
"""
__tablename__ = "custom_toolset"
toolset_id: Mapped[str] = mapped_column(String(64), primary_key=True)
name: Mapped[str] = mapped_column(String(100), nullable=False)
description: Mapped[Optional[str]] = mapped_column(Text)
owner_id: Mapped[Optional[str]] = mapped_column(String(64), index=True)
tools: Mapped[List[str]] = mapped_column(
JSONB, default=list, comment="工具名列表,仅允许非 system/非 mcp 的工具"
)
created_at: Mapped[datetime] = mapped_column(
DateTime(timezone=True), server_default=func.now()
)
updated_at: Mapped[datetime] = mapped_column(
DateTime(timezone=True), server_default=func.now(), onupdate=func.now()
)
@@ -0,0 +1,29 @@
from typing import Optional
from datetime import datetime
from sqlalchemy import String, DateTime, func
from sqlalchemy.dialects.postgresql import JSONB
from sqlalchemy.orm import Mapped, mapped_column
from .base import BaseDataModel
class MCPServerModel(BaseDataModel):
"""MCP 服务器注册表,记录 stdio/sse/http 三种 transport 的连接配置。"""
__tablename__ = "mcp_server"
server_id: Mapped[str] = mapped_column(String(64), primary_key=True)
name: Mapped[str] = mapped_column(String(100), nullable=False)
transport: Mapped[str] = mapped_column(String(16), nullable=False)
command: Mapped[Optional[str]] = mapped_column(String(255))
args: Mapped[list] = mapped_column(JSONB, default=list)
url: Mapped[Optional[str]] = mapped_column(String(500))
tool_prefix: Mapped[Optional[str]] = mapped_column(String(64))
env: Mapped[dict] = mapped_column(JSONB, default=dict, comment="敏感字段已 Fernet 加密")
created_at: Mapped[datetime] = mapped_column(
DateTime(timezone=True), server_default=func.now()
)
updated_at: Mapped[datetime] = mapped_column(
DateTime(timezone=True), server_default=func.now(), onupdate=func.now()
)
@@ -0,0 +1,22 @@
from datetime import datetime
from sqlalchemy import String, DateTime, func
from sqlalchemy.dialects.postgresql import JSONB
from sqlalchemy.orm import Mapped, mapped_column
from .base import BaseDataModel
class ToolConfigModel(BaseDataModel):
"""工具运行期配置(如 Tavily API key);config 内的敏感字段已 Fernet 加密。"""
__tablename__ = "tool_config"
tool_name: Mapped[str] = mapped_column(String(100), primary_key=True)
config: Mapped[dict] = mapped_column(JSONB, default=dict)
created_at: Mapped[datetime] = mapped_column(
DateTime(timezone=True), server_default=func.now()
)
updated_at: Mapped[datetime] = mapped_column(
DateTime(timezone=True), server_default=func.now(), onupdate=func.now()
)
@@ -79,3 +79,28 @@ class WorkflowContextModel(BaseDataModel):
updated_at: Mapped[str] = mapped_column(
DateTime(timezone=True), server_default=func.now(), onupdate=func.now()
)
class WorkflowGraphStateModel(BaseDataModel):
"""pydantic_graph 持久化 blob 的存储表。
与 ``workflow_context`` 解耦——后者面向"业务展示 / 用户可读",前者面向
"graph 引擎自身的状态恢复"。一份 trace_id 一行,jsonb 直接存 history 全量。
"""
__tablename__ = "workflow_graph_state"
trace_id: Mapped[str] = mapped_column(
String(64), primary_key=True, comment="对应的工作流 Trace ID"
)
history: Mapped[list] = mapped_column(
JSONB,
default=list,
comment="pydantic_graph FullStatePersistence.history 的 JSON 序列化",
)
created_at: Mapped[str] = mapped_column(
DateTime(timezone=True), server_default=func.now()
)
updated_at: Mapped[str] = mapped_column(
DateTime(timezone=True), server_default=func.now(), onupdate=func.now()
)
@@ -0,0 +1,81 @@
from typing import Any, Dict, List, Optional
from sqlalchemy import delete, select
from kilostar.core.postgres_database.database_exception import database_exception
from kilostar.core.postgres_database.model.custom_toolset import CustomToolsetModel
class CustomToolsetDatabase:
"""用户自定义工具组 DAO。``tools`` 字段是工具名列表,业务层负责保证只放非 system/非 mcp 的工具。"""
def __init__(self, async_session_maker):
self.async_session_maker = async_session_maker
@staticmethod
def _to_dict(row: CustomToolsetModel) -> Dict[str, Any]:
return {
"toolset_id": row.toolset_id,
"name": row.name,
"description": row.description,
"owner_id": row.owner_id,
"tools": list(row.tools or []),
}
@database_exception
async def upsert(
self,
toolset_id: str,
name: str,
tools: List[str],
description: Optional[str] = None,
owner_id: Optional[str] = None,
) -> Dict[str, Any]:
async with self.async_session_maker() as session:
stmt = select(CustomToolsetModel).where(
CustomToolsetModel.toolset_id == toolset_id
)
row = (await session.execute(stmt)).scalar_one_or_none()
if row:
row.name = name
row.description = description
row.owner_id = owner_id
row.tools = list(tools)
else:
row = CustomToolsetModel(
toolset_id=toolset_id,
name=name,
description=description,
owner_id=owner_id,
tools=list(tools),
)
session.add(row)
await session.commit()
await session.refresh(row)
return self._to_dict(row)
@database_exception
async def get(self, toolset_id: str) -> Optional[Dict[str, Any]]:
async with self.async_session_maker() as session:
stmt = select(CustomToolsetModel).where(
CustomToolsetModel.toolset_id == toolset_id
)
row = (await session.execute(stmt)).scalar_one_or_none()
return self._to_dict(row) if row else None
@database_exception
async def list_all(self) -> List[Dict[str, Any]]:
async with self.async_session_maker() as session:
stmt = select(CustomToolsetModel)
rows = (await session.execute(stmt)).scalars().all()
return [self._to_dict(r) for r in rows]
@database_exception
async def delete(self, toolset_id: str) -> bool:
async with self.async_session_maker() as session:
stmt = delete(CustomToolsetModel).where(
CustomToolsetModel.toolset_id == toolset_id
)
result = await session.execute(stmt)
await session.commit()
return result.rowcount > 0
@@ -0,0 +1,79 @@
from typing import Any, Dict, List, Optional
from sqlalchemy import delete, select
from kilostar.core.postgres_database.database_exception import database_exception
from kilostar.core.postgres_database.model.mcp_server import MCPServerModel
from kilostar.utils.crypto import decrypt_dict_secrets, encrypt_dict_secrets
class MCPServerDatabase:
"""MCP 服务器配置 DAO;写入前自动加密 ``env`` 中的敏感字段,读出后自动解密。"""
def __init__(self, async_session_maker):
self.async_session_maker = async_session_maker
@staticmethod
def _to_dict(row: MCPServerModel) -> Dict[str, Any]:
return {
"server_id": row.server_id,
"name": row.name,
"transport": row.transport,
"command": row.command,
"args": row.args or [],
"url": row.url,
"tool_prefix": row.tool_prefix,
"env": decrypt_dict_secrets(row.env or {}),
}
@database_exception
async def upsert(self, server_id: str, config: Dict[str, Any]) -> Dict[str, Any]:
env = encrypt_dict_secrets(config.get("env") or {})
async with self.async_session_maker() as session:
stmt = select(MCPServerModel).where(MCPServerModel.server_id == server_id)
row = (await session.execute(stmt)).scalar_one_or_none()
if row:
row.name = config.get("name", row.name)
row.transport = config.get("transport", row.transport)
row.command = config.get("command")
row.args = config.get("args") or []
row.url = config.get("url")
row.tool_prefix = config.get("tool_prefix")
row.env = env
else:
row = MCPServerModel(
server_id=server_id,
name=config.get("name", server_id),
transport=config.get("transport", "stdio"),
command=config.get("command"),
args=config.get("args") or [],
url=config.get("url"),
tool_prefix=config.get("tool_prefix"),
env=env,
)
session.add(row)
await session.commit()
await session.refresh(row)
return self._to_dict(row)
@database_exception
async def get(self, server_id: str) -> Optional[Dict[str, Any]]:
async with self.async_session_maker() as session:
stmt = select(MCPServerModel).where(MCPServerModel.server_id == server_id)
row = (await session.execute(stmt)).scalar_one_or_none()
return self._to_dict(row) if row else None
@database_exception
async def list_all(self) -> List[Dict[str, Any]]:
async with self.async_session_maker() as session:
stmt = select(MCPServerModel)
rows = (await session.execute(stmt)).scalars().all()
return [self._to_dict(r) for r in rows]
@database_exception
async def delete(self, server_id: str) -> bool:
async with self.async_session_maker() as session:
stmt = delete(MCPServerModel).where(MCPServerModel.server_id == server_id)
result = await session.execute(stmt)
await session.commit()
return result.rowcount > 0
@@ -17,10 +17,37 @@ from typing import List
from kilostar.core.postgres_database.model.provider import ProviderModel
from sqlalchemy import select
from kilostar.core.postgres_database.database_exception import database_exception
from kilostar.utils.crypto import (
CryptoError,
decrypt_secret,
encrypt_secret,
is_encrypted,
)
from kilostar.utils.logger import get_logger
logger = get_logger("provider_dao")
def _decrypt_apikey(value):
if not value:
return value
if not is_encrypted(value):
return value
try:
return decrypt_secret(value)
except CryptoError as e:
logger.error(f"Provider apikey 解密失败: {e}")
return value
def _encrypt_apikey(value):
if not value or is_encrypted(value):
return value
return encrypt_secret(value)
class ProviderDatabase:
"""Provider 表的 DAO:模型 Provider 的增删查改。"""
"""Provider 表的 DAO:模型 Provider 的增删查改``provider_apikey`` 透明 Fernet 加解密"""
def __init__(self, async_session_maker):
self.async_session_maker = async_session_maker
@@ -37,11 +64,10 @@ class ProviderDatabase:
provider_id=provider.provider_id,
provider_title=provider.provider_title,
provider_url=provider.provider_url,
provider_apikey=provider.provider_apikey,
provider_apikey=_decrypt_apikey(provider.provider_apikey),
provider_models=provider.provider_models,
provider_type=provider.provider_type,
provider_owner=provider.provider_owner,
provider_status=provider.provider_status,
is_active=provider.is_active,
)
for provider in results
@@ -50,7 +76,9 @@ class ProviderDatabase:
@database_exception
async def add_provider(self, **kwargs) -> None:
"""新建一条 Provider 记录;字段通过 kwargs 直接传给 ProviderModel"""
"""新建一条 Provider 记录;``provider_apikey`` 写入前自动加密"""
if "provider_apikey" in kwargs:
kwargs["provider_apikey"] = _encrypt_apikey(kwargs["provider_apikey"])
async with self.async_session_maker() as session:
provider = ProviderModel(**kwargs)
session.add(provider)
@@ -67,7 +95,9 @@ class ProviderDatabase:
@database_exception
async def update_provider(self, provider_id: str, **kwargs) -> None:
"""部分更新指定 Provider 的字段;不存在时返回 None,否则返回刷新后的对象"""
"""部分更新指定 Provider 的字段;``provider_apikey`` 写入前自动加密"""
if "provider_apikey" in kwargs:
kwargs["provider_apikey"] = _encrypt_apikey(kwargs["provider_apikey"])
async with self.async_session_maker() as session:
provider = await session.get(ProviderModel, provider_id)
if provider is not None:
@@ -76,5 +106,7 @@ class ProviderDatabase:
session.add(provider)
await session.commit()
await session.refresh(provider)
# 解密返回,方便上游使用
provider.provider_apikey = _decrypt_apikey(provider.provider_apikey)
return provider
return None
@@ -0,0 +1,58 @@
from typing import Any, Dict, List, Optional
from sqlalchemy import delete, select
from kilostar.core.postgres_database.database_exception import database_exception
from kilostar.core.postgres_database.model.tool_config import ToolConfigModel
from kilostar.utils.crypto import decrypt_dict_secrets, encrypt_dict_secrets
class ToolConfigDatabase:
"""工具运行期配置 DAO;config 中的敏感字段(key/token/secret/password 系列)自动加解密。"""
def __init__(self, async_session_maker):
self.async_session_maker = async_session_maker
@staticmethod
def _to_dict(row: ToolConfigModel) -> Dict[str, Any]:
return {
"tool_name": row.tool_name,
"config": decrypt_dict_secrets(row.config or {}),
}
@database_exception
async def upsert(self, tool_name: str, config: Dict[str, Any]) -> Dict[str, Any]:
encrypted = encrypt_dict_secrets(config or {})
async with self.async_session_maker() as session:
stmt = select(ToolConfigModel).where(ToolConfigModel.tool_name == tool_name)
row = (await session.execute(stmt)).scalar_one_or_none()
if row:
row.config = encrypted
else:
row = ToolConfigModel(tool_name=tool_name, config=encrypted)
session.add(row)
await session.commit()
await session.refresh(row)
return self._to_dict(row)
@database_exception
async def get(self, tool_name: str) -> Optional[Dict[str, Any]]:
async with self.async_session_maker() as session:
stmt = select(ToolConfigModel).where(ToolConfigModel.tool_name == tool_name)
row = (await session.execute(stmt)).scalar_one_or_none()
return self._to_dict(row) if row else None
@database_exception
async def list_all(self) -> List[Dict[str, Any]]:
async with self.async_session_maker() as session:
stmt = select(ToolConfigModel)
rows = (await session.execute(stmt)).scalars().all()
return [self._to_dict(r) for r in rows]
@database_exception
async def delete(self, tool_name: str) -> bool:
async with self.async_session_maker() as session:
stmt = delete(ToolConfigModel).where(ToolConfigModel.tool_name == tool_name)
result = await session.execute(stmt)
await session.commit()
return result.rowcount > 0
@@ -17,6 +17,7 @@ from typing import List, Optional
from kilostar.core.postgres_database.model.workflow import (
Workflow,
WorkflowContextModel,
WorkflowGraphStateModel,
)
from sqlalchemy.ext.asyncio import async_sessionmaker, AsyncSession
from kilostar.core.postgres_database.database_exception import database_exception
@@ -101,3 +102,58 @@ class WorkflowDatabase:
)
results = await session.execute(statement)
return results.scalar_one_or_none()
# ─── pydantic_graph 持久化(resume 用)─────────────────────────────
@database_exception
async def upsert_workflow_graph_state(
self, trace_id: str, history: list
) -> WorkflowGraphStateModel:
"""落 pydantic_graph FullStatePersistence.history 的 JSON 视图。
每个节点边界都会被引擎调一次,覆盖式写入;回滚到任一历史点是 graph
引擎自身的能力,DB 这层只保留最新版本。
"""
async with self.async_session_maker() as session:
statement = select(WorkflowGraphStateModel).where(
WorkflowGraphStateModel.trace_id == trace_id
)
results = await session.execute(statement)
record = results.scalar_one_or_none()
if record:
record.history = history
else:
record = WorkflowGraphStateModel(
trace_id=trace_id, history=history
)
session.add(record)
await session.commit()
await session.refresh(record)
return record
@database_exception
async def get_workflow_graph_state(
self, trace_id: str
) -> Optional[WorkflowGraphStateModel]:
"""读取 graph 持久化 history;不存在返回 ``None``。"""
async with self.async_session_maker() as session:
statement = select(WorkflowGraphStateModel).where(
WorkflowGraphStateModel.trace_id == trace_id
)
results = await session.execute(statement)
return results.scalar_one_or_none()
@database_exception
async def delete_workflow_graph_state(self, trace_id: str) -> bool:
"""删除某个工作流的 graph 持久化记录(用于显式清理)。"""
async with self.async_session_maker() as session:
statement = select(WorkflowGraphStateModel).where(
WorkflowGraphStateModel.trace_id == trace_id
)
results = await session.execute(statement)
record = results.scalar_one_or_none()
if record is None:
return False
await session.delete(record)
await session.commit()
return True
+112
View File
@@ -38,6 +38,9 @@ from kilostar.core.postgres_database.model.chat_history import (
ChatHistoryMessage,
)
from kilostar.core.postgres_database.model.system_node import SystemNodeConfigModel
from kilostar.core.postgres_database.model.mcp_server import MCPServerModel
from kilostar.core.postgres_database.model.tool_config import ToolConfigModel
from kilostar.core.postgres_database.model.custom_toolset import CustomToolsetModel
from .module.individual import IndividualDatabase
from .module.user import AuthDatabase
@@ -45,6 +48,9 @@ from .module.provider import ProviderDatabase
from .module.system_node import SystemNodeDatabase
from .module.workflow import WorkflowDatabase
from .module.chat_history import ChatHistoryDatabase
from .module.mcp_server import MCPServerDatabase
from .module.tool_config import ToolConfigDatabase
from .module.custom_toolset import CustomToolsetDatabase
@ray.remote
@@ -76,6 +82,9 @@ class PostgresDatabase:
self._system_node_database = SystemNodeDatabase(self.async_session_maker)
self._workflow_database = WorkflowDatabase(self.async_session_maker)
self._chat_history_database = ChatHistoryDatabase(self.async_session_maker)
self._mcp_server_database = MCPServerDatabase(self.async_session_maker)
self._tool_config_database = ToolConfigDatabase(self.async_session_maker)
self._custom_toolset_database = CustomToolsetDatabase(self.async_session_maker)
self.ready_event = asyncio.Event()
@@ -91,6 +100,15 @@ class PostgresDatabase:
finally:
self.ready_event.set()
async def ping(self) -> bool:
"""轻量探活:等待 ready 后执行 ``SELECT 1``。"""
from sqlalchemy import text
await self.ready_event.wait()
async with self.async_engine.connect() as conn:
await conn.execute(text("SELECT 1"))
return True
# Auth Database Methods
async def add_user(self, user_name: str, hashed_password: str):
"""新建一名用户。"""
@@ -242,6 +260,24 @@ class PostgresDatabase:
await self.ready_event.wait()
return await self._workflow_database.get_workflow_context(trace_id)
# Workflow Graph State (pydantic_graph 持久化)
async def upsert_workflow_graph_state(self, trace_id: str, history: list):
"""覆盖式写入 graph 持久化 historypydantic_graph 节点边界自动调用)。"""
await self.ready_event.wait()
return await self._workflow_database.upsert_workflow_graph_state(
trace_id, history
)
async def get_workflow_graph_state(self, trace_id: str):
"""读取 graph 持久化记录,用于跨进程 resume。"""
await self.ready_event.wait()
return await self._workflow_database.get_workflow_graph_state(trace_id)
async def delete_workflow_graph_state(self, trace_id: str):
"""显式清理 graph 持久化记录(已完成/失败的 workflow 释放空间)。"""
await self.ready_event.wait()
return await self._workflow_database.delete_workflow_graph_state(trace_id)
# Chat History Database Methods
async def create_chat_session(self, user_id: str, title: str = "新对话"):
"""新建一个聊天会话。"""
@@ -264,3 +300,79 @@ class PostgresDatabase:
"""返回某个聊天会话的全部消息。"""
await self.ready_event.wait()
return await self._chat_history_database.list_chat_messages(chat_id)
# MCP Server Database Methods
async def upsert_mcp_server(self, server_id: str, config: dict):
"""插入或更新一条 MCP 服务器配置;env 中敏感字段自动加密。"""
await self.ready_event.wait()
return await self._mcp_server_database.upsert(server_id, config)
async def get_mcp_server_db(self, server_id: str):
"""读取单条 MCP 服务器配置;env 自动解密。"""
await self.ready_event.wait()
return await self._mcp_server_database.get(server_id)
async def list_mcp_servers_db(self):
"""读取全部 MCP 服务器配置。"""
await self.ready_event.wait()
return await self._mcp_server_database.list_all()
async def delete_mcp_server_db(self, server_id: str):
"""删除某条 MCP 服务器配置。"""
await self.ready_event.wait()
return await self._mcp_server_database.delete(server_id)
# Tool Config Database Methods
async def upsert_tool_config(self, tool_name: str, config: dict):
"""插入或更新某工具的运行期配置;敏感字段自动加密。"""
await self.ready_event.wait()
return await self._tool_config_database.upsert(tool_name, config)
async def get_tool_config_db(self, tool_name: str):
"""读取某工具的运行期配置;敏感字段自动解密。"""
await self.ready_event.wait()
return await self._tool_config_database.get(tool_name)
async def list_tool_configs_db(self):
"""读取全部工具的运行期配置。"""
await self.ready_event.wait()
return await self._tool_config_database.list_all()
async def delete_tool_config_db(self, tool_name: str):
"""删除某工具的运行期配置。"""
await self.ready_event.wait()
return await self._tool_config_database.delete(tool_name)
# Custom Toolset Database Methods
async def upsert_custom_toolset(
self,
toolset_id: str,
name: str,
tools: list,
description: str = None,
owner_id: str = None,
):
"""插入或更新一个用户自定义工具组。"""
await self.ready_event.wait()
return await self._custom_toolset_database.upsert(
toolset_id=toolset_id,
name=name,
tools=tools,
description=description,
owner_id=owner_id,
)
async def get_custom_toolset(self, toolset_id: str):
"""按 ID 读取一个自定义工具组。"""
await self.ready_event.wait()
return await self._custom_toolset_database.get(toolset_id)
async def list_custom_toolsets(self):
"""读取全部自定义工具组。"""
await self.ready_event.wait()
return await self._custom_toolset_database.list_all()
async def delete_custom_toolset(self, toolset_id: str):
"""删除一个自定义工具组。"""
await self.ready_event.wait()
return await self._custom_toolset_database.delete(toolset_id)
@@ -0,0 +1,191 @@
# Copyright 2026 zhaoxi826
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Postgres 后端的 ``BaseStatePersistence`` 实现,让 graph 跨进程 resume。
设计思路:
- **复用 ``FullStatePersistence`` 的内存语义**snapshot/record_run/load_next 等
的实现已经做得很完备(NodeSnapshot 状态机、deep_copy、type adapter),不重写
这些细节,只是在每次"史会发生变更"的钩子触发后,把内存 history 序列化为 JSON
写到 ``workflow_graph_state`` 表里。
- **异步 IO 解耦**DB 写入是 fire-and-forget 模式,不在 graph 节点路径上阻塞——
graph 跑得快,持久化追得上即可。但 ``snapshot_end`` 一定会 await(确保关机
之前最终 history 落盘)。
- **resume 入口**:从 DB 读出 history JSON → ``load_json`` 还原 → ``Graph.iter_from_persistence``
跑剩余节点。
这一层不直接持有 SQLAlchemy session;通过两个 awaitable 注入 IO,便于测试。
"""
from __future__ import annotations
import asyncio
from contextlib import asynccontextmanager
from dataclasses import dataclass, field
from typing import Any, AsyncIterator, Awaitable, Callable, Optional
from pydantic_graph import BaseNode, End
from pydantic_graph.persistence import (
BaseStatePersistence,
NodeSnapshot,
Snapshot,
)
from pydantic_graph.persistence.in_mem import FullStatePersistence
from kilostar.utils.logger import get_logger
_logger = get_logger("graph_persistence")
# IO 注入签名:写一份 history JSON / 读一份 history JSONNone=没有)
WriteHistory = Callable[[str, Any], Awaitable[None]]
ReadHistory = Callable[[str], Awaitable[Optional[Any]]]
@dataclass
class PostgresStatePersistence(BaseStatePersistence):
"""复用 ``FullStatePersistence`` 内存语义 + 把 history 落 postgres。
每个 hook 触发后异步把 history 写库,DB 失败不影响 graph 继续推进
(只记 warning),保证 graph 自身的可用性。
"""
trace_id: str
write_history: WriteHistory
read_history: ReadHistory
_inner: FullStatePersistence = field(default_factory=FullStatePersistence)
# ─── BaseStatePersistence 接口 ──────────────────────────────────
async def snapshot_node(self, state, next_node):
await self._inner.snapshot_node(state, next_node)
await self._flush()
async def snapshot_node_if_new(self, snapshot_id, state, next_node):
await self._inner.snapshot_node_if_new(snapshot_id, state, next_node)
await self._flush()
async def snapshot_end(self, state, end):
await self._inner.snapshot_end(state, end)
# graph 已结束:必须确保最终 snapshot 落盘后再返回
await self._flush(must_succeed=True)
@asynccontextmanager
async def record_run(self, snapshot_id: str) -> AsyncIterator[None]:
async with self._inner.record_run(snapshot_id):
yield
# record_run 退出时 NodeSnapshot 状态从 running → success/error,需要刷盘
await self._flush()
async def load_next(self) -> Optional[NodeSnapshot]:
return await self._inner.load_next()
async def load_all(self) -> list[Snapshot]:
return await self._inner.load_all()
def should_set_types(self) -> bool:
return self._inner.should_set_types()
def set_types(self, state_type, run_end_type):
self._inner.set_types(state_type, run_end_type)
# ─── 序列化 / 反序列化 ─────────────────────────────────────────
async def hydrate(self) -> bool:
"""从 DB 拉一次 history 并恢复到内存;返回是否拉到了内容。
在 ``run_workflow_task`` 决定 fresh/resume 时调用。``set_types`` 必须
在调用前由 ``Graph.iter_from_persistence`` 替我们调过——否则
``_snapshots_type_adapter`` 还没准备好就会 assert 失败。
"""
try:
raw = await self.read_history(self.trace_id)
except Exception as e: # pragma: no cover - 防御
_logger.warning(f"hydrate read failed: {e}")
return False
if not raw:
return False
try:
self._inner.load_json(_to_json_bytes(raw))
except AssertionError:
# 没 set_types 时 load_json 会 assert;调用方需先调 set_graph_types
raise
except Exception as e:
_logger.warning(f"hydrate load_json failed: {e}")
return False
return True
async def _flush(self, *, must_succeed: bool = False) -> None:
"""把内存 history 序列化后异步写 DB。
``must_succeed=False`` 时 DB 异常仅记 warning``True`` 时再抛出,
让 ``snapshot_end`` 这种"必须落盘"的场景能感知失败。
"""
if self._inner._snapshots_type_adapter is None:
# 类型还没注册,没法 dump(首次 set_graph_types 还没跑过)
return
try:
blob = self._inner.dump_json()
except Exception as e: # pragma: no cover
_logger.warning(f"dump history failed: {e}")
return
try:
await self.write_history(self.trace_id, _from_json_bytes(blob))
except Exception as e:
_logger.warning(f"persist history failed: {e}")
if must_succeed:
raise
def _to_json_bytes(value: Any) -> bytes:
"""把 DB 读出的 ``list[dict]`` / ``str`` / ``bytes`` 都规范成 bytes 喂 load_json。"""
if isinstance(value, (bytes, bytearray)):
return bytes(value)
if isinstance(value, str):
return value.encode("utf-8")
# 假定是 list/dict(来自 JSONB),转回 JSON 字符串
import json as _json
return _json.dumps(value).encode("utf-8")
def _from_json_bytes(blob: bytes) -> Any:
"""把 ``dump_json`` 出来的 bytes 转成 list/dict 以便 JSONB 友好存储。"""
import json as _json
return _json.loads(blob.decode("utf-8"))
def build_postgres_persistence(trace_id: str) -> PostgresStatePersistence:
"""生产环境构造 PostgresStatePersistence:从 ray_actor_hook 取 postgres handle。"""
from kilostar.utils.ray_hook import ray_actor_hook
postgres_database = ray_actor_hook("postgres_database").postgres_database
async def _write(tid: str, history: Any) -> None:
await postgres_database.upsert_workflow_graph_state.remote(tid, history)
async def _read(tid: str) -> Optional[Any]:
record = await postgres_database.get_workflow_graph_state.remote(tid)
if record is None:
return None
# ORM 模型 / dict / list 都兼容
return getattr(record, "history", None) or record
return PostgresStatePersistence(
trace_id=trace_id,
write_history=_write,
read_history=_read,
)
+13 -1
View File
@@ -13,7 +13,7 @@
# limitations under the License.
from pydantic import BaseModel, Field, model_validator
from typing import Optional, Union, List, Dict, Any
from typing import Literal, Optional, Union, List, Dict, Any
from .model import LogicGate, WorkflowMetadata, WorkStepStatus, WorkflowStatus
from ulid import ULID
from datetime import datetime
@@ -61,10 +61,22 @@ class WorkflowStep(BaseModel):
default=None, description="前置依赖输出"
)
outputs: Optional[str] = Field(default=None, description="当前步骤产出物变量名")
node: Literal["skill_individual", "consciousness_node"] = Field(
default="skill_individual",
description=(
"执行此步的节点类别:\n"
"- skill_individualtask 内现起一个专家子个体执行(一次性)\n"
"- consciousness_node:远程调用全局 ConsciousnessNode actor"
),
)
agent_id: Optional[str] = Field(
default=None,
description="分配给 skill_individual 的 Skill Individual 真实 agent_id,不可用名称代替",
)
require_approval: bool = Field(
default=False,
description="该步执行前是否需要人工审批;启用时会暂停工作流并通过 SSE 等待用户回执",
)
logic_gate: Optional[LogicGate] = Field(default=None, description="逻辑跳转控制")
+478 -123
View File
@@ -12,166 +12,521 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""Workflow 引擎:基于 ``pydantic_graph`` 的状态机驱动。
调度路径只剩两类节点:
- ``skill_individual``:在当前 ray task 进程内现起一个 ``SkillIndividual``
执行;用后即焚,不消耗 actor。
- ``consciousness_node``:远程调用全局 ConsciousnessNode actor 的
``working`` 方法(中途指导 / 审查类工作)。
每一步执行前可设置 ``require_approval=True`` 触发 ``HumanApproval`` 节点:
推送 SSE → ``await gwm.get_received`` 阻塞等用户回执 → 决策 continue/abort。
Graph 还接了 pydantic_graph 的 ``FullStatePersistence``,目前主要用于
节点边界自动 snapshotpostgres 持久化保留旧 ``upsert_workflow_context``
路径,跨进程 resume 留到后续)。
"""
from __future__ import annotations
import asyncio
import datetime
from dataclasses import dataclass
from typing import Any, Awaitable, Callable, Dict, List, Optional
import ray
from kilostar.core.work.workflow.workflow import KiloStarWorkflow
from typing import Dict, Any, List
from pydantic import BaseModel, Field
from pydantic_graph import BaseNode, End, Graph, GraphRunContext
from pydantic_graph.persistence import BaseStatePersistence
from pydantic_graph.persistence.in_mem import FullStatePersistence
from kilostar.core.work.workflow.model import WorkflowStatus
@ray.remote
def run_workflow_task(workflow_data: dict, trace_id: str):
from kilostar.utils.ray_hook import ray_actor_hook
from kilostar.core.work.workflow.model import WorkflowStatus
import datetime
from pydantic import BaseModel
# ─── State / Deps ─────────────────────────────────────────────────────────
# State passed through graph nodes
class WorkflowGraphState(BaseModel):
trace_id: str
blackboard: Dict[str, Any]
work_link: List[Dict[str, Any]]
current_step_index: int = 0
status: str = "running"
logs: List[Dict[str, Any]] = []
async def save_context(state: WorkflowGraphState):
postgres_database = ray_actor_hook("postgres_database").postgres_database
await postgres_database.upsert_workflow_context.remote(
state.trace_id,
workflow_pointer=state.current_step_index,
blackboard=state.blackboard,
work_link=state.work_link,
workflow_status={str(datetime.datetime.now()): state.status},
workflow_log=state.logs,
)
await postgres_database.update_workflow_status.remote(
state.trace_id, state.status
)
global_workflow_manager = ray_actor_hook(
"global_workflow_manager"
).global_workflow_manager
await global_workflow_manager.put_received.remote(
state.trace_id, f"执行步骤 {state.current_step_index + 1}..."
class WorkflowGraphState(BaseModel):
"""图运行期跨节点共享的状态。"""
trace_id: str
blackboard: Dict[str, Any] = Field(default_factory=dict)
work_link: List[Dict[str, Any]] = Field(default_factory=list)
current_step_index: int = 0
final_status: str = WorkflowStatus.RUNNING.value
logs: List[Dict[str, Any]] = Field(default_factory=list)
original_command: str = ""
# 已发过 put_pending 的 HumanApproval step index 列表;resume 后避免重复推送。
# 用 list(不是 set)是为了 pydantic_graph 序列化 history 时 JSON 友好。
approvals_notified: List[int] = Field(default_factory=list)
# 业务侧执行入口:把 step + state 喂进去,拿到 (output_text, success_bool)
StepExecutor = Callable[
[Dict[str, Any], "WorkflowGraphState"], Awaitable[tuple[str, bool]]
]
@dataclass
class WorkflowDeps:
"""节点运行期依赖:所有外部 IO 都从这里走,便于测试 mock。
每个字段都是一个 awaitable,签名贴近原 ``.remote()`` 调用。生产路径由
``build_default_deps`` 现场组装真实 actor handle 包装;单测可以传任意
``AsyncMock``。
``run_skill`` / ``run_consciousness`` 是把 step 派发到具体执行器的入口,
抽出来既能让 graph 节点保持纯逻辑,又便于测试无需真起 SkillIndividual。
"""
upsert_workflow_context: Callable[..., Awaitable[Any]]
update_workflow_status: Callable[[str, str], Awaitable[Any]]
put_pending: Callable[[str, str], Awaitable[Any]]
get_received: Callable[[str], Awaitable[str]]
run_skill: StepExecutor
run_consciousness: StepExecutor
# ─── 节点 ─────────────────────────────────────────────────────────────────
@dataclass
class Initialize(BaseNode[WorkflowGraphState, WorkflowDeps, str]):
"""图入口节点:把 workflow 标记为 RUNNING,发首条 SSE 提示。"""
async def run(
self, ctx: GraphRunContext[WorkflowGraphState, WorkflowDeps]
) -> "Dispatch":
await ctx.deps.update_workflow_status(
ctx.state.trace_id, WorkflowStatus.RUNNING.value
)
await _persist_context(ctx, status=WorkflowStatus.RUNNING.value)
return Dispatch()
async def execute_step(state: WorkflowGraphState):
"""执行单一工作流节点逻辑"""
@dataclass
class Dispatch(BaseNode[WorkflowGraphState, WorkflowDeps, str]):
"""读取当前 step,按 ``node`` / ``require_approval`` 字段选择下一节点。"""
async def run(
self,
ctx: GraphRunContext[WorkflowGraphState, WorkflowDeps],
) -> "HumanApproval | SkillStep | ConsciousnessStep | Finalize":
state = ctx.state
if state.current_step_index >= len(state.work_link):
state.status = WorkflowStatus.COMPLETED
return state
return Finalize(status=WorkflowStatus.COMPLETED.value)
step = state.work_link[state.current_step_index]
step.get("node", "")
action = step.get("action", "")
step_data = state.work_link[state.current_step_index]
if step_data.get("require_approval"):
return HumanApproval()
# 记录开始状态
node_type = step_data.get("node") or "skill_individual"
if node_type == "consciousness_node":
return ConsciousnessStep()
if node_type == "skill_individual":
return SkillStep()
# 未识别的 node 类型按失败处理(保守:不静默吞)
state.logs.append(
{
str(state.current_step_index): [
str(datetime.datetime.now()),
"working",
f"开始执行: {step.get('name', '未命名步骤')}",
"failed",
f"未知节点类型: {node_type}",
]
}
)
await save_context(state)
await _persist_context(ctx, status=WorkflowStatus.FAILED.value)
return Finalize(status=WorkflowStatus.FAILED.value)
try:
# TODO: 实际对接不同节点执行逻辑 (例如: control_node, agent 技能)
# 这里是简化版,向控制节点或指定 skill 发送指令
# ... 模拟执行逻辑 ...
await asyncio.sleep(2)
@dataclass
class HumanApproval(BaseNode[WorkflowGraphState, WorkflowDeps, str]):
"""人工审批节点:暂停 graph,等用户通过 SSE 回执决策。
# 记录结果
state.blackboard[
step.get("outputs", f"step_{state.current_step_index}_result")
] = "Success execution of " + action
state.logs[-1][str(state.current_step_index)] = [
str(datetime.datetime.now()),
"completed",
f"成功: {action}",
]
回执约定(轻量协议,沿用现有 SSE 通道):
# 判断逻辑跳转
logic_gate = step.get("logic_gate")
if logic_gate and logic_gate.get("if_pass") == "exit":
state.status = WorkflowStatus.COMPLETED
else:
state.current_step_index += 1
- 含 ``approve`` / ``yes`` / ``ok`` 视为通过,回到 Dispatch 继续执行
- 其它(包括 ``reject``/``no``/``abort``)视为拒绝,工作流终止为 FAILED
"""
except Exception as e:
state.logs[-1][str(state.current_step_index)] = [
str(datetime.datetime.now()),
"failed",
str(e),
]
state.status = WorkflowStatus.FAILED
logic_gate = step.get("logic_gate")
if logic_gate and logic_gate.get("if_fail"):
fail_target = logic_gate.get("if_fail")
if "jump_to_step_" in fail_target:
target_step = int(fail_target.split("_")[-1]) - 1
state.current_step_index = target_step
state.status = WorkflowStatus.RUNNING
async def run(
self,
ctx: GraphRunContext[WorkflowGraphState, WorkflowDeps],
) -> "Dispatch | Finalize":
state = ctx.state
step_data = state.work_link[state.current_step_index]
# idempotent 推送:仅当本 step 还没通知过时才发 put_pending。
# 这样 resume 场景(HumanApproval 节点被重新进入)不会给前端发重复消息。
if state.current_step_index not in state.approvals_notified:
await ctx.deps.put_pending(
state.trace_id,
f"步骤 {state.current_step_index + 1} ({step_data.get('name', '')}) "
f"需要人工审批,请回复 approve / reject。",
)
state.approvals_notified.append(state.current_step_index)
await _persist_context(ctx, status=WorkflowStatus.HANGUP.value)
await save_context(state)
return state
reply = (await ctx.deps.get_received(state.trace_id) or "").strip().lower()
if any(token in reply for token in ("approve", "yes", "ok")):
# 把 require_approval 置否避免无限循环重新进 HumanApproval
step_data["require_approval"] = False
await _persist_context(ctx, status=WorkflowStatus.RUNNING.value)
return Dispatch()
async def _run():
postgres_database = ray_actor_hook("postgres_database").postgres_database
await postgres_database.update_workflow_status.remote(
trace_id, WorkflowStatus.RUNNING
state.logs.append(
{
str(state.current_step_index): [
str(datetime.datetime.now()),
"rejected",
f"用户拒绝执行该步骤: {reply!r}",
]
}
)
await _persist_context(ctx, status=WorkflowStatus.FAILED.value)
return Finalize(status=WorkflowStatus.FAILED.value)
state = WorkflowGraphState(
trace_id=trace_id,
blackboard={},
work_link=workflow_data.get("work_link", []),
)
await save_context(state)
# 简单的图执行驱动 (模拟 pydantic-ai.graph.run 行为,直至 Graph 库正式稳定)
while state.status == WorkflowStatus.RUNNING and state.current_step_index < len(
state.work_link
):
state = await execute_step(state)
@dataclass
class SkillStep(BaseNode[WorkflowGraphState, WorkflowDeps, str]):
"""skill_individual 路径:当前进程内拉一个专家子个体执行该步。"""
await postgres_database.update_workflow_status.remote(trace_id, state.status)
global_workflow_manager = ray_actor_hook(
"global_workflow_manager"
).global_workflow_manager
async def run(
self,
ctx: GraphRunContext[WorkflowGraphState, WorkflowDeps],
) -> "Dispatch | Finalize":
return await _execute_step(ctx, executor=ctx.deps.run_skill)
@dataclass
class ConsciousnessStep(BaseNode[WorkflowGraphState, WorkflowDeps, str]):
"""consciousness_node 路径:远程调用 ConsciousnessNode actor 处理该步。"""
async def run(
self,
ctx: GraphRunContext[WorkflowGraphState, WorkflowDeps],
) -> "Dispatch | Finalize":
return await _execute_step(ctx, executor=ctx.deps.run_consciousness)
@dataclass
class Finalize(BaseNode[WorkflowGraphState, WorkflowDeps, str]):
"""收尾节点:写最终状态、推送最终 SSE,把 workflow 终态作为 graph 输出。"""
status: str
async def run(
self, ctx: GraphRunContext[WorkflowGraphState, WorkflowDeps]
) -> End[str]:
ctx.state.final_status = self.status
await ctx.deps.update_workflow_status(ctx.state.trace_id, self.status)
msg = (
"工作流执行完成!"
if state.status == WorkflowStatus.COMPLETED
if self.status == WorkflowStatus.COMPLETED.value
else "工作流执行失败。"
)
await global_workflow_manager.put_received.remote(trace_id, msg)
await ctx.deps.put_pending(ctx.state.trace_id, msg)
return End(self.status)
asyncio.run(_run())
# ─── 内部 helper ─────────────────────────────────────────────────────────
async def _persist_context(
ctx: GraphRunContext[WorkflowGraphState, WorkflowDeps], *, status: str
) -> None:
"""把当前 state 落库到 workflow_context 表(覆盖式写入)。"""
await ctx.deps.upsert_workflow_context(
ctx.state.trace_id,
workflow_pointer=ctx.state.current_step_index,
blackboard=ctx.state.blackboard,
work_link=ctx.state.work_link,
workflow_status={str(datetime.datetime.now()): status},
workflow_log=ctx.state.logs,
)
async def _execute_step(
ctx: GraphRunContext[WorkflowGraphState, WorkflowDeps],
*,
executor: StepExecutor,
) -> "Dispatch | Finalize":
"""SkillStep / ConsciousnessStep 共享的执行骨架。
"日志/SSE/blackboard 更新/逻辑闸门"这些跨节点共性逻辑抽出来;具体怎么
跑这一步交给 ``executor`` 决定(生产是 SkillIndividual.run / actor.working
远程调用,测试可以直接传 lambda)。
"""
state = ctx.state
step_data = state.work_link[state.current_step_index]
state.logs.append(
{
str(state.current_step_index): [
str(datetime.datetime.now()),
"working",
f"开始执行: {step_data.get('name', '未命名步骤')}",
]
}
)
await _persist_context(ctx, status=WorkflowStatus.RUNNING.value)
await ctx.deps.put_pending(
state.trace_id, f"执行步骤 {state.current_step_index + 1}..."
)
try:
output_text, success = await executor(step_data, state)
except Exception as e: # 执行器抛异常 → 走失败分支
output_text, success = str(e), False
if success:
output_key = step_data.get(
"outputs", f"step_{state.current_step_index}_result"
)
state.blackboard[output_key] = output_text
state.logs[-1][str(state.current_step_index)] = [
str(datetime.datetime.now()),
"completed",
f"成功: {step_data.get('action', '')}",
]
await _persist_context(ctx, status=WorkflowStatus.RUNNING.value)
logic_gate = step_data.get("logic_gate") or {}
if logic_gate.get("if_pass") == "exit":
return Finalize(status=WorkflowStatus.COMPLETED.value)
state.current_step_index += 1
if state.current_step_index >= len(state.work_link):
return Finalize(status=WorkflowStatus.COMPLETED.value)
return Dispatch()
# 失败:if_fail 跳转优先于直接收尾
state.logs[-1][str(state.current_step_index)] = [
str(datetime.datetime.now()),
"failed",
output_text,
]
logic_gate = step_data.get("logic_gate") or {}
fail_target = logic_gate.get("if_fail")
if fail_target and "jump_to_step_" in fail_target:
target_step = int(fail_target.split("_")[-1]) - 1
state.current_step_index = target_step
await _persist_context(ctx, status=WorkflowStatus.RUNNING.value)
return Dispatch()
await _persist_context(ctx, status=WorkflowStatus.FAILED.value)
return Finalize(status=WorkflowStatus.FAILED.value)
# ─── 图定义 ───────────────────────────────────────────────────────────────
workflow_graph: Graph[WorkflowGraphState, WorkflowDeps, str] = Graph(
nodes=[Initialize, Dispatch, HumanApproval, SkillStep, ConsciousnessStep, Finalize],
state_type=WorkflowGraphState,
run_end_type=str,
)
# ─── 默认执行器:把 step 派发到 SkillIndividual / ConsciousnessNode ────
async def _default_skill_executor(
step_data: Dict[str, Any], state: WorkflowGraphState
) -> tuple[str, bool]:
"""生产环境的 skill_individual 派发器:当前 task 进程现起 agent 执行。
每步现起一个 ``SkillIndividual`` 跑完即销毁,不绑定 actor 寿命。``agent_id``
是必须的(用于从 GSM 拉到该子个体的配置)。
"""
from kilostar.worker_individual.skill_individual import SkillIndividual
from kilostar.core.global_state_machine.gsm_snapshot import fetch_snapshot
agent_id = step_data.get("agent_id")
if not agent_id:
return "skill_individual 步骤缺少 agent_id", False
snapshot = await fetch_snapshot()
agent_config = snapshot.individuals.get(agent_id)
if not agent_config:
return f"未找到 agent_id={agent_id} 的专家子个体", False
individual = SkillIndividual(dict(agent_config))
task_event = {
"step": step_data,
"blackboard": state.blackboard,
"original_command": state.original_command,
}
result = await individual.run(task_event)
output = (
result.get("output", "") if isinstance(result, dict) else str(result)
)
return output or "(empty)", True
async def _default_consciousness_executor(
step_data: Dict[str, Any], state: WorkflowGraphState
) -> tuple[str, bool]:
"""生产环境的 consciousness 派发器:远程调用 ConsciousnessNode.working。"""
from kilostar.utils.ray_hook import ray_actor_hook
from kilostar.core.individual.consciousness_node.template import (
ForWorkflow,
ForWorkflowInput,
)
from kilostar.core.work.workflow.workflow import WorkflowStep
consciousness_node = ray_actor_hook("consciousness_node").consciousness_node
payload = ForWorkflowInput(
workflow_step=WorkflowStep.model_validate(step_data),
original_command=state.original_command,
)
result = await consciousness_node.working.remote(payload)
if isinstance(result, ForWorkflow):
return result.output, True
if result is None:
return "ConsciousnessNode 返回 None", False
return f"ConsciousnessNode 返回未知类型: {type(result).__name__}", False
def build_default_deps() -> WorkflowDeps:
"""生产环境构造 ``WorkflowDeps``:把 ray actor handle 包装成 awaitable。
抽出来是为了让 ``run_workflow_task`` 入口和测试入口共享同一套包装逻辑。
"""
from kilostar.utils.ray_hook import ray_actor_hook
postgres_database = ray_actor_hook("postgres_database").postgres_database
global_workflow_manager = ray_actor_hook(
"global_workflow_manager"
).global_workflow_manager
async def _upsert_workflow_context(trace_id: str, **kwargs: Any) -> Any:
return await postgres_database.upsert_workflow_context.remote(
trace_id, **kwargs
)
async def _update_workflow_status(trace_id: str, status: str) -> Any:
return await postgres_database.update_workflow_status.remote(
trace_id, status
)
async def _put_pending(trace_id: str, message: str) -> Any:
return await global_workflow_manager.put_pending.remote(trace_id, message)
async def _get_received(trace_id: str) -> str:
return await global_workflow_manager.get_received.remote(trace_id)
return WorkflowDeps(
upsert_workflow_context=_upsert_workflow_context,
update_workflow_status=_update_workflow_status,
put_pending=_put_pending,
get_received=_get_received,
run_skill=_default_skill_executor,
run_consciousness=_default_consciousness_executor,
)
async def run_workflow_graph(
workflow_data: Dict[str, Any],
trace_id: str,
*,
deps: Optional[WorkflowDeps] = None,
persistence: Optional[BaseStatePersistence] = None,
) -> str:
"""在当前事件循环里跑一遍 workflow graph,返回 workflow 终态字符串。
Args:
workflow_data: ``KiloStarWorkflow.model_dump()`` 出来的 dict
trace_id: 工作流追踪 id
deps: 缺省时通过 ``build_default_deps`` 现场构造(生产路径)
persistence: 缺省给一个 ``FullStatePersistence`` 让 graph 自动在节点
边界 snapshot;外部传入则共享同一份持久化(便于诊断 / 后续 resume)
"""
if deps is None:
deps = build_default_deps()
if persistence is None:
persistence = FullStatePersistence()
state = WorkflowGraphState(
trace_id=trace_id,
blackboard={},
work_link=list(workflow_data.get("work_link", []) or []),
original_command=(workflow_data.get("workflow_metadata") or {}).get(
"command", ""
)
or "",
)
result = await workflow_graph.run(
Initialize(),
state=state,
deps=deps,
persistence=persistence,
)
return result.output
async def resume_workflow_graph(
trace_id: str,
*,
deps: Optional[WorkflowDeps] = None,
persistence: Optional[BaseStatePersistence] = None,
) -> str:
"""从持久化里恢复一个工作流,跑剩余节点直至 End。
要求 ``persistence`` 已经事先 ``hydrate``(生产路径用 ``PostgresStatePersistence.hydrate``)。
本函数只负责调 ``Graph.iter_from_persistence`` 把剩下的节点跑完。
"""
if persistence is None:
raise ValueError("resume 必须显式传入 persistence")
if deps is None:
deps = build_default_deps()
final_output: str = WorkflowStatus.RUNNING.value
async with workflow_graph.iter_from_persistence(
persistence, deps=deps
) as run:
async for node in run:
if isinstance(node, End):
final_output = node.data
break
return final_output
@ray.remote
class WorkflowRunningEngine:
def __init__(
self, consciousness_node=None, control_node=None, regulatory_node=None
):
self.consciousness_node = consciousness_node
self.control_node = control_node
self.regulatory_node = regulatory_node
self.events_queue = asyncio.Queue()
def run_workflow_task(workflow_data: dict, trace_id: str):
"""workflow 的 ray task 入口:一次性执行,跑完即销毁。
async def put_event(self, event):
await self.events_queue.put(event)
生产路径下持久化交给 ``PostgresStatePersistence`` —— 即便进程崩溃,再 fire
一次相同 ``trace_id`` 的任务(或调 ``/workflow/{trace_id}/resume``)即可
续跑。同时为了支持 fresh start,先尝试 ``hydrate``
- hydrate 拿到内容 → 走 resume 路径
- hydrate 没拿到 → 走全新路径
async def run(self):
"""引擎循环提取事件"""
while True:
await self.events_queue.get()
await asyncio.sleep(1)
ray task 是新进程,contextvars 不会从 caller 传过来,所以入口先 bind 一次
``trace_id``,让节点内的日志自动带上它。
"""
from kilostar.utils.request_context import trace_id_scope
from kilostar.core.work.workflow.graph_persistence import (
build_postgres_persistence,
)
async def execute_workflow(self, workflow: KiloStarWorkflow):
# 这个方法可以由意识节点调用来提交一个完整的运行任务
workflow_dict = workflow.model_dump()
trace_id = workflow.trace_id
run_workflow_task.remote(workflow_dict, trace_id)
async def _entry() -> None:
with trace_id_scope(trace_id):
persistence = build_postgres_persistence(trace_id)
persistence.set_graph_types(workflow_graph)
recovered = False
try:
recovered = await persistence.hydrate()
except Exception: # pragma: no cover - 防御
recovered = False
if recovered:
await resume_workflow_graph(trace_id, persistence=persistence)
else:
await run_workflow_graph(
workflow_data, trace_id, persistence=persistence
)
asyncio.run(_entry())