feat(system):优化后端
1.新增后端测试 2.增加了后端的加密 3.增加了i18n(国际化)
This commit is contained in:
@@ -13,47 +13,129 @@
|
||||
# limitations under the License.
|
||||
|
||||
import ray
|
||||
from kilostar.core.global_state_machine.provider_manager import ProviderManager
|
||||
from kilostar.core.global_state_machine.tool_manager import GlobalToolManager
|
||||
from kilostar.core.postgres_database import PostgresDatabase
|
||||
from kilostar.core.global_state_machine.skill_manager import GlobalSkillManager
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
|
||||
from kilostar.core.global_state_machine.individual_manager import (
|
||||
GlobalIndividualManager,
|
||||
)
|
||||
from kilostar.core.global_state_machine.provider_manager import ProviderManager
|
||||
from kilostar.core.global_state_machine.skill_manager import GlobalSkillManager
|
||||
from kilostar.core.global_state_machine.tool_manager import GlobalToolManager
|
||||
from kilostar.core.global_state_machine.gsm_snapshot import GSMSnapshot
|
||||
from kilostar.core.postgres_database import PostgresDatabase
|
||||
|
||||
|
||||
@ray.remote
|
||||
class GlobalStateMachine:
|
||||
"""全局状态机 Actor,统一持有 Provider/Tool/Skill/Individual 四个注册表。
|
||||
"""全局状态机 Actor,统一持有 Provider/Tool/Skill/Individual/MCP/CustomToolset 注册表。
|
||||
|
||||
其它 Actor 通过 ``ray.get_actor("global_state_machine")`` 拿到本实例,
|
||||
再调用本类暴露的方法来读写各注册表,避免每个 Actor 各自维护一份状态。
|
||||
所有持久化都走 PostgresDatabase;启动时由 ``init_state_machine`` 一次性把
|
||||
Provider / MCP / Tool config / Custom toolset 拉到内存,保证后续读操作零等待。
|
||||
"""
|
||||
|
||||
def __init__(self, postgres_database: PostgresDatabase):
|
||||
import sys
|
||||
|
||||
print("GSM __init__ START", file=sys.stderr, flush=True)
|
||||
print(" event_dict done", file=sys.stderr, flush=True)
|
||||
self._global_provider_manager = ProviderManager(postgres_database)
|
||||
print(" provider_manager done", file=sys.stderr, flush=True)
|
||||
self._global_tool_manager = GlobalToolManager()
|
||||
print(" tool_manager done", file=sys.stderr, flush=True)
|
||||
self._global_skill_manager = GlobalSkillManager()
|
||||
print(" skill_manager done", file=sys.stderr, flush=True)
|
||||
self._global_individual_manager = GlobalIndividualManager()
|
||||
print(" individual_manager done", file=sys.stderr, flush=True)
|
||||
|
||||
# 内存注册表(启动时由 init_state_machine 从 DB 加载)
|
||||
self._mcp_servers: Dict[str, Dict[str, Any]] = {}
|
||||
self._tool_configs: Dict[str, Dict[str, Any]] = {}
|
||||
self._custom_toolsets: Dict[str, Dict[str, Any]] = {}
|
||||
|
||||
# 配置快照与版本号:每次写入 → version+=1 → ray.put 新 snapshot
|
||||
# 读端通过 current_config_ref 拿 ref 后用 ray.get 直读,绕开 actor 单线程瓶颈
|
||||
self._config_version: int = 0
|
||||
self._current_ref: Optional[ray.ObjectRef] = None
|
||||
|
||||
self.postgres_database = postgres_database
|
||||
print("GSM __init__ DONE", file=sys.stderr, flush=True)
|
||||
|
||||
async def init_state_machine(self):
|
||||
"""从数据库加载 Provider/Individual 注册表到内存。"""
|
||||
"""启动期一次性把 Provider/Individual/MCP/ToolConfig/CustomToolset 拉到内存。"""
|
||||
await self._global_provider_manager.init_provider_register(
|
||||
self.postgres_database
|
||||
)
|
||||
await self._global_individual_manager.init_individual_register(
|
||||
self.postgres_database
|
||||
)
|
||||
# MCP servers
|
||||
rows = await self.postgres_database.list_mcp_servers_db.remote()
|
||||
self._mcp_servers = {row["server_id"]: row for row in rows}
|
||||
# Tool configs
|
||||
cfg_rows = await self.postgres_database.list_tool_configs_db.remote()
|
||||
self._tool_configs = {row["tool_name"]: row["config"] for row in cfg_rows}
|
||||
# Custom toolsets
|
||||
ts_rows = await self.postgres_database.list_custom_toolsets.remote()
|
||||
self._custom_toolsets = {row["toolset_id"]: row for row in ts_rows}
|
||||
# 让 tool_manager 立刻把 custom toolset 装配成 FunctionToolset
|
||||
self._global_tool_manager.rebuild_custom_toolsets(self._custom_toolsets)
|
||||
# 启动期一次性发布 v1 快照,让等待中的读端立刻可用
|
||||
self._publish_snapshot()
|
||||
|
||||
# ─── Snapshot 发布(Object Store 读路径) ────────────────────
|
||||
|
||||
def _build_snapshot(self) -> GSMSnapshot:
|
||||
"""把当前内存状态打包成 GSMSnapshot,调用方应已确保数据是最新的。
|
||||
|
||||
注意 ``tool_funcs`` 这里被拍平成 ``{tool_name: callable}`` —— 内部按
|
||||
scope 分桶的细节在 snapshot 层不暴露,task 端只关心"我能调哪些函数"。
|
||||
如果同名工具在多个 scope 出现,后写入的覆盖(系统工具与第三方互斥,
|
||||
实际不会冲突)。
|
||||
|
||||
``system_tools_by_scope`` 单独保留按 scope 分桶的工具名列表,让客户端
|
||||
在自己进程里复刻 ``get_toolsets_for_scope`` 的合并语义(fetch_snapshot
|
||||
用户调 ``build_toolsets_for_scope`` 即可重建 FunctionToolset 列表)。
|
||||
"""
|
||||
tm = self._global_tool_manager
|
||||
flat_funcs: Dict[str, Any] = {}
|
||||
system_tools_by_scope: Dict[str, List[str]] = {}
|
||||
for _scope, name_to_func in tm._tool_funcs.items():
|
||||
system_tools_by_scope[_scope] = list(name_to_func.keys())
|
||||
for name, func in name_to_func.items():
|
||||
flat_funcs[name] = func
|
||||
return GSMSnapshot(
|
||||
version=self._config_version,
|
||||
providers=dict(self._global_provider_manager.provider_register),
|
||||
individuals=dict(self._global_individual_manager._individuals),
|
||||
mcp_servers=dict(self._mcp_servers),
|
||||
tool_configs=dict(self._tool_configs),
|
||||
custom_toolsets=dict(self._custom_toolsets),
|
||||
skills=dict(self._global_skill_manager.skill_mapper),
|
||||
tool_metadata=dict(tm.tool_metadata),
|
||||
tool_funcs=flat_funcs,
|
||||
third_party_funcs=dict(tm._third_party_funcs),
|
||||
tool_mapper={
|
||||
scope: dict(name_to_cls)
|
||||
for scope, name_to_cls in tm.tool_mapper.items()
|
||||
},
|
||||
system_tools_by_scope=system_tools_by_scope,
|
||||
)
|
||||
|
||||
def _publish_snapshot(self) -> None:
|
||||
"""版本号 +1 并把当前状态 put 到 Ray Object Store。
|
||||
|
||||
旧 ref 会因为引用计数归零而进入回收队列;正在执行的 task 已经把 ref
|
||||
拷贝到了自己的进程,dec 不会影响它们的读取。
|
||||
"""
|
||||
self._config_version += 1
|
||||
self._current_ref = ray.put(self._build_snapshot())
|
||||
|
||||
async def current_config_ref(self) -> Tuple[int, ray.ObjectRef]:
|
||||
"""返回 ``(version, ObjectRef)``,调用方拿了 ref 后用 ``ray.get`` 自取。
|
||||
|
||||
**不要**直接返回 snapshot 对象 —— 那样会走 actor RPC 反序列化,丧失
|
||||
object store 的共享内存优势。返回 ref 才能让调用方在自己进程里 ray.get。
|
||||
"""
|
||||
if self._current_ref is None:
|
||||
self._publish_snapshot()
|
||||
return self._config_version, self._current_ref
|
||||
|
||||
async def current_version(self) -> int:
|
||||
"""轻量版:只返回当前版本号,用于读端判断本地缓存是否还新。"""
|
||||
return self._config_version
|
||||
|
||||
# ─── Provider ──────────────────────────────────────────────
|
||||
|
||||
async def add_provider_wrap(
|
||||
self,
|
||||
@@ -64,7 +146,7 @@ class GlobalStateMachine:
|
||||
provider_owner,
|
||||
):
|
||||
"""新增一个模型 Provider:内存注册 + 数据库持久化一并完成。"""
|
||||
return await self._global_provider_manager.add_provider(
|
||||
result = await self._global_provider_manager.add_provider(
|
||||
provider_type=provider_type,
|
||||
provider_title=provider_title,
|
||||
provider_url=provider_url,
|
||||
@@ -72,8 +154,9 @@ class GlobalStateMachine:
|
||||
provider_owner=provider_owner,
|
||||
postgres_database=self.postgres_database,
|
||||
)
|
||||
self._publish_snapshot()
|
||||
return result
|
||||
|
||||
# Provider Manager Methods
|
||||
def get_provider_list(self):
|
||||
"""返回内存中已登记的全部 Provider。"""
|
||||
return self._global_provider_manager.get_provider_list()
|
||||
@@ -84,11 +167,14 @@ class GlobalStateMachine:
|
||||
|
||||
async def delete_provider(self, provider_title: str):
|
||||
"""删除一个 Provider:内存注册 + 数据库持久化一并完成。"""
|
||||
return await self._global_provider_manager.delete_provider(
|
||||
result = await self._global_provider_manager.delete_provider(
|
||||
provider_title, self.postgres_database
|
||||
)
|
||||
self._publish_snapshot()
|
||||
return result
|
||||
|
||||
# ─── Tool / Toolset ────────────────────────────────────────
|
||||
|
||||
# Tool Manager Methods
|
||||
def get_tool_mapper(self):
|
||||
"""返回 agent_name -> {tool_name: callable} 的全量映射。"""
|
||||
return self._global_tool_manager.tool_mapper
|
||||
@@ -96,37 +182,152 @@ class GlobalStateMachine:
|
||||
def get_tool_list(self, agent_name: str):
|
||||
"""返回某个 agent 可用的工具集(其专属工具与 default 工具的并集)。"""
|
||||
tools = self._global_tool_manager.tool_mapper.get(agent_name, {})
|
||||
# also include default tools
|
||||
default_tools = self._global_tool_manager.tool_mapper.get("default", {})
|
||||
merged_tools = {**default_tools, **tools}
|
||||
return merged_tools
|
||||
return {**default_tools, **tools}
|
||||
|
||||
def get_tool_categories(self):
|
||||
"""返回工具按分类聚合的完整信息。"""
|
||||
return {
|
||||
"system": self._global_tool_manager.get_system_tools(),
|
||||
"third_party": self._global_tool_manager.get_third_party_tools(),
|
||||
"by_category": {
|
||||
"system": self._global_tool_manager.get_tools_by_category("system"),
|
||||
"search": self._global_tool_manager.get_tools_by_category("search"),
|
||||
"mcp": self._global_tool_manager.get_tools_by_category("mcp"),
|
||||
"other": self._global_tool_manager.get_tools_by_category("other"),
|
||||
},
|
||||
"all": self._global_tool_manager.get_all_tools(),
|
||||
}
|
||||
|
||||
def get_toolsets_for_scope(self, scope: str) -> List[Any]:
|
||||
"""返回某个 scope 下的"系统 + 自定义工具组"toolset 列表(不含 MCP)。"""
|
||||
return self._global_tool_manager.get_toolsets_for_scope(scope)
|
||||
|
||||
# ─── MCP Server Registry ───────────────────────────────────
|
||||
|
||||
async def add_mcp_server(self, server_id: str, config: Dict[str, Any]) -> bool:
|
||||
"""注册一个 MCP 服务器配置(写库 → 写内存)。"""
|
||||
saved = await self.postgres_database.upsert_mcp_server.remote(server_id, config)
|
||||
self._mcp_servers[server_id] = saved
|
||||
self._publish_snapshot()
|
||||
return True
|
||||
|
||||
def get_mcp_server(self, server_id: str) -> Optional[Dict[str, Any]]:
|
||||
return self._mcp_servers.get(server_id)
|
||||
|
||||
def list_mcp_servers(self) -> List[Dict[str, Any]]:
|
||||
return [
|
||||
{"server_id": sid, **cfg} for sid, cfg in self._mcp_servers.items()
|
||||
]
|
||||
|
||||
async def delete_mcp_server(self, server_id: str) -> bool:
|
||||
ok = await self.postgres_database.delete_mcp_server_db.remote(server_id)
|
||||
self._mcp_servers.pop(server_id, None)
|
||||
self._publish_snapshot()
|
||||
return bool(ok)
|
||||
|
||||
def get_mcp_server_configs(self) -> Dict[str, Dict[str, Any]]:
|
||||
"""返回原始 MCP 服务器配置字典(供节点创建 toolsets 时使用)。"""
|
||||
return dict(self._mcp_servers)
|
||||
|
||||
# ─── Tool Config(Tavily API key 等)─────────────────────
|
||||
|
||||
async def set_tool_config(self, tool_name: str, config: Dict[str, Any]) -> bool:
|
||||
"""整体覆盖某工具的运行期配置(敏感字段在 DAO 内自动加密)。"""
|
||||
saved = await self.postgres_database.upsert_tool_config.remote(
|
||||
tool_name, config
|
||||
)
|
||||
self._tool_configs[tool_name] = saved["config"]
|
||||
self._publish_snapshot()
|
||||
return True
|
||||
|
||||
def get_tool_config(self, tool_name: str) -> Dict[str, Any]:
|
||||
"""按工具名取出配置;不存在则返回空字典。"""
|
||||
return dict(self._tool_configs.get(tool_name, {}))
|
||||
|
||||
def list_tool_configs(self) -> Dict[str, Dict[str, Any]]:
|
||||
"""返回全部已配置工具的配置(包含敏感字段,调用方需自行脱敏)。"""
|
||||
return dict(self._tool_configs)
|
||||
|
||||
async def delete_tool_config(self, tool_name: str) -> bool:
|
||||
ok = await self.postgres_database.delete_tool_config_db.remote(tool_name)
|
||||
self._tool_configs.pop(tool_name, None)
|
||||
self._publish_snapshot()
|
||||
return bool(ok)
|
||||
|
||||
# ─── Custom Toolset(用户自定义工具组)──────────────────
|
||||
|
||||
async def add_custom_toolset(
|
||||
self,
|
||||
toolset_id: str,
|
||||
name: str,
|
||||
tools: List[str],
|
||||
description: Optional[str] = None,
|
||||
owner_id: Optional[str] = None,
|
||||
) -> Dict[str, Any]:
|
||||
"""新增/更新一个自定义工具组:仅允许引用非 system/非 mcp 的工具。"""
|
||||
# 校验:只能放第三方(非 system / 非 mcp)工具
|
||||
invalid = [
|
||||
t for t in tools if not self._global_tool_manager.is_third_party_tool(t)
|
||||
]
|
||||
if invalid:
|
||||
raise ValueError(
|
||||
f"自定义工具组只允许包含第三方工具,以下不合法:{invalid}"
|
||||
)
|
||||
saved = await self.postgres_database.upsert_custom_toolset.remote(
|
||||
toolset_id=toolset_id,
|
||||
name=name,
|
||||
tools=list(tools),
|
||||
description=description,
|
||||
owner_id=owner_id,
|
||||
)
|
||||
self._custom_toolsets[toolset_id] = saved
|
||||
self._global_tool_manager.rebuild_custom_toolsets(self._custom_toolsets)
|
||||
self._publish_snapshot()
|
||||
return saved
|
||||
|
||||
def list_custom_toolsets(self) -> List[Dict[str, Any]]:
|
||||
return list(self._custom_toolsets.values())
|
||||
|
||||
def get_custom_toolset(self, toolset_id: str) -> Optional[Dict[str, Any]]:
|
||||
return self._custom_toolsets.get(toolset_id)
|
||||
|
||||
async def delete_custom_toolset(self, toolset_id: str) -> bool:
|
||||
ok = await self.postgres_database.delete_custom_toolset.remote(toolset_id)
|
||||
self._custom_toolsets.pop(toolset_id, None)
|
||||
self._global_tool_manager.rebuild_custom_toolsets(self._custom_toolsets)
|
||||
self._publish_snapshot()
|
||||
return bool(ok)
|
||||
|
||||
# ─── Skill ────────────────────────────────────────────────
|
||||
|
||||
# Skill Manager Methods
|
||||
def add_skill(self, skill_name: str):
|
||||
"""注册一个新的 Skill 名称到 Skill 注册表。"""
|
||||
return self._global_skill_manager.add_skill(skill_name)
|
||||
result = self._global_skill_manager.add_skill(skill_name)
|
||||
self._publish_snapshot()
|
||||
return result
|
||||
|
||||
def get_skill_list(self):
|
||||
"""返回全部已注册的 Skill 名称。"""
|
||||
return self._global_skill_manager.get_skill_list()
|
||||
|
||||
def remove_skill(self, skill_name: str):
|
||||
"""从注册表中移除一个 Skill。"""
|
||||
return self._global_skill_manager.remove_skill(skill_name)
|
||||
result = self._global_skill_manager.remove_skill(skill_name)
|
||||
self._publish_snapshot()
|
||||
return result
|
||||
|
||||
# ─── Individual ───────────────────────────────────────────
|
||||
|
||||
# Individual Manager Methods
|
||||
def add_individual(self, agent_id: str, config):
|
||||
"""把一个 Worker Individual 的运行期配置加入注册表。"""
|
||||
return self._global_individual_manager.add_individual(agent_id, config)
|
||||
result = self._global_individual_manager.add_individual(agent_id, config)
|
||||
self._publish_snapshot()
|
||||
return result
|
||||
|
||||
def get_individual(self, agent_id: str):
|
||||
"""按 agent_id 取出某个 Worker Individual 的配置。"""
|
||||
return self._global_individual_manager.get_individual(agent_id)
|
||||
|
||||
def remove_individual(self, agent_id: str):
|
||||
"""从注册表中移除一个 Worker Individual。"""
|
||||
return self._global_individual_manager.remove_individual(agent_id)
|
||||
result = self._global_individual_manager.remove_individual(agent_id)
|
||||
self._publish_snapshot()
|
||||
return result
|
||||
|
||||
def list_individuals(self):
|
||||
"""返回当前注册的全部 Worker Individual 列表。"""
|
||||
return self._global_individual_manager.list_individuals()
|
||||
|
||||
@@ -0,0 +1,184 @@
|
||||
# Copyright 2026 zhaoxi826
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""GSM 快照对象与客户端拉取工具。
|
||||
|
||||
设计动机:把 GSM actor 内存里的"读路径配置"打包成可放进 Ray Object Store
|
||||
的不可变快照。读端不再走 actor RPC,而是 ``fetch_snapshot()`` 一次拿到全量
|
||||
当前配置,亚毫秒级共享内存读,绕开单 actor 的吞吐瓶颈。
|
||||
|
||||
GSM 仍然是 source of truth + 写入串行化器,但读路径解耦:
|
||||
|
||||
- 写入 → GSM 内存更新 → ``ray.put(snapshot)`` 拿到新 ObjectRef → 版本号 +1
|
||||
- 读取 → ``current_config_ref()`` 拿 (version, ref) → ``ray.get(ref)`` 直读
|
||||
|
||||
旧的 ``get_provider / get_individual / ...`` 接口保留不动,是低频路径的兜底;
|
||||
新代码(特别是 skill task 这种高并发热路径)应优先走 ``fetch_snapshot``。
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any, Callable, Dict, List, Optional, Tuple
|
||||
|
||||
import ray
|
||||
|
||||
from kilostar.core.global_state_machine.model_provider.base_provider import Provider
|
||||
from kilostar.utils.logger import get_logger
|
||||
|
||||
_logger = get_logger("gsm_snapshot")
|
||||
|
||||
|
||||
@dataclass
|
||||
class GSMSnapshot:
|
||||
"""GSM 配置的不可变快照,所有字段都必须 cloudpickle 友好。
|
||||
|
||||
本类故意不放 ``FunctionToolset`` 实例 —— pydantic-ai toolset 的可序列化性
|
||||
随版本可能变动,让 task 端按 ``tool_funcs`` + ``tool_mapper`` 自己装配
|
||||
既隔离了 pydantic-ai 的实现风险,又让 snapshot 体积更小。
|
||||
"""
|
||||
|
||||
version: int = 0
|
||||
providers: Dict[str, Provider] = field(default_factory=dict)
|
||||
individuals: Dict[str, Dict[str, Any]] = field(default_factory=dict)
|
||||
mcp_servers: Dict[str, Dict[str, Any]] = field(default_factory=dict)
|
||||
tool_configs: Dict[str, Dict[str, Any]] = field(default_factory=dict)
|
||||
custom_toolsets: Dict[str, Dict[str, Any]] = field(default_factory=dict)
|
||||
skills: Dict[str, Tuple[str, str]] = field(default_factory=dict)
|
||||
tool_metadata: Dict[str, Dict[str, Any]] = field(default_factory=dict)
|
||||
tool_funcs: Dict[str, Callable[..., Any]] = field(default_factory=dict)
|
||||
third_party_funcs: Dict[str, Callable[..., Any]] = field(default_factory=dict)
|
||||
tool_mapper: Dict[str, Dict[str, type]] = field(default_factory=dict)
|
||||
# ``{scope: [tool_name, ...]}``:系统工具按 scope 维护的工具名清单。
|
||||
# 客户端按名字 + ``tool_funcs`` 在自己进程里重建 FunctionToolset,
|
||||
# 避开把不可序列化/版本耦合的 toolset 实例塞进快照的坑。
|
||||
system_tools_by_scope: Dict[str, List[str]] = field(default_factory=dict)
|
||||
|
||||
|
||||
_local_cache: Dict[str, Any] = {"version": -1, "snapshot": None}
|
||||
_cache_lock = asyncio.Lock()
|
||||
|
||||
|
||||
async def fetch_snapshot(
|
||||
*,
|
||||
use_cache: bool = True,
|
||||
gsm_actor: Optional[Any] = None,
|
||||
) -> GSMSnapshot:
|
||||
"""拉取当前 GSM 快照。
|
||||
|
||||
优先走"版本号检查 + ObjectRef 直读"路径:
|
||||
|
||||
1. 调 ``gsm.current_version.remote()`` 看本地缓存是否还新(一次轻量 RPC)
|
||||
2. 若本地缓存版本号一致,直接返回缓存(亚毫秒,零网络)
|
||||
3. 否则调 ``gsm.current_config_ref.remote()`` 拿 ref,``ray.get`` 解出
|
||||
4. 更新本地缓存
|
||||
|
||||
Args:
|
||||
use_cache: 默认开启进程内 LRU 缓存(实际是单槽位,持有当前版本);
|
||||
测试或诊断场景可关掉强制重拉。
|
||||
gsm_actor: 可选传入 GSM actor handle;省略时通过 ``ray_actor_hook`` 获取。
|
||||
|
||||
Note:
|
||||
本函数在 task / actor 进程内多次调用是廉价的;建议每次需要 config 时
|
||||
现取,不要把 snapshot 长期持有跨任务边界(避免 ObjectRef 阻碍回收)。
|
||||
"""
|
||||
if gsm_actor is None:
|
||||
from kilostar.utils.ray_hook import ray_actor_hook
|
||||
|
||||
gsm_actor = ray_actor_hook("global_state_machine").global_state_machine
|
||||
|
||||
if use_cache:
|
||||
async with _cache_lock:
|
||||
try:
|
||||
latest_version = await gsm_actor.current_version.remote()
|
||||
except Exception:
|
||||
latest_version = None
|
||||
|
||||
if (
|
||||
latest_version is not None
|
||||
and _local_cache.get("version") == latest_version
|
||||
and _local_cache.get("snapshot") is not None
|
||||
):
|
||||
return _local_cache["snapshot"]
|
||||
|
||||
version, ref = await gsm_actor.current_config_ref.remote()
|
||||
snapshot = ray.get(ref)
|
||||
_local_cache["version"] = version
|
||||
_local_cache["snapshot"] = snapshot
|
||||
return snapshot
|
||||
|
||||
version, ref = await gsm_actor.current_config_ref.remote()
|
||||
return ray.get(ref)
|
||||
|
||||
|
||||
def reset_local_cache() -> None:
|
||||
"""清空本进程内的快照缓存(测试用)。"""
|
||||
_local_cache["version"] = -1
|
||||
_local_cache["snapshot"] = None
|
||||
|
||||
|
||||
# ─── 客户端 helper:从快照重建本地视图 ─────────────────────────────
|
||||
|
||||
|
||||
def build_toolsets_for_scope(
|
||||
snapshot: GSMSnapshot, scope: str
|
||||
) -> List[Any]:
|
||||
"""在调用方进程里按 ``snapshot`` 现场组装 FunctionToolset 列表。
|
||||
|
||||
复刻 ``GlobalToolManager.get_toolsets_for_scope`` 的合并逻辑:
|
||||
|
||||
- 系统 toolset:按 ``default`` + ``scope`` 两个 bucket 拼装
|
||||
- 自定义 toolset:``custom_toolsets`` 里所有有效项
|
||||
|
||||
返回的 toolset 是 *进程局部* 的——pydantic-ai FunctionToolset 实例不能跨进程
|
||||
共享,但函数对象本身已经躺在 snapshot 里被 cloudpickle 还原过,
|
||||
重新 ``FunctionToolset(tools=[...])`` 几乎零代价。
|
||||
"""
|
||||
try:
|
||||
from pydantic_ai.toolsets import FunctionToolset
|
||||
except ImportError:
|
||||
_logger.warning("pydantic_ai.toolsets unavailable; cannot build toolsets")
|
||||
return []
|
||||
|
||||
result: List[Any] = []
|
||||
for bucket in ("default", scope):
|
||||
names = snapshot.system_tools_by_scope.get(bucket) or []
|
||||
funcs = [snapshot.tool_funcs[n] for n in names if n in snapshot.tool_funcs]
|
||||
if not funcs:
|
||||
continue
|
||||
try:
|
||||
result.append(
|
||||
FunctionToolset(tools=funcs, id=f"system::{bucket}")
|
||||
)
|
||||
except Exception as e: # pragma: no cover - 防御
|
||||
_logger.error(f"build system toolset {bucket} failed: {e}")
|
||||
|
||||
for toolset_id, defn in snapshot.custom_toolsets.items():
|
||||
names = defn.get("tools") or []
|
||||
funcs = [
|
||||
snapshot.third_party_funcs[n]
|
||||
for n in names
|
||||
if n in snapshot.third_party_funcs
|
||||
]
|
||||
if not funcs:
|
||||
continue
|
||||
try:
|
||||
result.append(
|
||||
FunctionToolset(tools=funcs, id=f"custom::{toolset_id}")
|
||||
)
|
||||
except Exception as e: # pragma: no cover - 防御
|
||||
_logger.error(f"build custom toolset {toolset_id} failed: {e}")
|
||||
|
||||
return result
|
||||
@@ -1,35 +1,42 @@
|
||||
# Copyright 2026 zhaoxi826
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import pathlib
|
||||
import importlib
|
||||
import inspect
|
||||
from collections import defaultdict
|
||||
from typing import Any, Callable, Dict, List, Type
|
||||
|
||||
from kilostar.plugin.tool_plugin.base_tool import BaseToolData
|
||||
from typing import Dict, Type
|
||||
from kilostar.utils.logger import get_logger
|
||||
|
||||
logger = get_logger("tool_manager")
|
||||
|
||||
_SYSTEM_BUCKET = "system"
|
||||
|
||||
|
||||
class GlobalToolManager:
|
||||
"""工具注册表:扫描 ``kilostar/plugin/tool_plugin/`` 下所有 BaseToolData 子类,
|
||||
按 ``action_scope`` 分桶到 ``tool_mapper[scope][plugin_name]``;无 scope 的归入 ``default``。"""
|
||||
按 ``action_scope`` 打包成 ``FunctionToolset``。
|
||||
|
||||
三类 toolset:
|
||||
- **system**:``is_system=True`` 的工具,按 scope 分组
|
||||
- **custom**:用户自定义工具组(由 ``rebuild_custom_toolsets`` 动态构建)
|
||||
- **mcp**:由 ``mcp_helper`` 独立管理,不经过本类
|
||||
|
||||
``category="mcp"`` 的工具不会被本类管理。
|
||||
"""
|
||||
|
||||
tool_metadata: Dict[str, Dict[str, Any]]
|
||||
_tool_funcs: Dict[str, Dict[str, Callable]]
|
||||
_system_toolsets: Dict[str, Any]
|
||||
_custom_toolsets: Dict[str, Any]
|
||||
_third_party_funcs: Dict[str, Callable]
|
||||
tool_mapper: Dict[str, Dict[str, Type[BaseToolData]]]
|
||||
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
self.tool_metadata = {}
|
||||
self._tool_funcs = defaultdict(dict)
|
||||
self._system_toolsets = {}
|
||||
self._custom_toolsets = {}
|
||||
self._third_party_funcs = {}
|
||||
self.tool_mapper = defaultdict(dict)
|
||||
|
||||
tool_plugin_dir = (
|
||||
@@ -39,21 +46,154 @@ class GlobalToolManager:
|
||||
return
|
||||
|
||||
for item in tool_plugin_dir.iterdir():
|
||||
if item.is_dir() and not item.name.startswith("__"):
|
||||
plugin_name = item.name
|
||||
module_name = f"kilostar.plugin.tool_plugin.{plugin_name}"
|
||||
if not (item.is_dir() and not item.name.startswith("__")):
|
||||
continue
|
||||
plugin_name = item.name
|
||||
module_name = f"kilostar.plugin.tool_plugin.{plugin_name}"
|
||||
|
||||
try:
|
||||
module = importlib.import_module(module_name)
|
||||
for name, obj in inspect.getmembers(module, inspect.isclass):
|
||||
if issubclass(obj, BaseToolData) and obj is not BaseToolData:
|
||||
# It's a valid tool class
|
||||
action_scopes = obj.model_fields.get("action_scope").default
|
||||
try:
|
||||
module = importlib.import_module(module_name)
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to import tool plugin {plugin_name}: {e}")
|
||||
continue
|
||||
|
||||
if not action_scopes:
|
||||
self.tool_mapper["default"][plugin_name] = obj
|
||||
else:
|
||||
for scope in action_scopes:
|
||||
self.tool_mapper[scope][plugin_name] = obj
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to load tool plugin {plugin_name}: {e}")
|
||||
tool_data_cls = self._find_tool_data_class(module)
|
||||
if tool_data_cls is None:
|
||||
continue
|
||||
|
||||
tool_func = getattr(module, plugin_name, None)
|
||||
if not callable(tool_func):
|
||||
logger.warning(
|
||||
f"Tool plugin '{plugin_name}' has no callable named "
|
||||
f"'{plugin_name}' in its module; skipped."
|
||||
)
|
||||
continue
|
||||
|
||||
action_scopes = (
|
||||
tool_data_cls.model_fields.get("action_scope").default or []
|
||||
)
|
||||
is_system = bool(tool_data_cls.model_fields.get("is_system").default)
|
||||
category_field = tool_data_cls.model_fields.get("category")
|
||||
category = (category_field.default if category_field else "other") or "other"
|
||||
|
||||
self.tool_metadata[plugin_name] = {
|
||||
"name": plugin_name,
|
||||
"is_system": is_system,
|
||||
"category": category,
|
||||
"action_scope": list(action_scopes),
|
||||
}
|
||||
|
||||
if category == "mcp":
|
||||
continue
|
||||
|
||||
scopes = [s for s in action_scopes if s] or ["default"]
|
||||
|
||||
if is_system:
|
||||
for scope in scopes:
|
||||
self._tool_funcs[scope][plugin_name] = tool_func
|
||||
self.tool_mapper[scope][plugin_name] = tool_data_cls
|
||||
else:
|
||||
self._third_party_funcs[plugin_name] = tool_func
|
||||
for scope in scopes:
|
||||
self.tool_mapper[scope][plugin_name] = tool_data_cls
|
||||
|
||||
self._build_system_toolsets()
|
||||
|
||||
def _build_system_toolsets(self) -> None:
|
||||
FunctionToolset = self._import_function_toolset()
|
||||
if FunctionToolset is None:
|
||||
return
|
||||
for scope, name_to_func in self._tool_funcs.items():
|
||||
if not name_to_func:
|
||||
continue
|
||||
try:
|
||||
self._system_toolsets[scope] = FunctionToolset(
|
||||
tools=list(name_to_func.values()),
|
||||
id=f"system::{scope}",
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to build system toolset {scope}: {e}")
|
||||
|
||||
def rebuild_custom_toolsets(self, custom_defs: Dict[str, Dict[str, Any]]) -> None:
|
||||
"""根据 DB 中的自定义工具组定义重建 custom FunctionToolset。"""
|
||||
FunctionToolset = self._import_function_toolset()
|
||||
if FunctionToolset is None:
|
||||
self._custom_toolsets = {}
|
||||
return
|
||||
new_map: Dict[str, Any] = {}
|
||||
for toolset_id, defn in custom_defs.items():
|
||||
tools_names = defn.get("tools") or []
|
||||
funcs = [
|
||||
self._third_party_funcs[n]
|
||||
for n in tools_names
|
||||
if n in self._third_party_funcs
|
||||
]
|
||||
if not funcs:
|
||||
continue
|
||||
try:
|
||||
new_map[toolset_id] = FunctionToolset(
|
||||
tools=funcs,
|
||||
id=f"custom::{toolset_id}",
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to build custom toolset {toolset_id}: {e}")
|
||||
self._custom_toolsets = new_map
|
||||
|
||||
@staticmethod
|
||||
def _import_function_toolset():
|
||||
try:
|
||||
from pydantic_ai.toolsets import FunctionToolset
|
||||
return FunctionToolset
|
||||
except ImportError:
|
||||
logger.warning("pydantic_ai.toolsets unavailable")
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def _find_tool_data_class(module) -> Type[BaseToolData] | None:
|
||||
for _, obj in inspect.getmembers(module, inspect.isclass):
|
||||
if issubclass(obj, BaseToolData) and obj is not BaseToolData:
|
||||
return obj
|
||||
return None
|
||||
|
||||
# ─── Toolset accessors ───
|
||||
|
||||
def get_system_toolset(self, scope: str) -> Any | None:
|
||||
return self._system_toolsets.get(scope)
|
||||
|
||||
def get_toolsets_for_scope(self, scope: str) -> List[Any]:
|
||||
"""合并 system(default + scope)+ 全部 custom toolset。"""
|
||||
result: List[Any] = []
|
||||
for s in ("default", scope):
|
||||
ts = self._system_toolsets.get(s)
|
||||
if ts is not None:
|
||||
result.append(ts)
|
||||
result.extend(self._custom_toolsets.values())
|
||||
return result
|
||||
|
||||
# ─── Metadata accessors ───
|
||||
|
||||
def is_third_party_tool(self, tool_name: str) -> bool:
|
||||
return tool_name in self._third_party_funcs
|
||||
|
||||
def get_tools_by_category(self, category: str) -> List[Dict[str, Any]]:
|
||||
return [m for m in self.tool_metadata.values() if m.get("category") == category]
|
||||
|
||||
def get_system_tools(self) -> List[Dict[str, Any]]:
|
||||
return [m for m in self.tool_metadata.values() if m.get("is_system") is True]
|
||||
|
||||
def get_third_party_tools(self) -> List[Dict[str, Any]]:
|
||||
return [
|
||||
m
|
||||
for m in self.tool_metadata.values()
|
||||
if m.get("is_system") is not True and m.get("category") != "mcp"
|
||||
]
|
||||
|
||||
def get_all_tools(self) -> List[Dict[str, Any]]:
|
||||
return list(self.tool_metadata.values())
|
||||
|
||||
# 兼容旧接口
|
||||
def get_non_system_tools(self) -> List[Dict[str, Any]]:
|
||||
return self.get_third_party_tools()
|
||||
|
||||
def get_personal_tools(self) -> List[Dict[str, Any]]:
|
||||
return self.get_third_party_tools()
|
||||
|
||||
@@ -29,6 +29,7 @@ from kilostar.core.global_state_machine.global_state_machine import GlobalStateM
|
||||
from kilostar.core.global_state_machine.model_provider.base_provider import Provider
|
||||
from kilostar.adapter.model_adapter.agent_factory import AgentFactory
|
||||
from kilostar.utils.ray_hook import ray_actor_hook
|
||||
from kilostar.utils.i18n import agent_prompt
|
||||
|
||||
|
||||
@ray.remote
|
||||
@@ -45,22 +46,19 @@ class ConsciousnessNode:
|
||||
provider_title: str,
|
||||
model_id: str,
|
||||
tools_list: list[str] = None,
|
||||
toolsets=None,
|
||||
locale: str | None = None,
|
||||
) -> None:
|
||||
system_prompt: str = (
|
||||
"你叫kilostar,是一个多智能体AI助手系统中的【意识节点 (Consciousness Node)】。\n"
|
||||
"你是系统的'高级规划师'和'架构师',负责处理监控节点分配过来的复杂任务。\n"
|
||||
"你的主要工作场景包括:\n"
|
||||
"1. 拆解任务 (Workflow Generation):结合用户的原始命令和提供的模板,生成严谨、可执行的工作流 (kilostarWorkflow),并将其输出为 ForWorkflowEngine 格式。拆解时步骤应清晰连贯。\n"
|
||||
"2. 中途指导 (Workflow Execution):在工作流执行中,如果某一步骤指派给你,你需要对控制节点的结果进行分析或提供下一步的指导,输出 ForWorkflow 格式。\n"
|
||||
"3. 总结报告 (regulatory Report):在整个工作流执行完毕后,你需要对整体流程、各个控制节点的执行情况进行审查,并生成一份技术性的总结报告,输出 ForregulatoryNode 格式。\n"
|
||||
"请确保所有的思考和生成过程符合逻辑,严密且高质量。"
|
||||
)
|
||||
system_prompt: str = agent_prompt("consciousness_node", locale=locale)
|
||||
output_type = Union[ForregulatoryNode, ForWorkflow, ForWorkflowEngine]
|
||||
from kilostar.utils.get_tool import load_tools_from_list
|
||||
from kilostar.core.global_state_machine.gsm_snapshot import fetch_snapshot
|
||||
|
||||
provider: Provider = await global_state_machine.get_provider.remote(
|
||||
provider_title
|
||||
)
|
||||
snapshot = await fetch_snapshot(gsm_actor=global_state_machine)
|
||||
provider: Provider = snapshot.providers.get(provider_title)
|
||||
if provider is None:
|
||||
from kilostar.utils.i18n import t
|
||||
raise ValueError(t("provider_not_registered", locale=locale, provider_title=provider_title))
|
||||
agent_factory = AgentFactory()
|
||||
|
||||
callables = load_tools_from_list(tools_list)
|
||||
@@ -72,6 +70,7 @@ class ConsciousnessNode:
|
||||
deps_type=ConsciousnessNodeDeps,
|
||||
agent_name="consciousness_node",
|
||||
tools=callables,
|
||||
toolsets=toolsets,
|
||||
)
|
||||
|
||||
@self.agent.system_prompt
|
||||
@@ -95,6 +94,13 @@ class ConsciousnessNode:
|
||||
开始进行工作流设计的交互过程(与用户通过 SSE 进行确认或直接生成)
|
||||
目前简化为:直接根据 command 拆解并构建工作流,然后提交执行。
|
||||
"""
|
||||
from kilostar.utils.request_context import trace_id_scope
|
||||
|
||||
# 进入工作流域:把 trace_id 绑到 contextvars,本协程所有日志自动带上它
|
||||
with trace_id_scope(trace_id):
|
||||
await self._do_start_workflow_design(trace_id, command)
|
||||
|
||||
async def _do_start_workflow_design(self, trace_id: str, command: str):
|
||||
self.logger.info(
|
||||
f"ConsciousnessNode: 开始为 trace_id {trace_id} 设计工作流。原始命令:{command}"
|
||||
)
|
||||
@@ -116,11 +122,11 @@ class ConsciousnessNode:
|
||||
original_command=command, available_skills=available_skills
|
||||
)
|
||||
|
||||
# 通知 SSE 正在生成图结构
|
||||
# 通知 SSE 正在生成图结构(pending 队列:节点端写入 → API SSE 读取,单向下行)
|
||||
global_workflow_manager = ray_actor_hook(
|
||||
"global_workflow_manager"
|
||||
).global_workflow_manager
|
||||
await global_workflow_manager.put_received.remote(
|
||||
await global_workflow_manager.put_pending.remote(
|
||||
trace_id, "正在为您构建并规划工作流任务节点,请稍候..."
|
||||
)
|
||||
|
||||
@@ -131,17 +137,17 @@ class ConsciousnessNode:
|
||||
workflow = result.workflow
|
||||
workflow.trace_id = trace_id
|
||||
|
||||
await global_workflow_manager.put_received.remote(
|
||||
await global_workflow_manager.put_pending.remote(
|
||||
trace_id, "工作流构建完成,即将开始执行!"
|
||||
)
|
||||
|
||||
# 将生成的完整工作流提交执行
|
||||
workflow_engine = ray_actor_hook(
|
||||
"workflow_running_engine"
|
||||
).workflow_running_engine
|
||||
await workflow_engine.execute_workflow.remote(workflow)
|
||||
# 直接以 ray task 形式 fire workflow,不再经过 WorkflowRunningEngine 这层中转:
|
||||
# workflow 是一次性、有头有尾的执行,task 语义比常驻 actor 更贴。
|
||||
from kilostar.core.work.workflow.workflow_engine import run_workflow_task
|
||||
|
||||
run_workflow_task.remote(workflow.model_dump(), trace_id)
|
||||
else:
|
||||
await global_workflow_manager.put_received.remote(
|
||||
await global_workflow_manager.put_pending.remote(
|
||||
trace_id, "很抱歉,工作流生成失败。"
|
||||
)
|
||||
await postgres_database.update_workflow_status.remote(trace_id, "failed")
|
||||
|
||||
@@ -22,6 +22,7 @@ from kilostar.core.individual.control_node.template import (
|
||||
ForWorkflowInput,
|
||||
ControlNodeDeps,
|
||||
)
|
||||
from kilostar.utils.i18n import agent_prompt
|
||||
|
||||
|
||||
@ray.remote
|
||||
@@ -44,6 +45,8 @@ class ControlNode:
|
||||
provider_title: str,
|
||||
model_id: str,
|
||||
tools_list: list[str] = None,
|
||||
toolsets=None,
|
||||
locale: str | None = None,
|
||||
) -> None:
|
||||
"""
|
||||
create_agent方法,将agent对象装配到Control的属性内
|
||||
@@ -54,25 +57,21 @@ class ControlNode:
|
||||
global_state_machine: 全局状态机
|
||||
provider_title: 供应商名
|
||||
model_id: 模型id
|
||||
locale: 语言代码(zh/en),控制system prompt语言
|
||||
|
||||
Returns:
|
||||
无返回
|
||||
"""
|
||||
system_prompt: str = (
|
||||
"你叫kilostar,是一个多智能体AI助手系统中的【控制节点 (Control Node)】。\n"
|
||||
"你是系统的'执行者'和'车间主任',专门负责执行工作流中分配给你的具体子任务。\n"
|
||||
"你的工作职责是:\n"
|
||||
"1. 仔细分析分配给你的工作流步骤 (workflow_step) 的目标和要求。\n"
|
||||
"2. 运用你被分配的工具 (如有) 或者依靠自身的知识和推理能力,精准、高效地完成该任务。\n"
|
||||
"3. 将执行的结果、产生的数据或者具体的输出,严格按照 ForWorkflow 格式返回。\n"
|
||||
"请注意:你的输出应当具体、实用,直接提供任务所要求的结果,不要做过多无关的寒暄。"
|
||||
)
|
||||
system_prompt: str = agent_prompt("control_node", locale=locale)
|
||||
output_type = ForWorkflow
|
||||
from kilostar.utils.get_tool import load_tools_from_list
|
||||
from kilostar.core.global_state_machine.gsm_snapshot import fetch_snapshot
|
||||
|
||||
provider: Provider = await global_state_machine.get_provider.remote(
|
||||
provider_title
|
||||
)
|
||||
snapshot = await fetch_snapshot(gsm_actor=global_state_machine)
|
||||
provider: Provider = snapshot.providers.get(provider_title)
|
||||
if provider is None:
|
||||
from kilostar.utils.i18n import t
|
||||
raise ValueError(t("provider_not_registered", locale=locale, provider_title=provider_title))
|
||||
agent_factory = AgentFactory()
|
||||
|
||||
callables = load_tools_from_list(tools_list)
|
||||
@@ -84,6 +83,7 @@ class ControlNode:
|
||||
deps_type=ControlNodeDeps,
|
||||
agent_name="control_node",
|
||||
tools=callables,
|
||||
toolsets=toolsets,
|
||||
)
|
||||
|
||||
@self.agent.system_prompt
|
||||
|
||||
@@ -24,6 +24,7 @@ from kilostar.core.individual.regulatory_node.template import (
|
||||
MessageResponse
|
||||
)
|
||||
from pydantic_ai import RunContext, Agent
|
||||
from kilostar.utils.i18n import agent_prompt
|
||||
|
||||
|
||||
@ray.remote
|
||||
@@ -46,6 +47,8 @@ class RegulatoryNode:
|
||||
provider_title: str,
|
||||
model_id: str,
|
||||
tools_list: list[str] = None,
|
||||
toolsets=None,
|
||||
locale: str | None = None,
|
||||
) -> None:
|
||||
"""
|
||||
create_agent方法,将agent对象装配到regulatoryNode的属性内
|
||||
@@ -56,24 +59,21 @@ class RegulatoryNode:
|
||||
provider_title: 供应商名
|
||||
model_id: 模型id
|
||||
tools_list: 工具列表
|
||||
locale: 语言代码(zh/en),控制system prompt语言
|
||||
Returns:
|
||||
无返回
|
||||
"""
|
||||
system_prompt: str = (
|
||||
"你叫kilostar,是一个多智能体AI助手系统中的【监控节点 (regulatory Node)】。\n"
|
||||
"你是系统的'前台接待'和'大脑皮层',负责接收用户的初始请求或工作流的最终报告。\n"
|
||||
"你的核心职责是进行【意图识别与路由】。请仔细阅读用户的请求:\n"
|
||||
"1. 如果用户只是进行简单的问候、闲聊或查询非常基础的信息,请直接生成友好的回复,使用 ForUser 格式。\n"
|
||||
"2. 如果用户提出的是复杂任务(如需要编写代码、多步骤规划、数据处理等),请务必将其判定为需要工作流处理的任务,"
|
||||
" 并使用 ForConsciousnessNode 格式将其移交意识节点处理。\n"
|
||||
"3. 如果你收到的是 TerminationMessage(代表工作流已完成并生成了报告),请将报告内容转化为友好的面向用户的回复,使用 ForUser 格式。\n"
|
||||
"请保持冷静、专业,并严格遵循上述路由规则。"
|
||||
)
|
||||
system_prompt: str = agent_prompt("regulatory_node", locale=locale)
|
||||
output_type = Union[MessageResponse]
|
||||
from kilostar.utils.get_tool import load_tools_from_list
|
||||
provider: Provider = await global_state_machine.get_provider.remote(
|
||||
provider_title
|
||||
)
|
||||
from kilostar.core.global_state_machine.gsm_snapshot import fetch_snapshot
|
||||
|
||||
# 走 Object Store 快照而不是 actor RPC:高频读路径不再受单 actor 串行限制
|
||||
snapshot = await fetch_snapshot(gsm_actor=global_state_machine)
|
||||
provider: Provider = snapshot.providers.get(provider_title)
|
||||
if provider is None:
|
||||
from kilostar.utils.i18n import t
|
||||
raise ValueError(t("provider_not_registered", locale=locale, provider_title=provider_title))
|
||||
agent_factory = AgentFactory()
|
||||
|
||||
callables = load_tools_from_list(tools_list)
|
||||
@@ -85,6 +85,7 @@ class RegulatoryNode:
|
||||
deps_type=RegulatoryNodeDeps,
|
||||
agent_name="regulatory_node",
|
||||
tools=callables,
|
||||
toolsets=toolsets,
|
||||
)
|
||||
|
||||
@self.agent.system_prompt
|
||||
@@ -112,16 +113,15 @@ class RegulatoryNode:
|
||||
)
|
||||
return prompt
|
||||
|
||||
async def working(self, payload: MessageRequest) -> str:
|
||||
"""working方法,是节点唯一的调用方法,对于_run函数的结果进行判断并实现最终回复
|
||||
async def working(self, payload: MessageRequest) -> Union[MessageResponse, None]:
|
||||
"""working方法,是节点唯一的调用方法,对_run函数的结果进行判断并返回最终回复
|
||||
Args:
|
||||
payload: 消息载荷,包含所有信息
|
||||
|
||||
Returns:
|
||||
str,监控节点对于用户的回复
|
||||
MessageResponse 或 None,监控节点对用户的结构化回复
|
||||
"""
|
||||
await self._run(payload)
|
||||
return ""
|
||||
return await self._run(payload)
|
||||
|
||||
async def _run(
|
||||
self, payload: MessageRequest
|
||||
@@ -140,7 +140,8 @@ class RegulatoryNode:
|
||||
deps=deps,)
|
||||
response: MessageResponse = agent_response.output
|
||||
response.platform = platform
|
||||
response.platform_id = MessageRequest.platform_id
|
||||
response.platform_id = payload.platform_id
|
||||
return response
|
||||
except:
|
||||
pass
|
||||
except Exception as e:
|
||||
self.logger.exception(f"RegulatoryNode._run failed: {e}")
|
||||
return None
|
||||
@@ -49,7 +49,7 @@ class MessageRequest(RequestModel):
|
||||
MessageRequest类
|
||||
任何消息渠道向regulatory_node发送消息请求的模型
|
||||
"""
|
||||
platform: Literal["client"]
|
||||
platform: Literal["client", "onebot"]
|
||||
user_name: str
|
||||
platform_id: Optional[str]
|
||||
message: str
|
||||
@@ -59,6 +59,6 @@ class MessageResponse(RegulatoryNodeResponse):
|
||||
MessageResponse类
|
||||
regulatory_node回复的模型
|
||||
"""
|
||||
platform: Optional[Literal["client"]] = Field(description="系统自动填入的platform")
|
||||
platform: Optional[Literal["client", "onebot"]] = Field(description="系统自动填入的platform")
|
||||
platform_id: Optional[str] = Field(description="系统自动填入的platform_id")
|
||||
reply_message: str = Field(...,description="模型回复的消息")
|
||||
|
||||
@@ -23,12 +23,16 @@ from kilostar.core.postgres_database.model.individual import (
|
||||
from kilostar.core.postgres_database.model.workflow import (
|
||||
Workflow,
|
||||
WorkflowContextModel,
|
||||
WorkflowGraphStateModel,
|
||||
)
|
||||
from kilostar.core.postgres_database.model.chat_history import (
|
||||
ChatHistoryRegister,
|
||||
ChatHistoryMessage,
|
||||
)
|
||||
from kilostar.core.postgres_database.model.system_node import SystemNodeConfigModel
|
||||
from kilostar.core.postgres_database.model.mcp_server import MCPServerModel
|
||||
from kilostar.core.postgres_database.model.tool_config import ToolConfigModel
|
||||
from kilostar.core.postgres_database.model.custom_toolset import CustomToolsetModel
|
||||
|
||||
# 兼容旧代码的别名
|
||||
Provider = ProviderModel
|
||||
@@ -49,9 +53,13 @@ __all__ = [
|
||||
"SpecialIndividualModel",
|
||||
"Workflow",
|
||||
"WorkflowContextModel",
|
||||
"WorkflowGraphStateModel",
|
||||
"ChatHistoryRegister",
|
||||
"ChatHistoryMessage",
|
||||
"SystemNodeConfigModel",
|
||||
"SystemNodeConfig",
|
||||
"MCPServerModel",
|
||||
"ToolConfigModel",
|
||||
"CustomToolsetModel",
|
||||
"AgentType",
|
||||
]
|
||||
|
||||
@@ -0,0 +1,32 @@
|
||||
from datetime import datetime
|
||||
from typing import List, Optional
|
||||
|
||||
from sqlalchemy import String, Text, DateTime, func
|
||||
from sqlalchemy.dialects.postgresql import JSONB
|
||||
from sqlalchemy.orm import Mapped, mapped_column
|
||||
|
||||
from .base import BaseDataModel
|
||||
|
||||
|
||||
class CustomToolsetModel(BaseDataModel):
|
||||
"""用户自定义工具组:把若干个非 system / 非 mcp 的工具插件打包成一个 toolset。
|
||||
|
||||
``tools`` 字段保存工具名列表(即 ``plugin/tool_plugin/`` 下的目录名);
|
||||
GSM 启动时按列表把对应工具函数装进同一个 ``FunctionToolset``。
|
||||
"""
|
||||
|
||||
__tablename__ = "custom_toolset"
|
||||
|
||||
toolset_id: Mapped[str] = mapped_column(String(64), primary_key=True)
|
||||
name: Mapped[str] = mapped_column(String(100), nullable=False)
|
||||
description: Mapped[Optional[str]] = mapped_column(Text)
|
||||
owner_id: Mapped[Optional[str]] = mapped_column(String(64), index=True)
|
||||
tools: Mapped[List[str]] = mapped_column(
|
||||
JSONB, default=list, comment="工具名列表,仅允许非 system/非 mcp 的工具"
|
||||
)
|
||||
created_at: Mapped[datetime] = mapped_column(
|
||||
DateTime(timezone=True), server_default=func.now()
|
||||
)
|
||||
updated_at: Mapped[datetime] = mapped_column(
|
||||
DateTime(timezone=True), server_default=func.now(), onupdate=func.now()
|
||||
)
|
||||
@@ -0,0 +1,29 @@
|
||||
from typing import Optional
|
||||
from datetime import datetime
|
||||
|
||||
from sqlalchemy import String, DateTime, func
|
||||
from sqlalchemy.dialects.postgresql import JSONB
|
||||
from sqlalchemy.orm import Mapped, mapped_column
|
||||
|
||||
from .base import BaseDataModel
|
||||
|
||||
|
||||
class MCPServerModel(BaseDataModel):
|
||||
"""MCP 服务器注册表,记录 stdio/sse/http 三种 transport 的连接配置。"""
|
||||
|
||||
__tablename__ = "mcp_server"
|
||||
|
||||
server_id: Mapped[str] = mapped_column(String(64), primary_key=True)
|
||||
name: Mapped[str] = mapped_column(String(100), nullable=False)
|
||||
transport: Mapped[str] = mapped_column(String(16), nullable=False)
|
||||
command: Mapped[Optional[str]] = mapped_column(String(255))
|
||||
args: Mapped[list] = mapped_column(JSONB, default=list)
|
||||
url: Mapped[Optional[str]] = mapped_column(String(500))
|
||||
tool_prefix: Mapped[Optional[str]] = mapped_column(String(64))
|
||||
env: Mapped[dict] = mapped_column(JSONB, default=dict, comment="敏感字段已 Fernet 加密")
|
||||
created_at: Mapped[datetime] = mapped_column(
|
||||
DateTime(timezone=True), server_default=func.now()
|
||||
)
|
||||
updated_at: Mapped[datetime] = mapped_column(
|
||||
DateTime(timezone=True), server_default=func.now(), onupdate=func.now()
|
||||
)
|
||||
@@ -0,0 +1,22 @@
|
||||
from datetime import datetime
|
||||
|
||||
from sqlalchemy import String, DateTime, func
|
||||
from sqlalchemy.dialects.postgresql import JSONB
|
||||
from sqlalchemy.orm import Mapped, mapped_column
|
||||
|
||||
from .base import BaseDataModel
|
||||
|
||||
|
||||
class ToolConfigModel(BaseDataModel):
|
||||
"""工具运行期配置(如 Tavily API key);config 内的敏感字段已 Fernet 加密。"""
|
||||
|
||||
__tablename__ = "tool_config"
|
||||
|
||||
tool_name: Mapped[str] = mapped_column(String(100), primary_key=True)
|
||||
config: Mapped[dict] = mapped_column(JSONB, default=dict)
|
||||
created_at: Mapped[datetime] = mapped_column(
|
||||
DateTime(timezone=True), server_default=func.now()
|
||||
)
|
||||
updated_at: Mapped[datetime] = mapped_column(
|
||||
DateTime(timezone=True), server_default=func.now(), onupdate=func.now()
|
||||
)
|
||||
@@ -79,3 +79,28 @@ class WorkflowContextModel(BaseDataModel):
|
||||
updated_at: Mapped[str] = mapped_column(
|
||||
DateTime(timezone=True), server_default=func.now(), onupdate=func.now()
|
||||
)
|
||||
|
||||
|
||||
class WorkflowGraphStateModel(BaseDataModel):
|
||||
"""pydantic_graph 持久化 blob 的存储表。
|
||||
|
||||
与 ``workflow_context`` 解耦——后者面向"业务展示 / 用户可读",前者面向
|
||||
"graph 引擎自身的状态恢复"。一份 trace_id 一行,jsonb 直接存 history 全量。
|
||||
"""
|
||||
|
||||
__tablename__ = "workflow_graph_state"
|
||||
|
||||
trace_id: Mapped[str] = mapped_column(
|
||||
String(64), primary_key=True, comment="对应的工作流 Trace ID"
|
||||
)
|
||||
history: Mapped[list] = mapped_column(
|
||||
JSONB,
|
||||
default=list,
|
||||
comment="pydantic_graph FullStatePersistence.history 的 JSON 序列化",
|
||||
)
|
||||
created_at: Mapped[str] = mapped_column(
|
||||
DateTime(timezone=True), server_default=func.now()
|
||||
)
|
||||
updated_at: Mapped[str] = mapped_column(
|
||||
DateTime(timezone=True), server_default=func.now(), onupdate=func.now()
|
||||
)
|
||||
|
||||
@@ -0,0 +1,81 @@
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from sqlalchemy import delete, select
|
||||
|
||||
from kilostar.core.postgres_database.database_exception import database_exception
|
||||
from kilostar.core.postgres_database.model.custom_toolset import CustomToolsetModel
|
||||
|
||||
|
||||
class CustomToolsetDatabase:
|
||||
"""用户自定义工具组 DAO。``tools`` 字段是工具名列表,业务层负责保证只放非 system/非 mcp 的工具。"""
|
||||
|
||||
def __init__(self, async_session_maker):
|
||||
self.async_session_maker = async_session_maker
|
||||
|
||||
@staticmethod
|
||||
def _to_dict(row: CustomToolsetModel) -> Dict[str, Any]:
|
||||
return {
|
||||
"toolset_id": row.toolset_id,
|
||||
"name": row.name,
|
||||
"description": row.description,
|
||||
"owner_id": row.owner_id,
|
||||
"tools": list(row.tools or []),
|
||||
}
|
||||
|
||||
@database_exception
|
||||
async def upsert(
|
||||
self,
|
||||
toolset_id: str,
|
||||
name: str,
|
||||
tools: List[str],
|
||||
description: Optional[str] = None,
|
||||
owner_id: Optional[str] = None,
|
||||
) -> Dict[str, Any]:
|
||||
async with self.async_session_maker() as session:
|
||||
stmt = select(CustomToolsetModel).where(
|
||||
CustomToolsetModel.toolset_id == toolset_id
|
||||
)
|
||||
row = (await session.execute(stmt)).scalar_one_or_none()
|
||||
if row:
|
||||
row.name = name
|
||||
row.description = description
|
||||
row.owner_id = owner_id
|
||||
row.tools = list(tools)
|
||||
else:
|
||||
row = CustomToolsetModel(
|
||||
toolset_id=toolset_id,
|
||||
name=name,
|
||||
description=description,
|
||||
owner_id=owner_id,
|
||||
tools=list(tools),
|
||||
)
|
||||
session.add(row)
|
||||
await session.commit()
|
||||
await session.refresh(row)
|
||||
return self._to_dict(row)
|
||||
|
||||
@database_exception
|
||||
async def get(self, toolset_id: str) -> Optional[Dict[str, Any]]:
|
||||
async with self.async_session_maker() as session:
|
||||
stmt = select(CustomToolsetModel).where(
|
||||
CustomToolsetModel.toolset_id == toolset_id
|
||||
)
|
||||
row = (await session.execute(stmt)).scalar_one_or_none()
|
||||
return self._to_dict(row) if row else None
|
||||
|
||||
@database_exception
|
||||
async def list_all(self) -> List[Dict[str, Any]]:
|
||||
async with self.async_session_maker() as session:
|
||||
stmt = select(CustomToolsetModel)
|
||||
rows = (await session.execute(stmt)).scalars().all()
|
||||
return [self._to_dict(r) for r in rows]
|
||||
|
||||
@database_exception
|
||||
async def delete(self, toolset_id: str) -> bool:
|
||||
async with self.async_session_maker() as session:
|
||||
stmt = delete(CustomToolsetModel).where(
|
||||
CustomToolsetModel.toolset_id == toolset_id
|
||||
)
|
||||
result = await session.execute(stmt)
|
||||
await session.commit()
|
||||
return result.rowcount > 0
|
||||
@@ -0,0 +1,79 @@
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from sqlalchemy import delete, select
|
||||
|
||||
from kilostar.core.postgres_database.database_exception import database_exception
|
||||
from kilostar.core.postgres_database.model.mcp_server import MCPServerModel
|
||||
from kilostar.utils.crypto import decrypt_dict_secrets, encrypt_dict_secrets
|
||||
|
||||
|
||||
class MCPServerDatabase:
|
||||
"""MCP 服务器配置 DAO;写入前自动加密 ``env`` 中的敏感字段,读出后自动解密。"""
|
||||
|
||||
def __init__(self, async_session_maker):
|
||||
self.async_session_maker = async_session_maker
|
||||
|
||||
@staticmethod
|
||||
def _to_dict(row: MCPServerModel) -> Dict[str, Any]:
|
||||
return {
|
||||
"server_id": row.server_id,
|
||||
"name": row.name,
|
||||
"transport": row.transport,
|
||||
"command": row.command,
|
||||
"args": row.args or [],
|
||||
"url": row.url,
|
||||
"tool_prefix": row.tool_prefix,
|
||||
"env": decrypt_dict_secrets(row.env or {}),
|
||||
}
|
||||
|
||||
@database_exception
|
||||
async def upsert(self, server_id: str, config: Dict[str, Any]) -> Dict[str, Any]:
|
||||
env = encrypt_dict_secrets(config.get("env") or {})
|
||||
async with self.async_session_maker() as session:
|
||||
stmt = select(MCPServerModel).where(MCPServerModel.server_id == server_id)
|
||||
row = (await session.execute(stmt)).scalar_one_or_none()
|
||||
if row:
|
||||
row.name = config.get("name", row.name)
|
||||
row.transport = config.get("transport", row.transport)
|
||||
row.command = config.get("command")
|
||||
row.args = config.get("args") or []
|
||||
row.url = config.get("url")
|
||||
row.tool_prefix = config.get("tool_prefix")
|
||||
row.env = env
|
||||
else:
|
||||
row = MCPServerModel(
|
||||
server_id=server_id,
|
||||
name=config.get("name", server_id),
|
||||
transport=config.get("transport", "stdio"),
|
||||
command=config.get("command"),
|
||||
args=config.get("args") or [],
|
||||
url=config.get("url"),
|
||||
tool_prefix=config.get("tool_prefix"),
|
||||
env=env,
|
||||
)
|
||||
session.add(row)
|
||||
await session.commit()
|
||||
await session.refresh(row)
|
||||
return self._to_dict(row)
|
||||
|
||||
@database_exception
|
||||
async def get(self, server_id: str) -> Optional[Dict[str, Any]]:
|
||||
async with self.async_session_maker() as session:
|
||||
stmt = select(MCPServerModel).where(MCPServerModel.server_id == server_id)
|
||||
row = (await session.execute(stmt)).scalar_one_or_none()
|
||||
return self._to_dict(row) if row else None
|
||||
|
||||
@database_exception
|
||||
async def list_all(self) -> List[Dict[str, Any]]:
|
||||
async with self.async_session_maker() as session:
|
||||
stmt = select(MCPServerModel)
|
||||
rows = (await session.execute(stmt)).scalars().all()
|
||||
return [self._to_dict(r) for r in rows]
|
||||
|
||||
@database_exception
|
||||
async def delete(self, server_id: str) -> bool:
|
||||
async with self.async_session_maker() as session:
|
||||
stmt = delete(MCPServerModel).where(MCPServerModel.server_id == server_id)
|
||||
result = await session.execute(stmt)
|
||||
await session.commit()
|
||||
return result.rowcount > 0
|
||||
@@ -17,10 +17,37 @@ from typing import List
|
||||
from kilostar.core.postgres_database.model.provider import ProviderModel
|
||||
from sqlalchemy import select
|
||||
from kilostar.core.postgres_database.database_exception import database_exception
|
||||
from kilostar.utils.crypto import (
|
||||
CryptoError,
|
||||
decrypt_secret,
|
||||
encrypt_secret,
|
||||
is_encrypted,
|
||||
)
|
||||
from kilostar.utils.logger import get_logger
|
||||
|
||||
logger = get_logger("provider_dao")
|
||||
|
||||
|
||||
def _decrypt_apikey(value):
|
||||
if not value:
|
||||
return value
|
||||
if not is_encrypted(value):
|
||||
return value
|
||||
try:
|
||||
return decrypt_secret(value)
|
||||
except CryptoError as e:
|
||||
logger.error(f"Provider apikey 解密失败: {e}")
|
||||
return value
|
||||
|
||||
|
||||
def _encrypt_apikey(value):
|
||||
if not value or is_encrypted(value):
|
||||
return value
|
||||
return encrypt_secret(value)
|
||||
|
||||
|
||||
class ProviderDatabase:
|
||||
"""Provider 表的 DAO:模型 Provider 的增删查改。"""
|
||||
"""Provider 表的 DAO:模型 Provider 的增删查改;``provider_apikey`` 透明 Fernet 加解密。"""
|
||||
|
||||
def __init__(self, async_session_maker):
|
||||
self.async_session_maker = async_session_maker
|
||||
@@ -37,11 +64,10 @@ class ProviderDatabase:
|
||||
provider_id=provider.provider_id,
|
||||
provider_title=provider.provider_title,
|
||||
provider_url=provider.provider_url,
|
||||
provider_apikey=provider.provider_apikey,
|
||||
provider_apikey=_decrypt_apikey(provider.provider_apikey),
|
||||
provider_models=provider.provider_models,
|
||||
provider_type=provider.provider_type,
|
||||
provider_owner=provider.provider_owner,
|
||||
provider_status=provider.provider_status,
|
||||
is_active=provider.is_active,
|
||||
)
|
||||
for provider in results
|
||||
@@ -50,7 +76,9 @@ class ProviderDatabase:
|
||||
|
||||
@database_exception
|
||||
async def add_provider(self, **kwargs) -> None:
|
||||
"""新建一条 Provider 记录;字段通过 kwargs 直接传给 ProviderModel。"""
|
||||
"""新建一条 Provider 记录;``provider_apikey`` 写入前自动加密。"""
|
||||
if "provider_apikey" in kwargs:
|
||||
kwargs["provider_apikey"] = _encrypt_apikey(kwargs["provider_apikey"])
|
||||
async with self.async_session_maker() as session:
|
||||
provider = ProviderModel(**kwargs)
|
||||
session.add(provider)
|
||||
@@ -67,7 +95,9 @@ class ProviderDatabase:
|
||||
|
||||
@database_exception
|
||||
async def update_provider(self, provider_id: str, **kwargs) -> None:
|
||||
"""部分更新指定 Provider 的字段;不存在时返回 None,否则返回刷新后的对象。"""
|
||||
"""部分更新指定 Provider 的字段;``provider_apikey`` 写入前自动加密。"""
|
||||
if "provider_apikey" in kwargs:
|
||||
kwargs["provider_apikey"] = _encrypt_apikey(kwargs["provider_apikey"])
|
||||
async with self.async_session_maker() as session:
|
||||
provider = await session.get(ProviderModel, provider_id)
|
||||
if provider is not None:
|
||||
@@ -76,5 +106,7 @@ class ProviderDatabase:
|
||||
session.add(provider)
|
||||
await session.commit()
|
||||
await session.refresh(provider)
|
||||
# 解密返回,方便上游使用
|
||||
provider.provider_apikey = _decrypt_apikey(provider.provider_apikey)
|
||||
return provider
|
||||
return None
|
||||
|
||||
@@ -0,0 +1,58 @@
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from sqlalchemy import delete, select
|
||||
|
||||
from kilostar.core.postgres_database.database_exception import database_exception
|
||||
from kilostar.core.postgres_database.model.tool_config import ToolConfigModel
|
||||
from kilostar.utils.crypto import decrypt_dict_secrets, encrypt_dict_secrets
|
||||
|
||||
|
||||
class ToolConfigDatabase:
|
||||
"""工具运行期配置 DAO;config 中的敏感字段(key/token/secret/password 系列)自动加解密。"""
|
||||
|
||||
def __init__(self, async_session_maker):
|
||||
self.async_session_maker = async_session_maker
|
||||
|
||||
@staticmethod
|
||||
def _to_dict(row: ToolConfigModel) -> Dict[str, Any]:
|
||||
return {
|
||||
"tool_name": row.tool_name,
|
||||
"config": decrypt_dict_secrets(row.config or {}),
|
||||
}
|
||||
|
||||
@database_exception
|
||||
async def upsert(self, tool_name: str, config: Dict[str, Any]) -> Dict[str, Any]:
|
||||
encrypted = encrypt_dict_secrets(config or {})
|
||||
async with self.async_session_maker() as session:
|
||||
stmt = select(ToolConfigModel).where(ToolConfigModel.tool_name == tool_name)
|
||||
row = (await session.execute(stmt)).scalar_one_or_none()
|
||||
if row:
|
||||
row.config = encrypted
|
||||
else:
|
||||
row = ToolConfigModel(tool_name=tool_name, config=encrypted)
|
||||
session.add(row)
|
||||
await session.commit()
|
||||
await session.refresh(row)
|
||||
return self._to_dict(row)
|
||||
|
||||
@database_exception
|
||||
async def get(self, tool_name: str) -> Optional[Dict[str, Any]]:
|
||||
async with self.async_session_maker() as session:
|
||||
stmt = select(ToolConfigModel).where(ToolConfigModel.tool_name == tool_name)
|
||||
row = (await session.execute(stmt)).scalar_one_or_none()
|
||||
return self._to_dict(row) if row else None
|
||||
|
||||
@database_exception
|
||||
async def list_all(self) -> List[Dict[str, Any]]:
|
||||
async with self.async_session_maker() as session:
|
||||
stmt = select(ToolConfigModel)
|
||||
rows = (await session.execute(stmt)).scalars().all()
|
||||
return [self._to_dict(r) for r in rows]
|
||||
|
||||
@database_exception
|
||||
async def delete(self, tool_name: str) -> bool:
|
||||
async with self.async_session_maker() as session:
|
||||
stmt = delete(ToolConfigModel).where(ToolConfigModel.tool_name == tool_name)
|
||||
result = await session.execute(stmt)
|
||||
await session.commit()
|
||||
return result.rowcount > 0
|
||||
@@ -17,6 +17,7 @@ from typing import List, Optional
|
||||
from kilostar.core.postgres_database.model.workflow import (
|
||||
Workflow,
|
||||
WorkflowContextModel,
|
||||
WorkflowGraphStateModel,
|
||||
)
|
||||
from sqlalchemy.ext.asyncio import async_sessionmaker, AsyncSession
|
||||
from kilostar.core.postgres_database.database_exception import database_exception
|
||||
@@ -101,3 +102,58 @@ class WorkflowDatabase:
|
||||
)
|
||||
results = await session.execute(statement)
|
||||
return results.scalar_one_or_none()
|
||||
|
||||
# ─── pydantic_graph 持久化(resume 用)─────────────────────────────
|
||||
|
||||
@database_exception
|
||||
async def upsert_workflow_graph_state(
|
||||
self, trace_id: str, history: list
|
||||
) -> WorkflowGraphStateModel:
|
||||
"""落 pydantic_graph FullStatePersistence.history 的 JSON 视图。
|
||||
|
||||
每个节点边界都会被引擎调一次,覆盖式写入;回滚到任一历史点是 graph
|
||||
引擎自身的能力,DB 这层只保留最新版本。
|
||||
"""
|
||||
async with self.async_session_maker() as session:
|
||||
statement = select(WorkflowGraphStateModel).where(
|
||||
WorkflowGraphStateModel.trace_id == trace_id
|
||||
)
|
||||
results = await session.execute(statement)
|
||||
record = results.scalar_one_or_none()
|
||||
if record:
|
||||
record.history = history
|
||||
else:
|
||||
record = WorkflowGraphStateModel(
|
||||
trace_id=trace_id, history=history
|
||||
)
|
||||
session.add(record)
|
||||
await session.commit()
|
||||
await session.refresh(record)
|
||||
return record
|
||||
|
||||
@database_exception
|
||||
async def get_workflow_graph_state(
|
||||
self, trace_id: str
|
||||
) -> Optional[WorkflowGraphStateModel]:
|
||||
"""读取 graph 持久化 history;不存在返回 ``None``。"""
|
||||
async with self.async_session_maker() as session:
|
||||
statement = select(WorkflowGraphStateModel).where(
|
||||
WorkflowGraphStateModel.trace_id == trace_id
|
||||
)
|
||||
results = await session.execute(statement)
|
||||
return results.scalar_one_or_none()
|
||||
|
||||
@database_exception
|
||||
async def delete_workflow_graph_state(self, trace_id: str) -> bool:
|
||||
"""删除某个工作流的 graph 持久化记录(用于显式清理)。"""
|
||||
async with self.async_session_maker() as session:
|
||||
statement = select(WorkflowGraphStateModel).where(
|
||||
WorkflowGraphStateModel.trace_id == trace_id
|
||||
)
|
||||
results = await session.execute(statement)
|
||||
record = results.scalar_one_or_none()
|
||||
if record is None:
|
||||
return False
|
||||
await session.delete(record)
|
||||
await session.commit()
|
||||
return True
|
||||
|
||||
@@ -38,6 +38,9 @@ from kilostar.core.postgres_database.model.chat_history import (
|
||||
ChatHistoryMessage,
|
||||
)
|
||||
from kilostar.core.postgres_database.model.system_node import SystemNodeConfigModel
|
||||
from kilostar.core.postgres_database.model.mcp_server import MCPServerModel
|
||||
from kilostar.core.postgres_database.model.tool_config import ToolConfigModel
|
||||
from kilostar.core.postgres_database.model.custom_toolset import CustomToolsetModel
|
||||
|
||||
from .module.individual import IndividualDatabase
|
||||
from .module.user import AuthDatabase
|
||||
@@ -45,6 +48,9 @@ from .module.provider import ProviderDatabase
|
||||
from .module.system_node import SystemNodeDatabase
|
||||
from .module.workflow import WorkflowDatabase
|
||||
from .module.chat_history import ChatHistoryDatabase
|
||||
from .module.mcp_server import MCPServerDatabase
|
||||
from .module.tool_config import ToolConfigDatabase
|
||||
from .module.custom_toolset import CustomToolsetDatabase
|
||||
|
||||
|
||||
@ray.remote
|
||||
@@ -76,6 +82,9 @@ class PostgresDatabase:
|
||||
self._system_node_database = SystemNodeDatabase(self.async_session_maker)
|
||||
self._workflow_database = WorkflowDatabase(self.async_session_maker)
|
||||
self._chat_history_database = ChatHistoryDatabase(self.async_session_maker)
|
||||
self._mcp_server_database = MCPServerDatabase(self.async_session_maker)
|
||||
self._tool_config_database = ToolConfigDatabase(self.async_session_maker)
|
||||
self._custom_toolset_database = CustomToolsetDatabase(self.async_session_maker)
|
||||
|
||||
self.ready_event = asyncio.Event()
|
||||
|
||||
@@ -91,6 +100,15 @@ class PostgresDatabase:
|
||||
finally:
|
||||
self.ready_event.set()
|
||||
|
||||
async def ping(self) -> bool:
|
||||
"""轻量探活:等待 ready 后执行 ``SELECT 1``。"""
|
||||
from sqlalchemy import text
|
||||
|
||||
await self.ready_event.wait()
|
||||
async with self.async_engine.connect() as conn:
|
||||
await conn.execute(text("SELECT 1"))
|
||||
return True
|
||||
|
||||
# Auth Database Methods
|
||||
async def add_user(self, user_name: str, hashed_password: str):
|
||||
"""新建一名用户。"""
|
||||
@@ -242,6 +260,24 @@ class PostgresDatabase:
|
||||
await self.ready_event.wait()
|
||||
return await self._workflow_database.get_workflow_context(trace_id)
|
||||
|
||||
# Workflow Graph State (pydantic_graph 持久化)
|
||||
async def upsert_workflow_graph_state(self, trace_id: str, history: list):
|
||||
"""覆盖式写入 graph 持久化 history(pydantic_graph 节点边界自动调用)。"""
|
||||
await self.ready_event.wait()
|
||||
return await self._workflow_database.upsert_workflow_graph_state(
|
||||
trace_id, history
|
||||
)
|
||||
|
||||
async def get_workflow_graph_state(self, trace_id: str):
|
||||
"""读取 graph 持久化记录,用于跨进程 resume。"""
|
||||
await self.ready_event.wait()
|
||||
return await self._workflow_database.get_workflow_graph_state(trace_id)
|
||||
|
||||
async def delete_workflow_graph_state(self, trace_id: str):
|
||||
"""显式清理 graph 持久化记录(已完成/失败的 workflow 释放空间)。"""
|
||||
await self.ready_event.wait()
|
||||
return await self._workflow_database.delete_workflow_graph_state(trace_id)
|
||||
|
||||
# Chat History Database Methods
|
||||
async def create_chat_session(self, user_id: str, title: str = "新对话"):
|
||||
"""新建一个聊天会话。"""
|
||||
@@ -264,3 +300,79 @@ class PostgresDatabase:
|
||||
"""返回某个聊天会话的全部消息。"""
|
||||
await self.ready_event.wait()
|
||||
return await self._chat_history_database.list_chat_messages(chat_id)
|
||||
|
||||
# MCP Server Database Methods
|
||||
async def upsert_mcp_server(self, server_id: str, config: dict):
|
||||
"""插入或更新一条 MCP 服务器配置;env 中敏感字段自动加密。"""
|
||||
await self.ready_event.wait()
|
||||
return await self._mcp_server_database.upsert(server_id, config)
|
||||
|
||||
async def get_mcp_server_db(self, server_id: str):
|
||||
"""读取单条 MCP 服务器配置;env 自动解密。"""
|
||||
await self.ready_event.wait()
|
||||
return await self._mcp_server_database.get(server_id)
|
||||
|
||||
async def list_mcp_servers_db(self):
|
||||
"""读取全部 MCP 服务器配置。"""
|
||||
await self.ready_event.wait()
|
||||
return await self._mcp_server_database.list_all()
|
||||
|
||||
async def delete_mcp_server_db(self, server_id: str):
|
||||
"""删除某条 MCP 服务器配置。"""
|
||||
await self.ready_event.wait()
|
||||
return await self._mcp_server_database.delete(server_id)
|
||||
|
||||
# Tool Config Database Methods
|
||||
async def upsert_tool_config(self, tool_name: str, config: dict):
|
||||
"""插入或更新某工具的运行期配置;敏感字段自动加密。"""
|
||||
await self.ready_event.wait()
|
||||
return await self._tool_config_database.upsert(tool_name, config)
|
||||
|
||||
async def get_tool_config_db(self, tool_name: str):
|
||||
"""读取某工具的运行期配置;敏感字段自动解密。"""
|
||||
await self.ready_event.wait()
|
||||
return await self._tool_config_database.get(tool_name)
|
||||
|
||||
async def list_tool_configs_db(self):
|
||||
"""读取全部工具的运行期配置。"""
|
||||
await self.ready_event.wait()
|
||||
return await self._tool_config_database.list_all()
|
||||
|
||||
async def delete_tool_config_db(self, tool_name: str):
|
||||
"""删除某工具的运行期配置。"""
|
||||
await self.ready_event.wait()
|
||||
return await self._tool_config_database.delete(tool_name)
|
||||
|
||||
# Custom Toolset Database Methods
|
||||
async def upsert_custom_toolset(
|
||||
self,
|
||||
toolset_id: str,
|
||||
name: str,
|
||||
tools: list,
|
||||
description: str = None,
|
||||
owner_id: str = None,
|
||||
):
|
||||
"""插入或更新一个用户自定义工具组。"""
|
||||
await self.ready_event.wait()
|
||||
return await self._custom_toolset_database.upsert(
|
||||
toolset_id=toolset_id,
|
||||
name=name,
|
||||
tools=tools,
|
||||
description=description,
|
||||
owner_id=owner_id,
|
||||
)
|
||||
|
||||
async def get_custom_toolset(self, toolset_id: str):
|
||||
"""按 ID 读取一个自定义工具组。"""
|
||||
await self.ready_event.wait()
|
||||
return await self._custom_toolset_database.get(toolset_id)
|
||||
|
||||
async def list_custom_toolsets(self):
|
||||
"""读取全部自定义工具组。"""
|
||||
await self.ready_event.wait()
|
||||
return await self._custom_toolset_database.list_all()
|
||||
|
||||
async def delete_custom_toolset(self, toolset_id: str):
|
||||
"""删除一个自定义工具组。"""
|
||||
await self.ready_event.wait()
|
||||
return await self._custom_toolset_database.delete(toolset_id)
|
||||
|
||||
@@ -0,0 +1,191 @@
|
||||
# Copyright 2026 zhaoxi826
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Postgres 后端的 ``BaseStatePersistence`` 实现,让 graph 跨进程 resume。
|
||||
|
||||
设计思路:
|
||||
|
||||
- **复用 ``FullStatePersistence`` 的内存语义**:snapshot/record_run/load_next 等
|
||||
的实现已经做得很完备(NodeSnapshot 状态机、deep_copy、type adapter),不重写
|
||||
这些细节,只是在每次"史会发生变更"的钩子触发后,把内存 history 序列化为 JSON
|
||||
写到 ``workflow_graph_state`` 表里。
|
||||
- **异步 IO 解耦**:DB 写入是 fire-and-forget 模式,不在 graph 节点路径上阻塞——
|
||||
graph 跑得快,持久化追得上即可。但 ``snapshot_end`` 一定会 await(确保关机
|
||||
之前最终 history 落盘)。
|
||||
- **resume 入口**:从 DB 读出 history JSON → ``load_json`` 还原 → ``Graph.iter_from_persistence``
|
||||
跑剩余节点。
|
||||
|
||||
这一层不直接持有 SQLAlchemy session;通过两个 awaitable 注入 IO,便于测试。
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
from contextlib import asynccontextmanager
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any, AsyncIterator, Awaitable, Callable, Optional
|
||||
|
||||
from pydantic_graph import BaseNode, End
|
||||
from pydantic_graph.persistence import (
|
||||
BaseStatePersistence,
|
||||
NodeSnapshot,
|
||||
Snapshot,
|
||||
)
|
||||
from pydantic_graph.persistence.in_mem import FullStatePersistence
|
||||
|
||||
from kilostar.utils.logger import get_logger
|
||||
|
||||
_logger = get_logger("graph_persistence")
|
||||
|
||||
|
||||
# IO 注入签名:写一份 history JSON / 读一份 history JSON(None=没有)
|
||||
WriteHistory = Callable[[str, Any], Awaitable[None]]
|
||||
ReadHistory = Callable[[str], Awaitable[Optional[Any]]]
|
||||
|
||||
|
||||
@dataclass
|
||||
class PostgresStatePersistence(BaseStatePersistence):
|
||||
"""复用 ``FullStatePersistence`` 内存语义 + 把 history 落 postgres。
|
||||
|
||||
每个 hook 触发后异步把 history 写库,DB 失败不影响 graph 继续推进
|
||||
(只记 warning),保证 graph 自身的可用性。
|
||||
"""
|
||||
|
||||
trace_id: str
|
||||
write_history: WriteHistory
|
||||
read_history: ReadHistory
|
||||
_inner: FullStatePersistence = field(default_factory=FullStatePersistence)
|
||||
|
||||
# ─── BaseStatePersistence 接口 ──────────────────────────────────
|
||||
|
||||
async def snapshot_node(self, state, next_node):
|
||||
await self._inner.snapshot_node(state, next_node)
|
||||
await self._flush()
|
||||
|
||||
async def snapshot_node_if_new(self, snapshot_id, state, next_node):
|
||||
await self._inner.snapshot_node_if_new(snapshot_id, state, next_node)
|
||||
await self._flush()
|
||||
|
||||
async def snapshot_end(self, state, end):
|
||||
await self._inner.snapshot_end(state, end)
|
||||
# graph 已结束:必须确保最终 snapshot 落盘后再返回
|
||||
await self._flush(must_succeed=True)
|
||||
|
||||
@asynccontextmanager
|
||||
async def record_run(self, snapshot_id: str) -> AsyncIterator[None]:
|
||||
async with self._inner.record_run(snapshot_id):
|
||||
yield
|
||||
# record_run 退出时 NodeSnapshot 状态从 running → success/error,需要刷盘
|
||||
await self._flush()
|
||||
|
||||
async def load_next(self) -> Optional[NodeSnapshot]:
|
||||
return await self._inner.load_next()
|
||||
|
||||
async def load_all(self) -> list[Snapshot]:
|
||||
return await self._inner.load_all()
|
||||
|
||||
def should_set_types(self) -> bool:
|
||||
return self._inner.should_set_types()
|
||||
|
||||
def set_types(self, state_type, run_end_type):
|
||||
self._inner.set_types(state_type, run_end_type)
|
||||
|
||||
# ─── 序列化 / 反序列化 ─────────────────────────────────────────
|
||||
|
||||
async def hydrate(self) -> bool:
|
||||
"""从 DB 拉一次 history 并恢复到内存;返回是否拉到了内容。
|
||||
|
||||
在 ``run_workflow_task`` 决定 fresh/resume 时调用。``set_types`` 必须
|
||||
在调用前由 ``Graph.iter_from_persistence`` 替我们调过——否则
|
||||
``_snapshots_type_adapter`` 还没准备好就会 assert 失败。
|
||||
"""
|
||||
try:
|
||||
raw = await self.read_history(self.trace_id)
|
||||
except Exception as e: # pragma: no cover - 防御
|
||||
_logger.warning(f"hydrate read failed: {e}")
|
||||
return False
|
||||
if not raw:
|
||||
return False
|
||||
try:
|
||||
self._inner.load_json(_to_json_bytes(raw))
|
||||
except AssertionError:
|
||||
# 没 set_types 时 load_json 会 assert;调用方需先调 set_graph_types
|
||||
raise
|
||||
except Exception as e:
|
||||
_logger.warning(f"hydrate load_json failed: {e}")
|
||||
return False
|
||||
return True
|
||||
|
||||
async def _flush(self, *, must_succeed: bool = False) -> None:
|
||||
"""把内存 history 序列化后异步写 DB。
|
||||
|
||||
``must_succeed=False`` 时 DB 异常仅记 warning;``True`` 时再抛出,
|
||||
让 ``snapshot_end`` 这种"必须落盘"的场景能感知失败。
|
||||
"""
|
||||
if self._inner._snapshots_type_adapter is None:
|
||||
# 类型还没注册,没法 dump(首次 set_graph_types 还没跑过)
|
||||
return
|
||||
try:
|
||||
blob = self._inner.dump_json()
|
||||
except Exception as e: # pragma: no cover
|
||||
_logger.warning(f"dump history failed: {e}")
|
||||
return
|
||||
try:
|
||||
await self.write_history(self.trace_id, _from_json_bytes(blob))
|
||||
except Exception as e:
|
||||
_logger.warning(f"persist history failed: {e}")
|
||||
if must_succeed:
|
||||
raise
|
||||
|
||||
|
||||
def _to_json_bytes(value: Any) -> bytes:
|
||||
"""把 DB 读出的 ``list[dict]`` / ``str`` / ``bytes`` 都规范成 bytes 喂 load_json。"""
|
||||
if isinstance(value, (bytes, bytearray)):
|
||||
return bytes(value)
|
||||
if isinstance(value, str):
|
||||
return value.encode("utf-8")
|
||||
# 假定是 list/dict(来自 JSONB),转回 JSON 字符串
|
||||
import json as _json
|
||||
|
||||
return _json.dumps(value).encode("utf-8")
|
||||
|
||||
|
||||
def _from_json_bytes(blob: bytes) -> Any:
|
||||
"""把 ``dump_json`` 出来的 bytes 转成 list/dict 以便 JSONB 友好存储。"""
|
||||
import json as _json
|
||||
|
||||
return _json.loads(blob.decode("utf-8"))
|
||||
|
||||
|
||||
def build_postgres_persistence(trace_id: str) -> PostgresStatePersistence:
|
||||
"""生产环境构造 PostgresStatePersistence:从 ray_actor_hook 取 postgres handle。"""
|
||||
from kilostar.utils.ray_hook import ray_actor_hook
|
||||
|
||||
postgres_database = ray_actor_hook("postgres_database").postgres_database
|
||||
|
||||
async def _write(tid: str, history: Any) -> None:
|
||||
await postgres_database.upsert_workflow_graph_state.remote(tid, history)
|
||||
|
||||
async def _read(tid: str) -> Optional[Any]:
|
||||
record = await postgres_database.get_workflow_graph_state.remote(tid)
|
||||
if record is None:
|
||||
return None
|
||||
# ORM 模型 / dict / list 都兼容
|
||||
return getattr(record, "history", None) or record
|
||||
|
||||
return PostgresStatePersistence(
|
||||
trace_id=trace_id,
|
||||
write_history=_write,
|
||||
read_history=_read,
|
||||
)
|
||||
@@ -13,7 +13,7 @@
|
||||
# limitations under the License.
|
||||
|
||||
from pydantic import BaseModel, Field, model_validator
|
||||
from typing import Optional, Union, List, Dict, Any
|
||||
from typing import Literal, Optional, Union, List, Dict, Any
|
||||
from .model import LogicGate, WorkflowMetadata, WorkStepStatus, WorkflowStatus
|
||||
from ulid import ULID
|
||||
from datetime import datetime
|
||||
@@ -61,10 +61,22 @@ class WorkflowStep(BaseModel):
|
||||
default=None, description="前置依赖输出"
|
||||
)
|
||||
outputs: Optional[str] = Field(default=None, description="当前步骤产出物变量名")
|
||||
node: Literal["skill_individual", "consciousness_node"] = Field(
|
||||
default="skill_individual",
|
||||
description=(
|
||||
"执行此步的节点类别:\n"
|
||||
"- skill_individual:task 内现起一个专家子个体执行(一次性)\n"
|
||||
"- consciousness_node:远程调用全局 ConsciousnessNode actor"
|
||||
),
|
||||
)
|
||||
agent_id: Optional[str] = Field(
|
||||
default=None,
|
||||
description="分配给 skill_individual 的 Skill Individual 真实 agent_id,不可用名称代替",
|
||||
)
|
||||
require_approval: bool = Field(
|
||||
default=False,
|
||||
description="该步执行前是否需要人工审批;启用时会暂停工作流并通过 SSE 等待用户回执",
|
||||
)
|
||||
logic_gate: Optional[LogicGate] = Field(default=None, description="逻辑跳转控制")
|
||||
|
||||
|
||||
|
||||
@@ -12,166 +12,521 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Workflow 引擎:基于 ``pydantic_graph`` 的状态机驱动。
|
||||
|
||||
调度路径只剩两类节点:
|
||||
|
||||
- ``skill_individual``:在当前 ray task 进程内现起一个 ``SkillIndividual``
|
||||
执行;用后即焚,不消耗 actor。
|
||||
- ``consciousness_node``:远程调用全局 ConsciousnessNode actor 的
|
||||
``working`` 方法(中途指导 / 审查类工作)。
|
||||
|
||||
每一步执行前可设置 ``require_approval=True`` 触发 ``HumanApproval`` 节点:
|
||||
推送 SSE → ``await gwm.get_received`` 阻塞等用户回执 → 决策 continue/abort。
|
||||
|
||||
Graph 还接了 pydantic_graph 的 ``FullStatePersistence``,目前主要用于
|
||||
节点边界自动 snapshot(postgres 持久化保留旧 ``upsert_workflow_context``
|
||||
路径,跨进程 resume 留到后续)。
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import datetime
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Awaitable, Callable, Dict, List, Optional
|
||||
|
||||
import ray
|
||||
from kilostar.core.work.workflow.workflow import KiloStarWorkflow
|
||||
from typing import Dict, Any, List
|
||||
from pydantic import BaseModel, Field
|
||||
from pydantic_graph import BaseNode, End, Graph, GraphRunContext
|
||||
from pydantic_graph.persistence import BaseStatePersistence
|
||||
from pydantic_graph.persistence.in_mem import FullStatePersistence
|
||||
|
||||
from kilostar.core.work.workflow.model import WorkflowStatus
|
||||
|
||||
|
||||
@ray.remote
|
||||
def run_workflow_task(workflow_data: dict, trace_id: str):
|
||||
from kilostar.utils.ray_hook import ray_actor_hook
|
||||
from kilostar.core.work.workflow.model import WorkflowStatus
|
||||
import datetime
|
||||
from pydantic import BaseModel
|
||||
# ─── State / Deps ─────────────────────────────────────────────────────────
|
||||
|
||||
# State passed through graph nodes
|
||||
class WorkflowGraphState(BaseModel):
|
||||
trace_id: str
|
||||
blackboard: Dict[str, Any]
|
||||
work_link: List[Dict[str, Any]]
|
||||
current_step_index: int = 0
|
||||
status: str = "running"
|
||||
logs: List[Dict[str, Any]] = []
|
||||
|
||||
async def save_context(state: WorkflowGraphState):
|
||||
postgres_database = ray_actor_hook("postgres_database").postgres_database
|
||||
await postgres_database.upsert_workflow_context.remote(
|
||||
state.trace_id,
|
||||
workflow_pointer=state.current_step_index,
|
||||
blackboard=state.blackboard,
|
||||
work_link=state.work_link,
|
||||
workflow_status={str(datetime.datetime.now()): state.status},
|
||||
workflow_log=state.logs,
|
||||
)
|
||||
await postgres_database.update_workflow_status.remote(
|
||||
state.trace_id, state.status
|
||||
)
|
||||
global_workflow_manager = ray_actor_hook(
|
||||
"global_workflow_manager"
|
||||
).global_workflow_manager
|
||||
await global_workflow_manager.put_received.remote(
|
||||
state.trace_id, f"执行步骤 {state.current_step_index + 1}..."
|
||||
class WorkflowGraphState(BaseModel):
|
||||
"""图运行期跨节点共享的状态。"""
|
||||
|
||||
trace_id: str
|
||||
blackboard: Dict[str, Any] = Field(default_factory=dict)
|
||||
work_link: List[Dict[str, Any]] = Field(default_factory=list)
|
||||
current_step_index: int = 0
|
||||
final_status: str = WorkflowStatus.RUNNING.value
|
||||
logs: List[Dict[str, Any]] = Field(default_factory=list)
|
||||
original_command: str = ""
|
||||
# 已发过 put_pending 的 HumanApproval step index 列表;resume 后避免重复推送。
|
||||
# 用 list(不是 set)是为了 pydantic_graph 序列化 history 时 JSON 友好。
|
||||
approvals_notified: List[int] = Field(default_factory=list)
|
||||
|
||||
|
||||
# 业务侧执行入口:把 step + state 喂进去,拿到 (output_text, success_bool)
|
||||
StepExecutor = Callable[
|
||||
[Dict[str, Any], "WorkflowGraphState"], Awaitable[tuple[str, bool]]
|
||||
]
|
||||
|
||||
|
||||
@dataclass
|
||||
class WorkflowDeps:
|
||||
"""节点运行期依赖:所有外部 IO 都从这里走,便于测试 mock。
|
||||
|
||||
每个字段都是一个 awaitable,签名贴近原 ``.remote()`` 调用。生产路径由
|
||||
``build_default_deps`` 现场组装真实 actor handle 包装;单测可以传任意
|
||||
``AsyncMock``。
|
||||
|
||||
``run_skill`` / ``run_consciousness`` 是把 step 派发到具体执行器的入口,
|
||||
抽出来既能让 graph 节点保持纯逻辑,又便于测试无需真起 SkillIndividual。
|
||||
"""
|
||||
|
||||
upsert_workflow_context: Callable[..., Awaitable[Any]]
|
||||
update_workflow_status: Callable[[str, str], Awaitable[Any]]
|
||||
put_pending: Callable[[str, str], Awaitable[Any]]
|
||||
get_received: Callable[[str], Awaitable[str]]
|
||||
run_skill: StepExecutor
|
||||
run_consciousness: StepExecutor
|
||||
|
||||
|
||||
# ─── 节点 ─────────────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
@dataclass
|
||||
class Initialize(BaseNode[WorkflowGraphState, WorkflowDeps, str]):
|
||||
"""图入口节点:把 workflow 标记为 RUNNING,发首条 SSE 提示。"""
|
||||
|
||||
async def run(
|
||||
self, ctx: GraphRunContext[WorkflowGraphState, WorkflowDeps]
|
||||
) -> "Dispatch":
|
||||
await ctx.deps.update_workflow_status(
|
||||
ctx.state.trace_id, WorkflowStatus.RUNNING.value
|
||||
)
|
||||
await _persist_context(ctx, status=WorkflowStatus.RUNNING.value)
|
||||
return Dispatch()
|
||||
|
||||
async def execute_step(state: WorkflowGraphState):
|
||||
"""执行单一工作流节点逻辑"""
|
||||
|
||||
@dataclass
|
||||
class Dispatch(BaseNode[WorkflowGraphState, WorkflowDeps, str]):
|
||||
"""读取当前 step,按 ``node`` / ``require_approval`` 字段选择下一节点。"""
|
||||
|
||||
async def run(
|
||||
self,
|
||||
ctx: GraphRunContext[WorkflowGraphState, WorkflowDeps],
|
||||
) -> "HumanApproval | SkillStep | ConsciousnessStep | Finalize":
|
||||
state = ctx.state
|
||||
if state.current_step_index >= len(state.work_link):
|
||||
state.status = WorkflowStatus.COMPLETED
|
||||
return state
|
||||
return Finalize(status=WorkflowStatus.COMPLETED.value)
|
||||
|
||||
step = state.work_link[state.current_step_index]
|
||||
step.get("node", "")
|
||||
action = step.get("action", "")
|
||||
step_data = state.work_link[state.current_step_index]
|
||||
if step_data.get("require_approval"):
|
||||
return HumanApproval()
|
||||
|
||||
# 记录开始状态
|
||||
node_type = step_data.get("node") or "skill_individual"
|
||||
if node_type == "consciousness_node":
|
||||
return ConsciousnessStep()
|
||||
if node_type == "skill_individual":
|
||||
return SkillStep()
|
||||
|
||||
# 未识别的 node 类型按失败处理(保守:不静默吞)
|
||||
state.logs.append(
|
||||
{
|
||||
str(state.current_step_index): [
|
||||
str(datetime.datetime.now()),
|
||||
"working",
|
||||
f"开始执行: {step.get('name', '未命名步骤')}",
|
||||
"failed",
|
||||
f"未知节点类型: {node_type}",
|
||||
]
|
||||
}
|
||||
)
|
||||
await save_context(state)
|
||||
await _persist_context(ctx, status=WorkflowStatus.FAILED.value)
|
||||
return Finalize(status=WorkflowStatus.FAILED.value)
|
||||
|
||||
try:
|
||||
# TODO: 实际对接不同节点执行逻辑 (例如: control_node, agent 技能)
|
||||
# 这里是简化版,向控制节点或指定 skill 发送指令
|
||||
|
||||
# ... 模拟执行逻辑 ...
|
||||
await asyncio.sleep(2)
|
||||
@dataclass
|
||||
class HumanApproval(BaseNode[WorkflowGraphState, WorkflowDeps, str]):
|
||||
"""人工审批节点:暂停 graph,等用户通过 SSE 回执决策。
|
||||
|
||||
# 记录结果
|
||||
state.blackboard[
|
||||
step.get("outputs", f"step_{state.current_step_index}_result")
|
||||
] = "Success execution of " + action
|
||||
state.logs[-1][str(state.current_step_index)] = [
|
||||
str(datetime.datetime.now()),
|
||||
"completed",
|
||||
f"成功: {action}",
|
||||
]
|
||||
回执约定(轻量协议,沿用现有 SSE 通道):
|
||||
|
||||
# 判断逻辑跳转
|
||||
logic_gate = step.get("logic_gate")
|
||||
if logic_gate and logic_gate.get("if_pass") == "exit":
|
||||
state.status = WorkflowStatus.COMPLETED
|
||||
else:
|
||||
state.current_step_index += 1
|
||||
- 含 ``approve`` / ``yes`` / ``ok`` 视为通过,回到 Dispatch 继续执行
|
||||
- 其它(包括 ``reject``/``no``/``abort``)视为拒绝,工作流终止为 FAILED
|
||||
"""
|
||||
|
||||
except Exception as e:
|
||||
state.logs[-1][str(state.current_step_index)] = [
|
||||
str(datetime.datetime.now()),
|
||||
"failed",
|
||||
str(e),
|
||||
]
|
||||
state.status = WorkflowStatus.FAILED
|
||||
logic_gate = step.get("logic_gate")
|
||||
if logic_gate and logic_gate.get("if_fail"):
|
||||
fail_target = logic_gate.get("if_fail")
|
||||
if "jump_to_step_" in fail_target:
|
||||
target_step = int(fail_target.split("_")[-1]) - 1
|
||||
state.current_step_index = target_step
|
||||
state.status = WorkflowStatus.RUNNING
|
||||
async def run(
|
||||
self,
|
||||
ctx: GraphRunContext[WorkflowGraphState, WorkflowDeps],
|
||||
) -> "Dispatch | Finalize":
|
||||
state = ctx.state
|
||||
step_data = state.work_link[state.current_step_index]
|
||||
# idempotent 推送:仅当本 step 还没通知过时才发 put_pending。
|
||||
# 这样 resume 场景(HumanApproval 节点被重新进入)不会给前端发重复消息。
|
||||
if state.current_step_index not in state.approvals_notified:
|
||||
await ctx.deps.put_pending(
|
||||
state.trace_id,
|
||||
f"步骤 {state.current_step_index + 1} ({step_data.get('name', '')}) "
|
||||
f"需要人工审批,请回复 approve / reject。",
|
||||
)
|
||||
state.approvals_notified.append(state.current_step_index)
|
||||
await _persist_context(ctx, status=WorkflowStatus.HANGUP.value)
|
||||
|
||||
await save_context(state)
|
||||
return state
|
||||
reply = (await ctx.deps.get_received(state.trace_id) or "").strip().lower()
|
||||
if any(token in reply for token in ("approve", "yes", "ok")):
|
||||
# 把 require_approval 置否避免无限循环重新进 HumanApproval
|
||||
step_data["require_approval"] = False
|
||||
await _persist_context(ctx, status=WorkflowStatus.RUNNING.value)
|
||||
return Dispatch()
|
||||
|
||||
async def _run():
|
||||
postgres_database = ray_actor_hook("postgres_database").postgres_database
|
||||
await postgres_database.update_workflow_status.remote(
|
||||
trace_id, WorkflowStatus.RUNNING
|
||||
state.logs.append(
|
||||
{
|
||||
str(state.current_step_index): [
|
||||
str(datetime.datetime.now()),
|
||||
"rejected",
|
||||
f"用户拒绝执行该步骤: {reply!r}",
|
||||
]
|
||||
}
|
||||
)
|
||||
await _persist_context(ctx, status=WorkflowStatus.FAILED.value)
|
||||
return Finalize(status=WorkflowStatus.FAILED.value)
|
||||
|
||||
state = WorkflowGraphState(
|
||||
trace_id=trace_id,
|
||||
blackboard={},
|
||||
work_link=workflow_data.get("work_link", []),
|
||||
)
|
||||
await save_context(state)
|
||||
|
||||
# 简单的图执行驱动 (模拟 pydantic-ai.graph.run 行为,直至 Graph 库正式稳定)
|
||||
while state.status == WorkflowStatus.RUNNING and state.current_step_index < len(
|
||||
state.work_link
|
||||
):
|
||||
state = await execute_step(state)
|
||||
@dataclass
|
||||
class SkillStep(BaseNode[WorkflowGraphState, WorkflowDeps, str]):
|
||||
"""skill_individual 路径:当前进程内拉一个专家子个体执行该步。"""
|
||||
|
||||
await postgres_database.update_workflow_status.remote(trace_id, state.status)
|
||||
global_workflow_manager = ray_actor_hook(
|
||||
"global_workflow_manager"
|
||||
).global_workflow_manager
|
||||
async def run(
|
||||
self,
|
||||
ctx: GraphRunContext[WorkflowGraphState, WorkflowDeps],
|
||||
) -> "Dispatch | Finalize":
|
||||
return await _execute_step(ctx, executor=ctx.deps.run_skill)
|
||||
|
||||
|
||||
@dataclass
|
||||
class ConsciousnessStep(BaseNode[WorkflowGraphState, WorkflowDeps, str]):
|
||||
"""consciousness_node 路径:远程调用 ConsciousnessNode actor 处理该步。"""
|
||||
|
||||
async def run(
|
||||
self,
|
||||
ctx: GraphRunContext[WorkflowGraphState, WorkflowDeps],
|
||||
) -> "Dispatch | Finalize":
|
||||
return await _execute_step(ctx, executor=ctx.deps.run_consciousness)
|
||||
|
||||
|
||||
@dataclass
|
||||
class Finalize(BaseNode[WorkflowGraphState, WorkflowDeps, str]):
|
||||
"""收尾节点:写最终状态、推送最终 SSE,把 workflow 终态作为 graph 输出。"""
|
||||
|
||||
status: str
|
||||
|
||||
async def run(
|
||||
self, ctx: GraphRunContext[WorkflowGraphState, WorkflowDeps]
|
||||
) -> End[str]:
|
||||
ctx.state.final_status = self.status
|
||||
await ctx.deps.update_workflow_status(ctx.state.trace_id, self.status)
|
||||
msg = (
|
||||
"工作流执行完成!"
|
||||
if state.status == WorkflowStatus.COMPLETED
|
||||
if self.status == WorkflowStatus.COMPLETED.value
|
||||
else "工作流执行失败。"
|
||||
)
|
||||
await global_workflow_manager.put_received.remote(trace_id, msg)
|
||||
await ctx.deps.put_pending(ctx.state.trace_id, msg)
|
||||
return End(self.status)
|
||||
|
||||
asyncio.run(_run())
|
||||
|
||||
# ─── 内部 helper ─────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
async def _persist_context(
|
||||
ctx: GraphRunContext[WorkflowGraphState, WorkflowDeps], *, status: str
|
||||
) -> None:
|
||||
"""把当前 state 落库到 workflow_context 表(覆盖式写入)。"""
|
||||
await ctx.deps.upsert_workflow_context(
|
||||
ctx.state.trace_id,
|
||||
workflow_pointer=ctx.state.current_step_index,
|
||||
blackboard=ctx.state.blackboard,
|
||||
work_link=ctx.state.work_link,
|
||||
workflow_status={str(datetime.datetime.now()): status},
|
||||
workflow_log=ctx.state.logs,
|
||||
)
|
||||
|
||||
|
||||
async def _execute_step(
|
||||
ctx: GraphRunContext[WorkflowGraphState, WorkflowDeps],
|
||||
*,
|
||||
executor: StepExecutor,
|
||||
) -> "Dispatch | Finalize":
|
||||
"""SkillStep / ConsciousnessStep 共享的执行骨架。
|
||||
|
||||
把"日志/SSE/blackboard 更新/逻辑闸门"这些跨节点共性逻辑抽出来;具体怎么
|
||||
跑这一步交给 ``executor`` 决定(生产是 SkillIndividual.run / actor.working
|
||||
远程调用,测试可以直接传 lambda)。
|
||||
"""
|
||||
state = ctx.state
|
||||
step_data = state.work_link[state.current_step_index]
|
||||
|
||||
state.logs.append(
|
||||
{
|
||||
str(state.current_step_index): [
|
||||
str(datetime.datetime.now()),
|
||||
"working",
|
||||
f"开始执行: {step_data.get('name', '未命名步骤')}",
|
||||
]
|
||||
}
|
||||
)
|
||||
await _persist_context(ctx, status=WorkflowStatus.RUNNING.value)
|
||||
await ctx.deps.put_pending(
|
||||
state.trace_id, f"执行步骤 {state.current_step_index + 1}..."
|
||||
)
|
||||
|
||||
try:
|
||||
output_text, success = await executor(step_data, state)
|
||||
except Exception as e: # 执行器抛异常 → 走失败分支
|
||||
output_text, success = str(e), False
|
||||
|
||||
if success:
|
||||
output_key = step_data.get(
|
||||
"outputs", f"step_{state.current_step_index}_result"
|
||||
)
|
||||
state.blackboard[output_key] = output_text
|
||||
state.logs[-1][str(state.current_step_index)] = [
|
||||
str(datetime.datetime.now()),
|
||||
"completed",
|
||||
f"成功: {step_data.get('action', '')}",
|
||||
]
|
||||
await _persist_context(ctx, status=WorkflowStatus.RUNNING.value)
|
||||
|
||||
logic_gate = step_data.get("logic_gate") or {}
|
||||
if logic_gate.get("if_pass") == "exit":
|
||||
return Finalize(status=WorkflowStatus.COMPLETED.value)
|
||||
|
||||
state.current_step_index += 1
|
||||
if state.current_step_index >= len(state.work_link):
|
||||
return Finalize(status=WorkflowStatus.COMPLETED.value)
|
||||
return Dispatch()
|
||||
|
||||
# 失败:if_fail 跳转优先于直接收尾
|
||||
state.logs[-1][str(state.current_step_index)] = [
|
||||
str(datetime.datetime.now()),
|
||||
"failed",
|
||||
output_text,
|
||||
]
|
||||
logic_gate = step_data.get("logic_gate") or {}
|
||||
fail_target = logic_gate.get("if_fail")
|
||||
if fail_target and "jump_to_step_" in fail_target:
|
||||
target_step = int(fail_target.split("_")[-1]) - 1
|
||||
state.current_step_index = target_step
|
||||
await _persist_context(ctx, status=WorkflowStatus.RUNNING.value)
|
||||
return Dispatch()
|
||||
|
||||
await _persist_context(ctx, status=WorkflowStatus.FAILED.value)
|
||||
return Finalize(status=WorkflowStatus.FAILED.value)
|
||||
|
||||
|
||||
# ─── 图定义 ───────────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
workflow_graph: Graph[WorkflowGraphState, WorkflowDeps, str] = Graph(
|
||||
nodes=[Initialize, Dispatch, HumanApproval, SkillStep, ConsciousnessStep, Finalize],
|
||||
state_type=WorkflowGraphState,
|
||||
run_end_type=str,
|
||||
)
|
||||
|
||||
|
||||
# ─── 默认执行器:把 step 派发到 SkillIndividual / ConsciousnessNode ────
|
||||
|
||||
|
||||
async def _default_skill_executor(
|
||||
step_data: Dict[str, Any], state: WorkflowGraphState
|
||||
) -> tuple[str, bool]:
|
||||
"""生产环境的 skill_individual 派发器:当前 task 进程现起 agent 执行。
|
||||
|
||||
每步现起一个 ``SkillIndividual`` 跑完即销毁,不绑定 actor 寿命。``agent_id``
|
||||
是必须的(用于从 GSM 拉到该子个体的配置)。
|
||||
"""
|
||||
from kilostar.worker_individual.skill_individual import SkillIndividual
|
||||
from kilostar.core.global_state_machine.gsm_snapshot import fetch_snapshot
|
||||
|
||||
agent_id = step_data.get("agent_id")
|
||||
if not agent_id:
|
||||
return "skill_individual 步骤缺少 agent_id", False
|
||||
|
||||
snapshot = await fetch_snapshot()
|
||||
agent_config = snapshot.individuals.get(agent_id)
|
||||
if not agent_config:
|
||||
return f"未找到 agent_id={agent_id} 的专家子个体", False
|
||||
|
||||
individual = SkillIndividual(dict(agent_config))
|
||||
task_event = {
|
||||
"step": step_data,
|
||||
"blackboard": state.blackboard,
|
||||
"original_command": state.original_command,
|
||||
}
|
||||
result = await individual.run(task_event)
|
||||
output = (
|
||||
result.get("output", "") if isinstance(result, dict) else str(result)
|
||||
)
|
||||
return output or "(empty)", True
|
||||
|
||||
|
||||
async def _default_consciousness_executor(
|
||||
step_data: Dict[str, Any], state: WorkflowGraphState
|
||||
) -> tuple[str, bool]:
|
||||
"""生产环境的 consciousness 派发器:远程调用 ConsciousnessNode.working。"""
|
||||
from kilostar.utils.ray_hook import ray_actor_hook
|
||||
from kilostar.core.individual.consciousness_node.template import (
|
||||
ForWorkflow,
|
||||
ForWorkflowInput,
|
||||
)
|
||||
from kilostar.core.work.workflow.workflow import WorkflowStep
|
||||
|
||||
consciousness_node = ray_actor_hook("consciousness_node").consciousness_node
|
||||
payload = ForWorkflowInput(
|
||||
workflow_step=WorkflowStep.model_validate(step_data),
|
||||
original_command=state.original_command,
|
||||
)
|
||||
result = await consciousness_node.working.remote(payload)
|
||||
if isinstance(result, ForWorkflow):
|
||||
return result.output, True
|
||||
if result is None:
|
||||
return "ConsciousnessNode 返回 None", False
|
||||
return f"ConsciousnessNode 返回未知类型: {type(result).__name__}", False
|
||||
|
||||
|
||||
def build_default_deps() -> WorkflowDeps:
|
||||
"""生产环境构造 ``WorkflowDeps``:把 ray actor handle 包装成 awaitable。
|
||||
|
||||
抽出来是为了让 ``run_workflow_task`` 入口和测试入口共享同一套包装逻辑。
|
||||
"""
|
||||
from kilostar.utils.ray_hook import ray_actor_hook
|
||||
|
||||
postgres_database = ray_actor_hook("postgres_database").postgres_database
|
||||
global_workflow_manager = ray_actor_hook(
|
||||
"global_workflow_manager"
|
||||
).global_workflow_manager
|
||||
|
||||
async def _upsert_workflow_context(trace_id: str, **kwargs: Any) -> Any:
|
||||
return await postgres_database.upsert_workflow_context.remote(
|
||||
trace_id, **kwargs
|
||||
)
|
||||
|
||||
async def _update_workflow_status(trace_id: str, status: str) -> Any:
|
||||
return await postgres_database.update_workflow_status.remote(
|
||||
trace_id, status
|
||||
)
|
||||
|
||||
async def _put_pending(trace_id: str, message: str) -> Any:
|
||||
return await global_workflow_manager.put_pending.remote(trace_id, message)
|
||||
|
||||
async def _get_received(trace_id: str) -> str:
|
||||
return await global_workflow_manager.get_received.remote(trace_id)
|
||||
|
||||
return WorkflowDeps(
|
||||
upsert_workflow_context=_upsert_workflow_context,
|
||||
update_workflow_status=_update_workflow_status,
|
||||
put_pending=_put_pending,
|
||||
get_received=_get_received,
|
||||
run_skill=_default_skill_executor,
|
||||
run_consciousness=_default_consciousness_executor,
|
||||
)
|
||||
|
||||
|
||||
async def run_workflow_graph(
|
||||
workflow_data: Dict[str, Any],
|
||||
trace_id: str,
|
||||
*,
|
||||
deps: Optional[WorkflowDeps] = None,
|
||||
persistence: Optional[BaseStatePersistence] = None,
|
||||
) -> str:
|
||||
"""在当前事件循环里跑一遍 workflow graph,返回 workflow 终态字符串。
|
||||
|
||||
Args:
|
||||
workflow_data: ``KiloStarWorkflow.model_dump()`` 出来的 dict
|
||||
trace_id: 工作流追踪 id
|
||||
deps: 缺省时通过 ``build_default_deps`` 现场构造(生产路径)
|
||||
persistence: 缺省给一个 ``FullStatePersistence`` 让 graph 自动在节点
|
||||
边界 snapshot;外部传入则共享同一份持久化(便于诊断 / 后续 resume)
|
||||
"""
|
||||
if deps is None:
|
||||
deps = build_default_deps()
|
||||
if persistence is None:
|
||||
persistence = FullStatePersistence()
|
||||
|
||||
state = WorkflowGraphState(
|
||||
trace_id=trace_id,
|
||||
blackboard={},
|
||||
work_link=list(workflow_data.get("work_link", []) or []),
|
||||
original_command=(workflow_data.get("workflow_metadata") or {}).get(
|
||||
"command", ""
|
||||
)
|
||||
or "",
|
||||
)
|
||||
result = await workflow_graph.run(
|
||||
Initialize(),
|
||||
state=state,
|
||||
deps=deps,
|
||||
persistence=persistence,
|
||||
)
|
||||
return result.output
|
||||
|
||||
|
||||
async def resume_workflow_graph(
|
||||
trace_id: str,
|
||||
*,
|
||||
deps: Optional[WorkflowDeps] = None,
|
||||
persistence: Optional[BaseStatePersistence] = None,
|
||||
) -> str:
|
||||
"""从持久化里恢复一个工作流,跑剩余节点直至 End。
|
||||
|
||||
要求 ``persistence`` 已经事先 ``hydrate``(生产路径用 ``PostgresStatePersistence.hydrate``)。
|
||||
本函数只负责调 ``Graph.iter_from_persistence`` 把剩下的节点跑完。
|
||||
"""
|
||||
if persistence is None:
|
||||
raise ValueError("resume 必须显式传入 persistence")
|
||||
if deps is None:
|
||||
deps = build_default_deps()
|
||||
|
||||
final_output: str = WorkflowStatus.RUNNING.value
|
||||
async with workflow_graph.iter_from_persistence(
|
||||
persistence, deps=deps
|
||||
) as run:
|
||||
async for node in run:
|
||||
if isinstance(node, End):
|
||||
final_output = node.data
|
||||
break
|
||||
return final_output
|
||||
|
||||
|
||||
@ray.remote
|
||||
class WorkflowRunningEngine:
|
||||
def __init__(
|
||||
self, consciousness_node=None, control_node=None, regulatory_node=None
|
||||
):
|
||||
self.consciousness_node = consciousness_node
|
||||
self.control_node = control_node
|
||||
self.regulatory_node = regulatory_node
|
||||
self.events_queue = asyncio.Queue()
|
||||
def run_workflow_task(workflow_data: dict, trace_id: str):
|
||||
"""workflow 的 ray task 入口:一次性执行,跑完即销毁。
|
||||
|
||||
async def put_event(self, event):
|
||||
await self.events_queue.put(event)
|
||||
生产路径下持久化交给 ``PostgresStatePersistence`` —— 即便进程崩溃,再 fire
|
||||
一次相同 ``trace_id`` 的任务(或调 ``/workflow/{trace_id}/resume``)即可
|
||||
续跑。同时为了支持 fresh start,先尝试 ``hydrate``:
|
||||
- hydrate 拿到内容 → 走 resume 路径
|
||||
- hydrate 没拿到 → 走全新路径
|
||||
|
||||
async def run(self):
|
||||
"""引擎循环提取事件"""
|
||||
while True:
|
||||
await self.events_queue.get()
|
||||
await asyncio.sleep(1)
|
||||
ray task 是新进程,contextvars 不会从 caller 传过来,所以入口先 bind 一次
|
||||
``trace_id``,让节点内的日志自动带上它。
|
||||
"""
|
||||
from kilostar.utils.request_context import trace_id_scope
|
||||
from kilostar.core.work.workflow.graph_persistence import (
|
||||
build_postgres_persistence,
|
||||
)
|
||||
|
||||
async def execute_workflow(self, workflow: KiloStarWorkflow):
|
||||
# 这个方法可以由意识节点调用来提交一个完整的运行任务
|
||||
workflow_dict = workflow.model_dump()
|
||||
trace_id = workflow.trace_id
|
||||
run_workflow_task.remote(workflow_dict, trace_id)
|
||||
async def _entry() -> None:
|
||||
with trace_id_scope(trace_id):
|
||||
persistence = build_postgres_persistence(trace_id)
|
||||
persistence.set_graph_types(workflow_graph)
|
||||
recovered = False
|
||||
try:
|
||||
recovered = await persistence.hydrate()
|
||||
except Exception: # pragma: no cover - 防御
|
||||
recovered = False
|
||||
|
||||
if recovered:
|
||||
await resume_workflow_graph(trace_id, persistence=persistence)
|
||||
else:
|
||||
await run_workflow_graph(
|
||||
workflow_data, trace_id, persistence=persistence
|
||||
)
|
||||
|
||||
asyncio.run(_entry())
|
||||
|
||||
Reference in New Issue
Block a user