feat(system):优化后端
1.新增后端测试 2.增加了后端的加密 3.增加了i18n(国际化)
This commit is contained in:
@@ -0,0 +1,191 @@
|
||||
# Copyright 2026 zhaoxi826
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Postgres 后端的 ``BaseStatePersistence`` 实现,让 graph 跨进程 resume。
|
||||
|
||||
设计思路:
|
||||
|
||||
- **复用 ``FullStatePersistence`` 的内存语义**:snapshot/record_run/load_next 等
|
||||
的实现已经做得很完备(NodeSnapshot 状态机、deep_copy、type adapter),不重写
|
||||
这些细节,只是在每次"史会发生变更"的钩子触发后,把内存 history 序列化为 JSON
|
||||
写到 ``workflow_graph_state`` 表里。
|
||||
- **异步 IO 解耦**:DB 写入是 fire-and-forget 模式,不在 graph 节点路径上阻塞——
|
||||
graph 跑得快,持久化追得上即可。但 ``snapshot_end`` 一定会 await(确保关机
|
||||
之前最终 history 落盘)。
|
||||
- **resume 入口**:从 DB 读出 history JSON → ``load_json`` 还原 → ``Graph.iter_from_persistence``
|
||||
跑剩余节点。
|
||||
|
||||
这一层不直接持有 SQLAlchemy session;通过两个 awaitable 注入 IO,便于测试。
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
from contextlib import asynccontextmanager
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any, AsyncIterator, Awaitable, Callable, Optional
|
||||
|
||||
from pydantic_graph import BaseNode, End
|
||||
from pydantic_graph.persistence import (
|
||||
BaseStatePersistence,
|
||||
NodeSnapshot,
|
||||
Snapshot,
|
||||
)
|
||||
from pydantic_graph.persistence.in_mem import FullStatePersistence
|
||||
|
||||
from kilostar.utils.logger import get_logger
|
||||
|
||||
_logger = get_logger("graph_persistence")
|
||||
|
||||
|
||||
# IO 注入签名:写一份 history JSON / 读一份 history JSON(None=没有)
|
||||
WriteHistory = Callable[[str, Any], Awaitable[None]]
|
||||
ReadHistory = Callable[[str], Awaitable[Optional[Any]]]
|
||||
|
||||
|
||||
@dataclass
|
||||
class PostgresStatePersistence(BaseStatePersistence):
|
||||
"""复用 ``FullStatePersistence`` 内存语义 + 把 history 落 postgres。
|
||||
|
||||
每个 hook 触发后异步把 history 写库,DB 失败不影响 graph 继续推进
|
||||
(只记 warning),保证 graph 自身的可用性。
|
||||
"""
|
||||
|
||||
trace_id: str
|
||||
write_history: WriteHistory
|
||||
read_history: ReadHistory
|
||||
_inner: FullStatePersistence = field(default_factory=FullStatePersistence)
|
||||
|
||||
# ─── BaseStatePersistence 接口 ──────────────────────────────────
|
||||
|
||||
async def snapshot_node(self, state, next_node):
|
||||
await self._inner.snapshot_node(state, next_node)
|
||||
await self._flush()
|
||||
|
||||
async def snapshot_node_if_new(self, snapshot_id, state, next_node):
|
||||
await self._inner.snapshot_node_if_new(snapshot_id, state, next_node)
|
||||
await self._flush()
|
||||
|
||||
async def snapshot_end(self, state, end):
|
||||
await self._inner.snapshot_end(state, end)
|
||||
# graph 已结束:必须确保最终 snapshot 落盘后再返回
|
||||
await self._flush(must_succeed=True)
|
||||
|
||||
@asynccontextmanager
|
||||
async def record_run(self, snapshot_id: str) -> AsyncIterator[None]:
|
||||
async with self._inner.record_run(snapshot_id):
|
||||
yield
|
||||
# record_run 退出时 NodeSnapshot 状态从 running → success/error,需要刷盘
|
||||
await self._flush()
|
||||
|
||||
async def load_next(self) -> Optional[NodeSnapshot]:
|
||||
return await self._inner.load_next()
|
||||
|
||||
async def load_all(self) -> list[Snapshot]:
|
||||
return await self._inner.load_all()
|
||||
|
||||
def should_set_types(self) -> bool:
|
||||
return self._inner.should_set_types()
|
||||
|
||||
def set_types(self, state_type, run_end_type):
|
||||
self._inner.set_types(state_type, run_end_type)
|
||||
|
||||
# ─── 序列化 / 反序列化 ─────────────────────────────────────────
|
||||
|
||||
async def hydrate(self) -> bool:
|
||||
"""从 DB 拉一次 history 并恢复到内存;返回是否拉到了内容。
|
||||
|
||||
在 ``run_workflow_task`` 决定 fresh/resume 时调用。``set_types`` 必须
|
||||
在调用前由 ``Graph.iter_from_persistence`` 替我们调过——否则
|
||||
``_snapshots_type_adapter`` 还没准备好就会 assert 失败。
|
||||
"""
|
||||
try:
|
||||
raw = await self.read_history(self.trace_id)
|
||||
except Exception as e: # pragma: no cover - 防御
|
||||
_logger.warning(f"hydrate read failed: {e}")
|
||||
return False
|
||||
if not raw:
|
||||
return False
|
||||
try:
|
||||
self._inner.load_json(_to_json_bytes(raw))
|
||||
except AssertionError:
|
||||
# 没 set_types 时 load_json 会 assert;调用方需先调 set_graph_types
|
||||
raise
|
||||
except Exception as e:
|
||||
_logger.warning(f"hydrate load_json failed: {e}")
|
||||
return False
|
||||
return True
|
||||
|
||||
async def _flush(self, *, must_succeed: bool = False) -> None:
|
||||
"""把内存 history 序列化后异步写 DB。
|
||||
|
||||
``must_succeed=False`` 时 DB 异常仅记 warning;``True`` 时再抛出,
|
||||
让 ``snapshot_end`` 这种"必须落盘"的场景能感知失败。
|
||||
"""
|
||||
if self._inner._snapshots_type_adapter is None:
|
||||
# 类型还没注册,没法 dump(首次 set_graph_types 还没跑过)
|
||||
return
|
||||
try:
|
||||
blob = self._inner.dump_json()
|
||||
except Exception as e: # pragma: no cover
|
||||
_logger.warning(f"dump history failed: {e}")
|
||||
return
|
||||
try:
|
||||
await self.write_history(self.trace_id, _from_json_bytes(blob))
|
||||
except Exception as e:
|
||||
_logger.warning(f"persist history failed: {e}")
|
||||
if must_succeed:
|
||||
raise
|
||||
|
||||
|
||||
def _to_json_bytes(value: Any) -> bytes:
|
||||
"""把 DB 读出的 ``list[dict]`` / ``str`` / ``bytes`` 都规范成 bytes 喂 load_json。"""
|
||||
if isinstance(value, (bytes, bytearray)):
|
||||
return bytes(value)
|
||||
if isinstance(value, str):
|
||||
return value.encode("utf-8")
|
||||
# 假定是 list/dict(来自 JSONB),转回 JSON 字符串
|
||||
import json as _json
|
||||
|
||||
return _json.dumps(value).encode("utf-8")
|
||||
|
||||
|
||||
def _from_json_bytes(blob: bytes) -> Any:
|
||||
"""把 ``dump_json`` 出来的 bytes 转成 list/dict 以便 JSONB 友好存储。"""
|
||||
import json as _json
|
||||
|
||||
return _json.loads(blob.decode("utf-8"))
|
||||
|
||||
|
||||
def build_postgres_persistence(trace_id: str) -> PostgresStatePersistence:
|
||||
"""生产环境构造 PostgresStatePersistence:从 ray_actor_hook 取 postgres handle。"""
|
||||
from kilostar.utils.ray_hook import ray_actor_hook
|
||||
|
||||
postgres_database = ray_actor_hook("postgres_database").postgres_database
|
||||
|
||||
async def _write(tid: str, history: Any) -> None:
|
||||
await postgres_database.upsert_workflow_graph_state.remote(tid, history)
|
||||
|
||||
async def _read(tid: str) -> Optional[Any]:
|
||||
record = await postgres_database.get_workflow_graph_state.remote(tid)
|
||||
if record is None:
|
||||
return None
|
||||
# ORM 模型 / dict / list 都兼容
|
||||
return getattr(record, "history", None) or record
|
||||
|
||||
return PostgresStatePersistence(
|
||||
trace_id=trace_id,
|
||||
write_history=_write,
|
||||
read_history=_read,
|
||||
)
|
||||
@@ -13,7 +13,7 @@
|
||||
# limitations under the License.
|
||||
|
||||
from pydantic import BaseModel, Field, model_validator
|
||||
from typing import Optional, Union, List, Dict, Any
|
||||
from typing import Literal, Optional, Union, List, Dict, Any
|
||||
from .model import LogicGate, WorkflowMetadata, WorkStepStatus, WorkflowStatus
|
||||
from ulid import ULID
|
||||
from datetime import datetime
|
||||
@@ -61,10 +61,22 @@ class WorkflowStep(BaseModel):
|
||||
default=None, description="前置依赖输出"
|
||||
)
|
||||
outputs: Optional[str] = Field(default=None, description="当前步骤产出物变量名")
|
||||
node: Literal["skill_individual", "consciousness_node"] = Field(
|
||||
default="skill_individual",
|
||||
description=(
|
||||
"执行此步的节点类别:\n"
|
||||
"- skill_individual:task 内现起一个专家子个体执行(一次性)\n"
|
||||
"- consciousness_node:远程调用全局 ConsciousnessNode actor"
|
||||
),
|
||||
)
|
||||
agent_id: Optional[str] = Field(
|
||||
default=None,
|
||||
description="分配给 skill_individual 的 Skill Individual 真实 agent_id,不可用名称代替",
|
||||
)
|
||||
require_approval: bool = Field(
|
||||
default=False,
|
||||
description="该步执行前是否需要人工审批;启用时会暂停工作流并通过 SSE 等待用户回执",
|
||||
)
|
||||
logic_gate: Optional[LogicGate] = Field(default=None, description="逻辑跳转控制")
|
||||
|
||||
|
||||
|
||||
@@ -12,166 +12,521 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Workflow 引擎:基于 ``pydantic_graph`` 的状态机驱动。
|
||||
|
||||
调度路径只剩两类节点:
|
||||
|
||||
- ``skill_individual``:在当前 ray task 进程内现起一个 ``SkillIndividual``
|
||||
执行;用后即焚,不消耗 actor。
|
||||
- ``consciousness_node``:远程调用全局 ConsciousnessNode actor 的
|
||||
``working`` 方法(中途指导 / 审查类工作)。
|
||||
|
||||
每一步执行前可设置 ``require_approval=True`` 触发 ``HumanApproval`` 节点:
|
||||
推送 SSE → ``await gwm.get_received`` 阻塞等用户回执 → 决策 continue/abort。
|
||||
|
||||
Graph 还接了 pydantic_graph 的 ``FullStatePersistence``,目前主要用于
|
||||
节点边界自动 snapshot(postgres 持久化保留旧 ``upsert_workflow_context``
|
||||
路径,跨进程 resume 留到后续)。
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import datetime
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Awaitable, Callable, Dict, List, Optional
|
||||
|
||||
import ray
|
||||
from kilostar.core.work.workflow.workflow import KiloStarWorkflow
|
||||
from typing import Dict, Any, List
|
||||
from pydantic import BaseModel, Field
|
||||
from pydantic_graph import BaseNode, End, Graph, GraphRunContext
|
||||
from pydantic_graph.persistence import BaseStatePersistence
|
||||
from pydantic_graph.persistence.in_mem import FullStatePersistence
|
||||
|
||||
from kilostar.core.work.workflow.model import WorkflowStatus
|
||||
|
||||
|
||||
@ray.remote
|
||||
def run_workflow_task(workflow_data: dict, trace_id: str):
|
||||
from kilostar.utils.ray_hook import ray_actor_hook
|
||||
from kilostar.core.work.workflow.model import WorkflowStatus
|
||||
import datetime
|
||||
from pydantic import BaseModel
|
||||
# ─── State / Deps ─────────────────────────────────────────────────────────
|
||||
|
||||
# State passed through graph nodes
|
||||
class WorkflowGraphState(BaseModel):
|
||||
trace_id: str
|
||||
blackboard: Dict[str, Any]
|
||||
work_link: List[Dict[str, Any]]
|
||||
current_step_index: int = 0
|
||||
status: str = "running"
|
||||
logs: List[Dict[str, Any]] = []
|
||||
|
||||
async def save_context(state: WorkflowGraphState):
|
||||
postgres_database = ray_actor_hook("postgres_database").postgres_database
|
||||
await postgres_database.upsert_workflow_context.remote(
|
||||
state.trace_id,
|
||||
workflow_pointer=state.current_step_index,
|
||||
blackboard=state.blackboard,
|
||||
work_link=state.work_link,
|
||||
workflow_status={str(datetime.datetime.now()): state.status},
|
||||
workflow_log=state.logs,
|
||||
)
|
||||
await postgres_database.update_workflow_status.remote(
|
||||
state.trace_id, state.status
|
||||
)
|
||||
global_workflow_manager = ray_actor_hook(
|
||||
"global_workflow_manager"
|
||||
).global_workflow_manager
|
||||
await global_workflow_manager.put_received.remote(
|
||||
state.trace_id, f"执行步骤 {state.current_step_index + 1}..."
|
||||
class WorkflowGraphState(BaseModel):
|
||||
"""图运行期跨节点共享的状态。"""
|
||||
|
||||
trace_id: str
|
||||
blackboard: Dict[str, Any] = Field(default_factory=dict)
|
||||
work_link: List[Dict[str, Any]] = Field(default_factory=list)
|
||||
current_step_index: int = 0
|
||||
final_status: str = WorkflowStatus.RUNNING.value
|
||||
logs: List[Dict[str, Any]] = Field(default_factory=list)
|
||||
original_command: str = ""
|
||||
# 已发过 put_pending 的 HumanApproval step index 列表;resume 后避免重复推送。
|
||||
# 用 list(不是 set)是为了 pydantic_graph 序列化 history 时 JSON 友好。
|
||||
approvals_notified: List[int] = Field(default_factory=list)
|
||||
|
||||
|
||||
# 业务侧执行入口:把 step + state 喂进去,拿到 (output_text, success_bool)
|
||||
StepExecutor = Callable[
|
||||
[Dict[str, Any], "WorkflowGraphState"], Awaitable[tuple[str, bool]]
|
||||
]
|
||||
|
||||
|
||||
@dataclass
|
||||
class WorkflowDeps:
|
||||
"""节点运行期依赖:所有外部 IO 都从这里走,便于测试 mock。
|
||||
|
||||
每个字段都是一个 awaitable,签名贴近原 ``.remote()`` 调用。生产路径由
|
||||
``build_default_deps`` 现场组装真实 actor handle 包装;单测可以传任意
|
||||
``AsyncMock``。
|
||||
|
||||
``run_skill`` / ``run_consciousness`` 是把 step 派发到具体执行器的入口,
|
||||
抽出来既能让 graph 节点保持纯逻辑,又便于测试无需真起 SkillIndividual。
|
||||
"""
|
||||
|
||||
upsert_workflow_context: Callable[..., Awaitable[Any]]
|
||||
update_workflow_status: Callable[[str, str], Awaitable[Any]]
|
||||
put_pending: Callable[[str, str], Awaitable[Any]]
|
||||
get_received: Callable[[str], Awaitable[str]]
|
||||
run_skill: StepExecutor
|
||||
run_consciousness: StepExecutor
|
||||
|
||||
|
||||
# ─── 节点 ─────────────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
@dataclass
|
||||
class Initialize(BaseNode[WorkflowGraphState, WorkflowDeps, str]):
|
||||
"""图入口节点:把 workflow 标记为 RUNNING,发首条 SSE 提示。"""
|
||||
|
||||
async def run(
|
||||
self, ctx: GraphRunContext[WorkflowGraphState, WorkflowDeps]
|
||||
) -> "Dispatch":
|
||||
await ctx.deps.update_workflow_status(
|
||||
ctx.state.trace_id, WorkflowStatus.RUNNING.value
|
||||
)
|
||||
await _persist_context(ctx, status=WorkflowStatus.RUNNING.value)
|
||||
return Dispatch()
|
||||
|
||||
async def execute_step(state: WorkflowGraphState):
|
||||
"""执行单一工作流节点逻辑"""
|
||||
|
||||
@dataclass
|
||||
class Dispatch(BaseNode[WorkflowGraphState, WorkflowDeps, str]):
|
||||
"""读取当前 step,按 ``node`` / ``require_approval`` 字段选择下一节点。"""
|
||||
|
||||
async def run(
|
||||
self,
|
||||
ctx: GraphRunContext[WorkflowGraphState, WorkflowDeps],
|
||||
) -> "HumanApproval | SkillStep | ConsciousnessStep | Finalize":
|
||||
state = ctx.state
|
||||
if state.current_step_index >= len(state.work_link):
|
||||
state.status = WorkflowStatus.COMPLETED
|
||||
return state
|
||||
return Finalize(status=WorkflowStatus.COMPLETED.value)
|
||||
|
||||
step = state.work_link[state.current_step_index]
|
||||
step.get("node", "")
|
||||
action = step.get("action", "")
|
||||
step_data = state.work_link[state.current_step_index]
|
||||
if step_data.get("require_approval"):
|
||||
return HumanApproval()
|
||||
|
||||
# 记录开始状态
|
||||
node_type = step_data.get("node") or "skill_individual"
|
||||
if node_type == "consciousness_node":
|
||||
return ConsciousnessStep()
|
||||
if node_type == "skill_individual":
|
||||
return SkillStep()
|
||||
|
||||
# 未识别的 node 类型按失败处理(保守:不静默吞)
|
||||
state.logs.append(
|
||||
{
|
||||
str(state.current_step_index): [
|
||||
str(datetime.datetime.now()),
|
||||
"working",
|
||||
f"开始执行: {step.get('name', '未命名步骤')}",
|
||||
"failed",
|
||||
f"未知节点类型: {node_type}",
|
||||
]
|
||||
}
|
||||
)
|
||||
await save_context(state)
|
||||
await _persist_context(ctx, status=WorkflowStatus.FAILED.value)
|
||||
return Finalize(status=WorkflowStatus.FAILED.value)
|
||||
|
||||
try:
|
||||
# TODO: 实际对接不同节点执行逻辑 (例如: control_node, agent 技能)
|
||||
# 这里是简化版,向控制节点或指定 skill 发送指令
|
||||
|
||||
# ... 模拟执行逻辑 ...
|
||||
await asyncio.sleep(2)
|
||||
@dataclass
|
||||
class HumanApproval(BaseNode[WorkflowGraphState, WorkflowDeps, str]):
|
||||
"""人工审批节点:暂停 graph,等用户通过 SSE 回执决策。
|
||||
|
||||
# 记录结果
|
||||
state.blackboard[
|
||||
step.get("outputs", f"step_{state.current_step_index}_result")
|
||||
] = "Success execution of " + action
|
||||
state.logs[-1][str(state.current_step_index)] = [
|
||||
str(datetime.datetime.now()),
|
||||
"completed",
|
||||
f"成功: {action}",
|
||||
]
|
||||
回执约定(轻量协议,沿用现有 SSE 通道):
|
||||
|
||||
# 判断逻辑跳转
|
||||
logic_gate = step.get("logic_gate")
|
||||
if logic_gate and logic_gate.get("if_pass") == "exit":
|
||||
state.status = WorkflowStatus.COMPLETED
|
||||
else:
|
||||
state.current_step_index += 1
|
||||
- 含 ``approve`` / ``yes`` / ``ok`` 视为通过,回到 Dispatch 继续执行
|
||||
- 其它(包括 ``reject``/``no``/``abort``)视为拒绝,工作流终止为 FAILED
|
||||
"""
|
||||
|
||||
except Exception as e:
|
||||
state.logs[-1][str(state.current_step_index)] = [
|
||||
str(datetime.datetime.now()),
|
||||
"failed",
|
||||
str(e),
|
||||
]
|
||||
state.status = WorkflowStatus.FAILED
|
||||
logic_gate = step.get("logic_gate")
|
||||
if logic_gate and logic_gate.get("if_fail"):
|
||||
fail_target = logic_gate.get("if_fail")
|
||||
if "jump_to_step_" in fail_target:
|
||||
target_step = int(fail_target.split("_")[-1]) - 1
|
||||
state.current_step_index = target_step
|
||||
state.status = WorkflowStatus.RUNNING
|
||||
async def run(
|
||||
self,
|
||||
ctx: GraphRunContext[WorkflowGraphState, WorkflowDeps],
|
||||
) -> "Dispatch | Finalize":
|
||||
state = ctx.state
|
||||
step_data = state.work_link[state.current_step_index]
|
||||
# idempotent 推送:仅当本 step 还没通知过时才发 put_pending。
|
||||
# 这样 resume 场景(HumanApproval 节点被重新进入)不会给前端发重复消息。
|
||||
if state.current_step_index not in state.approvals_notified:
|
||||
await ctx.deps.put_pending(
|
||||
state.trace_id,
|
||||
f"步骤 {state.current_step_index + 1} ({step_data.get('name', '')}) "
|
||||
f"需要人工审批,请回复 approve / reject。",
|
||||
)
|
||||
state.approvals_notified.append(state.current_step_index)
|
||||
await _persist_context(ctx, status=WorkflowStatus.HANGUP.value)
|
||||
|
||||
await save_context(state)
|
||||
return state
|
||||
reply = (await ctx.deps.get_received(state.trace_id) or "").strip().lower()
|
||||
if any(token in reply for token in ("approve", "yes", "ok")):
|
||||
# 把 require_approval 置否避免无限循环重新进 HumanApproval
|
||||
step_data["require_approval"] = False
|
||||
await _persist_context(ctx, status=WorkflowStatus.RUNNING.value)
|
||||
return Dispatch()
|
||||
|
||||
async def _run():
|
||||
postgres_database = ray_actor_hook("postgres_database").postgres_database
|
||||
await postgres_database.update_workflow_status.remote(
|
||||
trace_id, WorkflowStatus.RUNNING
|
||||
state.logs.append(
|
||||
{
|
||||
str(state.current_step_index): [
|
||||
str(datetime.datetime.now()),
|
||||
"rejected",
|
||||
f"用户拒绝执行该步骤: {reply!r}",
|
||||
]
|
||||
}
|
||||
)
|
||||
await _persist_context(ctx, status=WorkflowStatus.FAILED.value)
|
||||
return Finalize(status=WorkflowStatus.FAILED.value)
|
||||
|
||||
state = WorkflowGraphState(
|
||||
trace_id=trace_id,
|
||||
blackboard={},
|
||||
work_link=workflow_data.get("work_link", []),
|
||||
)
|
||||
await save_context(state)
|
||||
|
||||
# 简单的图执行驱动 (模拟 pydantic-ai.graph.run 行为,直至 Graph 库正式稳定)
|
||||
while state.status == WorkflowStatus.RUNNING and state.current_step_index < len(
|
||||
state.work_link
|
||||
):
|
||||
state = await execute_step(state)
|
||||
@dataclass
|
||||
class SkillStep(BaseNode[WorkflowGraphState, WorkflowDeps, str]):
|
||||
"""skill_individual 路径:当前进程内拉一个专家子个体执行该步。"""
|
||||
|
||||
await postgres_database.update_workflow_status.remote(trace_id, state.status)
|
||||
global_workflow_manager = ray_actor_hook(
|
||||
"global_workflow_manager"
|
||||
).global_workflow_manager
|
||||
async def run(
|
||||
self,
|
||||
ctx: GraphRunContext[WorkflowGraphState, WorkflowDeps],
|
||||
) -> "Dispatch | Finalize":
|
||||
return await _execute_step(ctx, executor=ctx.deps.run_skill)
|
||||
|
||||
|
||||
@dataclass
|
||||
class ConsciousnessStep(BaseNode[WorkflowGraphState, WorkflowDeps, str]):
|
||||
"""consciousness_node 路径:远程调用 ConsciousnessNode actor 处理该步。"""
|
||||
|
||||
async def run(
|
||||
self,
|
||||
ctx: GraphRunContext[WorkflowGraphState, WorkflowDeps],
|
||||
) -> "Dispatch | Finalize":
|
||||
return await _execute_step(ctx, executor=ctx.deps.run_consciousness)
|
||||
|
||||
|
||||
@dataclass
|
||||
class Finalize(BaseNode[WorkflowGraphState, WorkflowDeps, str]):
|
||||
"""收尾节点:写最终状态、推送最终 SSE,把 workflow 终态作为 graph 输出。"""
|
||||
|
||||
status: str
|
||||
|
||||
async def run(
|
||||
self, ctx: GraphRunContext[WorkflowGraphState, WorkflowDeps]
|
||||
) -> End[str]:
|
||||
ctx.state.final_status = self.status
|
||||
await ctx.deps.update_workflow_status(ctx.state.trace_id, self.status)
|
||||
msg = (
|
||||
"工作流执行完成!"
|
||||
if state.status == WorkflowStatus.COMPLETED
|
||||
if self.status == WorkflowStatus.COMPLETED.value
|
||||
else "工作流执行失败。"
|
||||
)
|
||||
await global_workflow_manager.put_received.remote(trace_id, msg)
|
||||
await ctx.deps.put_pending(ctx.state.trace_id, msg)
|
||||
return End(self.status)
|
||||
|
||||
asyncio.run(_run())
|
||||
|
||||
# ─── 内部 helper ─────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
async def _persist_context(
|
||||
ctx: GraphRunContext[WorkflowGraphState, WorkflowDeps], *, status: str
|
||||
) -> None:
|
||||
"""把当前 state 落库到 workflow_context 表(覆盖式写入)。"""
|
||||
await ctx.deps.upsert_workflow_context(
|
||||
ctx.state.trace_id,
|
||||
workflow_pointer=ctx.state.current_step_index,
|
||||
blackboard=ctx.state.blackboard,
|
||||
work_link=ctx.state.work_link,
|
||||
workflow_status={str(datetime.datetime.now()): status},
|
||||
workflow_log=ctx.state.logs,
|
||||
)
|
||||
|
||||
|
||||
async def _execute_step(
|
||||
ctx: GraphRunContext[WorkflowGraphState, WorkflowDeps],
|
||||
*,
|
||||
executor: StepExecutor,
|
||||
) -> "Dispatch | Finalize":
|
||||
"""SkillStep / ConsciousnessStep 共享的执行骨架。
|
||||
|
||||
把"日志/SSE/blackboard 更新/逻辑闸门"这些跨节点共性逻辑抽出来;具体怎么
|
||||
跑这一步交给 ``executor`` 决定(生产是 SkillIndividual.run / actor.working
|
||||
远程调用,测试可以直接传 lambda)。
|
||||
"""
|
||||
state = ctx.state
|
||||
step_data = state.work_link[state.current_step_index]
|
||||
|
||||
state.logs.append(
|
||||
{
|
||||
str(state.current_step_index): [
|
||||
str(datetime.datetime.now()),
|
||||
"working",
|
||||
f"开始执行: {step_data.get('name', '未命名步骤')}",
|
||||
]
|
||||
}
|
||||
)
|
||||
await _persist_context(ctx, status=WorkflowStatus.RUNNING.value)
|
||||
await ctx.deps.put_pending(
|
||||
state.trace_id, f"执行步骤 {state.current_step_index + 1}..."
|
||||
)
|
||||
|
||||
try:
|
||||
output_text, success = await executor(step_data, state)
|
||||
except Exception as e: # 执行器抛异常 → 走失败分支
|
||||
output_text, success = str(e), False
|
||||
|
||||
if success:
|
||||
output_key = step_data.get(
|
||||
"outputs", f"step_{state.current_step_index}_result"
|
||||
)
|
||||
state.blackboard[output_key] = output_text
|
||||
state.logs[-1][str(state.current_step_index)] = [
|
||||
str(datetime.datetime.now()),
|
||||
"completed",
|
||||
f"成功: {step_data.get('action', '')}",
|
||||
]
|
||||
await _persist_context(ctx, status=WorkflowStatus.RUNNING.value)
|
||||
|
||||
logic_gate = step_data.get("logic_gate") or {}
|
||||
if logic_gate.get("if_pass") == "exit":
|
||||
return Finalize(status=WorkflowStatus.COMPLETED.value)
|
||||
|
||||
state.current_step_index += 1
|
||||
if state.current_step_index >= len(state.work_link):
|
||||
return Finalize(status=WorkflowStatus.COMPLETED.value)
|
||||
return Dispatch()
|
||||
|
||||
# 失败:if_fail 跳转优先于直接收尾
|
||||
state.logs[-1][str(state.current_step_index)] = [
|
||||
str(datetime.datetime.now()),
|
||||
"failed",
|
||||
output_text,
|
||||
]
|
||||
logic_gate = step_data.get("logic_gate") or {}
|
||||
fail_target = logic_gate.get("if_fail")
|
||||
if fail_target and "jump_to_step_" in fail_target:
|
||||
target_step = int(fail_target.split("_")[-1]) - 1
|
||||
state.current_step_index = target_step
|
||||
await _persist_context(ctx, status=WorkflowStatus.RUNNING.value)
|
||||
return Dispatch()
|
||||
|
||||
await _persist_context(ctx, status=WorkflowStatus.FAILED.value)
|
||||
return Finalize(status=WorkflowStatus.FAILED.value)
|
||||
|
||||
|
||||
# ─── 图定义 ───────────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
workflow_graph: Graph[WorkflowGraphState, WorkflowDeps, str] = Graph(
|
||||
nodes=[Initialize, Dispatch, HumanApproval, SkillStep, ConsciousnessStep, Finalize],
|
||||
state_type=WorkflowGraphState,
|
||||
run_end_type=str,
|
||||
)
|
||||
|
||||
|
||||
# ─── 默认执行器:把 step 派发到 SkillIndividual / ConsciousnessNode ────
|
||||
|
||||
|
||||
async def _default_skill_executor(
|
||||
step_data: Dict[str, Any], state: WorkflowGraphState
|
||||
) -> tuple[str, bool]:
|
||||
"""生产环境的 skill_individual 派发器:当前 task 进程现起 agent 执行。
|
||||
|
||||
每步现起一个 ``SkillIndividual`` 跑完即销毁,不绑定 actor 寿命。``agent_id``
|
||||
是必须的(用于从 GSM 拉到该子个体的配置)。
|
||||
"""
|
||||
from kilostar.worker_individual.skill_individual import SkillIndividual
|
||||
from kilostar.core.global_state_machine.gsm_snapshot import fetch_snapshot
|
||||
|
||||
agent_id = step_data.get("agent_id")
|
||||
if not agent_id:
|
||||
return "skill_individual 步骤缺少 agent_id", False
|
||||
|
||||
snapshot = await fetch_snapshot()
|
||||
agent_config = snapshot.individuals.get(agent_id)
|
||||
if not agent_config:
|
||||
return f"未找到 agent_id={agent_id} 的专家子个体", False
|
||||
|
||||
individual = SkillIndividual(dict(agent_config))
|
||||
task_event = {
|
||||
"step": step_data,
|
||||
"blackboard": state.blackboard,
|
||||
"original_command": state.original_command,
|
||||
}
|
||||
result = await individual.run(task_event)
|
||||
output = (
|
||||
result.get("output", "") if isinstance(result, dict) else str(result)
|
||||
)
|
||||
return output or "(empty)", True
|
||||
|
||||
|
||||
async def _default_consciousness_executor(
|
||||
step_data: Dict[str, Any], state: WorkflowGraphState
|
||||
) -> tuple[str, bool]:
|
||||
"""生产环境的 consciousness 派发器:远程调用 ConsciousnessNode.working。"""
|
||||
from kilostar.utils.ray_hook import ray_actor_hook
|
||||
from kilostar.core.individual.consciousness_node.template import (
|
||||
ForWorkflow,
|
||||
ForWorkflowInput,
|
||||
)
|
||||
from kilostar.core.work.workflow.workflow import WorkflowStep
|
||||
|
||||
consciousness_node = ray_actor_hook("consciousness_node").consciousness_node
|
||||
payload = ForWorkflowInput(
|
||||
workflow_step=WorkflowStep.model_validate(step_data),
|
||||
original_command=state.original_command,
|
||||
)
|
||||
result = await consciousness_node.working.remote(payload)
|
||||
if isinstance(result, ForWorkflow):
|
||||
return result.output, True
|
||||
if result is None:
|
||||
return "ConsciousnessNode 返回 None", False
|
||||
return f"ConsciousnessNode 返回未知类型: {type(result).__name__}", False
|
||||
|
||||
|
||||
def build_default_deps() -> WorkflowDeps:
|
||||
"""生产环境构造 ``WorkflowDeps``:把 ray actor handle 包装成 awaitable。
|
||||
|
||||
抽出来是为了让 ``run_workflow_task`` 入口和测试入口共享同一套包装逻辑。
|
||||
"""
|
||||
from kilostar.utils.ray_hook import ray_actor_hook
|
||||
|
||||
postgres_database = ray_actor_hook("postgres_database").postgres_database
|
||||
global_workflow_manager = ray_actor_hook(
|
||||
"global_workflow_manager"
|
||||
).global_workflow_manager
|
||||
|
||||
async def _upsert_workflow_context(trace_id: str, **kwargs: Any) -> Any:
|
||||
return await postgres_database.upsert_workflow_context.remote(
|
||||
trace_id, **kwargs
|
||||
)
|
||||
|
||||
async def _update_workflow_status(trace_id: str, status: str) -> Any:
|
||||
return await postgres_database.update_workflow_status.remote(
|
||||
trace_id, status
|
||||
)
|
||||
|
||||
async def _put_pending(trace_id: str, message: str) -> Any:
|
||||
return await global_workflow_manager.put_pending.remote(trace_id, message)
|
||||
|
||||
async def _get_received(trace_id: str) -> str:
|
||||
return await global_workflow_manager.get_received.remote(trace_id)
|
||||
|
||||
return WorkflowDeps(
|
||||
upsert_workflow_context=_upsert_workflow_context,
|
||||
update_workflow_status=_update_workflow_status,
|
||||
put_pending=_put_pending,
|
||||
get_received=_get_received,
|
||||
run_skill=_default_skill_executor,
|
||||
run_consciousness=_default_consciousness_executor,
|
||||
)
|
||||
|
||||
|
||||
async def run_workflow_graph(
|
||||
workflow_data: Dict[str, Any],
|
||||
trace_id: str,
|
||||
*,
|
||||
deps: Optional[WorkflowDeps] = None,
|
||||
persistence: Optional[BaseStatePersistence] = None,
|
||||
) -> str:
|
||||
"""在当前事件循环里跑一遍 workflow graph,返回 workflow 终态字符串。
|
||||
|
||||
Args:
|
||||
workflow_data: ``KiloStarWorkflow.model_dump()`` 出来的 dict
|
||||
trace_id: 工作流追踪 id
|
||||
deps: 缺省时通过 ``build_default_deps`` 现场构造(生产路径)
|
||||
persistence: 缺省给一个 ``FullStatePersistence`` 让 graph 自动在节点
|
||||
边界 snapshot;外部传入则共享同一份持久化(便于诊断 / 后续 resume)
|
||||
"""
|
||||
if deps is None:
|
||||
deps = build_default_deps()
|
||||
if persistence is None:
|
||||
persistence = FullStatePersistence()
|
||||
|
||||
state = WorkflowGraphState(
|
||||
trace_id=trace_id,
|
||||
blackboard={},
|
||||
work_link=list(workflow_data.get("work_link", []) or []),
|
||||
original_command=(workflow_data.get("workflow_metadata") or {}).get(
|
||||
"command", ""
|
||||
)
|
||||
or "",
|
||||
)
|
||||
result = await workflow_graph.run(
|
||||
Initialize(),
|
||||
state=state,
|
||||
deps=deps,
|
||||
persistence=persistence,
|
||||
)
|
||||
return result.output
|
||||
|
||||
|
||||
async def resume_workflow_graph(
|
||||
trace_id: str,
|
||||
*,
|
||||
deps: Optional[WorkflowDeps] = None,
|
||||
persistence: Optional[BaseStatePersistence] = None,
|
||||
) -> str:
|
||||
"""从持久化里恢复一个工作流,跑剩余节点直至 End。
|
||||
|
||||
要求 ``persistence`` 已经事先 ``hydrate``(生产路径用 ``PostgresStatePersistence.hydrate``)。
|
||||
本函数只负责调 ``Graph.iter_from_persistence`` 把剩下的节点跑完。
|
||||
"""
|
||||
if persistence is None:
|
||||
raise ValueError("resume 必须显式传入 persistence")
|
||||
if deps is None:
|
||||
deps = build_default_deps()
|
||||
|
||||
final_output: str = WorkflowStatus.RUNNING.value
|
||||
async with workflow_graph.iter_from_persistence(
|
||||
persistence, deps=deps
|
||||
) as run:
|
||||
async for node in run:
|
||||
if isinstance(node, End):
|
||||
final_output = node.data
|
||||
break
|
||||
return final_output
|
||||
|
||||
|
||||
@ray.remote
|
||||
class WorkflowRunningEngine:
|
||||
def __init__(
|
||||
self, consciousness_node=None, control_node=None, regulatory_node=None
|
||||
):
|
||||
self.consciousness_node = consciousness_node
|
||||
self.control_node = control_node
|
||||
self.regulatory_node = regulatory_node
|
||||
self.events_queue = asyncio.Queue()
|
||||
def run_workflow_task(workflow_data: dict, trace_id: str):
|
||||
"""workflow 的 ray task 入口:一次性执行,跑完即销毁。
|
||||
|
||||
async def put_event(self, event):
|
||||
await self.events_queue.put(event)
|
||||
生产路径下持久化交给 ``PostgresStatePersistence`` —— 即便进程崩溃,再 fire
|
||||
一次相同 ``trace_id`` 的任务(或调 ``/workflow/{trace_id}/resume``)即可
|
||||
续跑。同时为了支持 fresh start,先尝试 ``hydrate``:
|
||||
- hydrate 拿到内容 → 走 resume 路径
|
||||
- hydrate 没拿到 → 走全新路径
|
||||
|
||||
async def run(self):
|
||||
"""引擎循环提取事件"""
|
||||
while True:
|
||||
await self.events_queue.get()
|
||||
await asyncio.sleep(1)
|
||||
ray task 是新进程,contextvars 不会从 caller 传过来,所以入口先 bind 一次
|
||||
``trace_id``,让节点内的日志自动带上它。
|
||||
"""
|
||||
from kilostar.utils.request_context import trace_id_scope
|
||||
from kilostar.core.work.workflow.graph_persistence import (
|
||||
build_postgres_persistence,
|
||||
)
|
||||
|
||||
async def execute_workflow(self, workflow: KiloStarWorkflow):
|
||||
# 这个方法可以由意识节点调用来提交一个完整的运行任务
|
||||
workflow_dict = workflow.model_dump()
|
||||
trace_id = workflow.trace_id
|
||||
run_workflow_task.remote(workflow_dict, trace_id)
|
||||
async def _entry() -> None:
|
||||
with trace_id_scope(trace_id):
|
||||
persistence = build_postgres_persistence(trace_id)
|
||||
persistence.set_graph_types(workflow_graph)
|
||||
recovered = False
|
||||
try:
|
||||
recovered = await persistence.hydrate()
|
||||
except Exception: # pragma: no cover - 防御
|
||||
recovered = False
|
||||
|
||||
if recovered:
|
||||
await resume_workflow_graph(trace_id, persistence=persistence)
|
||||
else:
|
||||
await run_workflow_graph(
|
||||
workflow_data, trace_id, persistence=persistence
|
||||
)
|
||||
|
||||
asyncio.run(_entry())
|
||||
|
||||
Reference in New Issue
Block a user