Files
KiloStar/tests/unit/test_workflow_persistence.py
zhaoxi 99520c69d7 feat(system):优化后端
1.新增后端测试
2.增加了后端的加密
3.增加了i18n(国际化)
2026-05-31 15:39:34 +00:00

387 lines
14 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""``PostgresStatePersistence`` 与 graph resume 路径单元测试。
不依赖真实 postgres / ray —— 用两个 lambda 模拟 read/write 即可。
"""
from __future__ import annotations
from typing import Any, Dict, List, Optional, Tuple
from unittest.mock import AsyncMock
import pytest
from kilostar.core.work.workflow.workflow_engine import (
WorkflowDeps,
resume_workflow_graph,
run_workflow_graph,
workflow_graph,
)
from kilostar.core.work.workflow.graph_persistence import (
PostgresStatePersistence,
)
from kilostar.core.work.workflow.model import WorkflowStatus
def _make_in_memory_io() -> tuple[
Dict[str, Any], "callable", "callable"
]:
"""构造一对 (db_state, write_fn, read_fn):测试用模拟 postgres。"""
db: Dict[str, Any] = {}
async def write(trace_id: str, history: Any) -> None:
db[trace_id] = history
async def read(trace_id: str) -> Optional[Any]:
return db.get(trace_id)
return db, write, read
def _make_deps(
*,
skill_outputs: List[Tuple[str, bool]] | None = None,
received_replies: List[str] | None = None,
) -> tuple[WorkflowDeps, Dict[str, List[Any]]]:
sink: Dict[str, List[Any]] = {"skill_calls": [], "pending": []}
skill_q = list(skill_outputs or [])
reply_q = list(received_replies or [])
upsert = AsyncMock()
status = AsyncMock()
pending = AsyncMock(side_effect=lambda tid, msg: sink["pending"].append((tid, msg)))
async def _get_received(tid):
if reply_q:
return reply_q.pop(0)
return ""
async def _run_skill(step, state):
sink["skill_calls"].append((step.get("name"), state.current_step_index))
if not skill_q:
return "(no fixture)", True
return skill_q.pop(0)
async def _run_consciousness(step, state): # 不会被本测试触发
return "(consc)", True
return (
WorkflowDeps(
upsert_workflow_context=upsert,
update_workflow_status=status,
put_pending=pending,
get_received=_get_received,
run_skill=_run_skill,
run_consciousness=_run_consciousness,
),
sink,
)
# ─── PostgresStatePersistence: 写穿 ──────────────────────────────────────
@pytest.mark.asyncio
async def test_postgres_persistence_writes_history_on_each_node():
"""每经过一个节点边界,DB 都应该被更新一次(覆盖式写最新 history)。"""
db, write, read = _make_in_memory_io()
persistence = PostgresStatePersistence(
trace_id="t1", write_history=write, read_history=read
)
persistence.set_graph_types(workflow_graph)
deps, _ = _make_deps(skill_outputs=[("ok", True)])
workflow_data = {
"work_link": [
{"step": 1, "name": "s1", "action": "do",
"node": "skill_individual", "agent_id": "a1"},
]
}
final = await run_workflow_graph(
workflow_data, "t1", deps=deps, persistence=persistence
)
assert final == WorkflowStatus.COMPLETED.value
# DB 里应当有 history 且 history 至少包含一个节点 snapshot
assert "t1" in db
history = db["t1"]
assert isinstance(history, list)
assert len(history) >= 1
@pytest.mark.asyncio
async def test_postgres_persistence_swallows_db_errors_during_run():
"""DB 写入失败不应中断 graph 运行(只在 snapshot_end 必须 succeed)。"""
db, _, read = _make_in_memory_io()
write_calls = {"n": 0}
async def flaky_write(trace_id, history):
write_calls["n"] += 1
# 中途失败几次但最终一次(snapshot_end)成功
if write_calls["n"] < 3:
raise RuntimeError("transient db error")
db[trace_id] = history
persistence = PostgresStatePersistence(
trace_id="t-flaky", write_history=flaky_write, read_history=read
)
persistence.set_graph_types(workflow_graph)
deps, _ = _make_deps(skill_outputs=[("ok", True)])
workflow_data = {
"work_link": [
{"step": 1, "name": "s1", "action": "do",
"node": "skill_individual", "agent_id": "a1"},
]
}
final = await run_workflow_graph(
workflow_data, "t-flaky", deps=deps, persistence=persistence
)
# 即便中途多次写入失败,graph 仍然跑完
assert final == WorkflowStatus.COMPLETED.value
# ─── hydrate / resume:从 DB 续跑 ──────────────────────────────────────
@pytest.mark.asyncio
async def test_hydrate_returns_false_when_no_record():
"""DB 没记录时 hydrate 应返回 False,调用方走 fresh start。"""
db, write, read = _make_in_memory_io()
persistence = PostgresStatePersistence(
trace_id="t-empty", write_history=write, read_history=read
)
persistence.set_graph_types(workflow_graph)
assert await persistence.hydrate() is False
@pytest.mark.asyncio
async def test_resume_continues_from_persisted_history(monkeypatch):
"""先用 Graph.iter 跑半截留下"created"snapshot,然后用 resume 把剩余节点跑完。
断言:resume 路径只触发 1 次 skill(第 2 步),第 1 步是从 hydrate 恢复的、
不再被重新执行。
"""
from kilostar.core.work.workflow.workflow_engine import (
Initialize,
WorkflowGraphState,
)
db, write, read = _make_in_memory_io()
persistence_a = PostgresStatePersistence(
trace_id="t-resume", write_history=write, read_history=read
)
persistence_a.set_graph_types(workflow_graph)
deps_a, sink_a = _make_deps(skill_outputs=[("step-1-ok", True), ("step-2-ok", True)])
workflow_data = {
"work_link": [
{"step": 1, "name": "s1", "action": "do",
"node": "skill_individual", "agent_id": "a1"},
{"step": 2, "name": "s2", "action": "do",
"node": "skill_individual", "agent_id": "a1"},
]
}
# 用 iter 手动驱动:跑 Initialize → Dispatch → SkillStep(跑完第 1 步),
# 然后停下,让最后一个还没跑的节点(应该是 Dispatch)保持 created
state = WorkflowGraphState(
trace_id="t-resume",
blackboard={},
work_link=list(workflow_data["work_link"]),
original_command="",
)
async with workflow_graph.iter(
Initialize(),
state=state,
deps=deps_a,
persistence=persistence_a,
) as run:
# Initialize → Dispatch → SkillStep → Dispatch (这一步停下)
await run.next() # 跑 Initializenext_node = Dispatch
await run.next() # 跑 Dispatchnext_node = SkillStep
await run.next() # 跑 SkillStepnext_node = Dispatch(第 2 个)
# 此时不再调 next,第 2 个 Dispatch 仍是 created 状态
assert len(sink_a["skill_calls"]) == 1, "中断时只应跑了 1 步"
# 第二阶段:用同份 history resume
persistence_b = PostgresStatePersistence(
trace_id="t-resume", write_history=write, read_history=read
)
persistence_b.set_graph_types(workflow_graph)
assert await persistence_b.hydrate() is True
deps_b, sink_b = _make_deps(skill_outputs=[("step-2-resumed", True)])
final = await resume_workflow_graph(
"t-resume", deps=deps_b, persistence=persistence_b
)
assert final == WorkflowStatus.COMPLETED.value
# 关键断言:resume 只执行了第 2 步,第 1 步没被重复
assert len(sink_b["skill_calls"]) == 1
assert sink_b["skill_calls"][0][0] == "s2"
# ─── HumanApproval idempotent resume ──────────────────────────────────────
@pytest.mark.asyncio
async def test_human_approval_idempotent_on_resume():
"""HumanApproval 节点在 resume 后不应重复给前端推 put_pending。
流程:
1. 第一次跑:进 HumanApproval → put_pending 1 次 → 用户没回复 → 中断
2. 第二次跑(resume):用户已回复 approve → HumanApproval 第二次进入但
不应该再 put_pending;只读 reply 通过。
"""
from kilostar.core.work.workflow.workflow_engine import (
Initialize,
WorkflowGraphState,
)
db, write, read = _make_in_memory_io()
persistence_a = PostgresStatePersistence(
trace_id="t-hitl", write_history=write, read_history=read
)
persistence_a.set_graph_types(workflow_graph)
# 第一次:reply 队列空 → get_received 返回空串 → HumanApproval 走拒绝路径之前
# 我们要在 put_pending 后停下,所以用 iter 手动驱动,跑到 HumanApproval 内部前停住。
# 简化策略:直接让第一次跑完到 Finalize FAILEDreply=""),第二次 resume 验证幂等
# 不太合适——FAILED 后没法 resume。改用:第一次 reply=""(拒绝),用 graph_state
# 里 approvals_notified 字段来直接验证幂等性更纯粹。
deps_a, sink_a = _make_deps(
skill_outputs=[("step-1", True)],
received_replies=[""], # 空 reply → HumanApproval 拒绝
)
workflow_data_step = {
"step": 1, "name": "approve-me", "action": "do",
"node": "skill_individual", "agent_id": "a1",
"require_approval": True,
}
# 第一次跑:单步 require_approval=Truereply="" → put_pending 1 次 + Finalize FAILED
state = WorkflowGraphState(
trace_id="t-hitl",
blackboard={},
work_link=[dict(workflow_data_step)],
original_command="",
)
async with workflow_graph.iter(
Initialize(),
state=state,
deps=deps_a,
persistence=persistence_a,
) as run:
async for _ in run:
pass
pending_count_first = len(sink_a["pending"])
# 至少有审批提示 + 最终 FAILED 提示这两类 pending
approval_msgs_first = [
m for _, m in sink_a["pending"] if "需要人工审批" in m
]
assert len(approval_msgs_first) == 1, "第一次跑必须发审批提示一次"
# 关键:state 已经记下这次通知
assert state.approvals_notified == [0]
# 第二次:直接构造一个 stateapprovals_notified 里已有 0,再次进 HumanApproval
# 模拟"resume 后 graph 重新进入审批节点"——put_pending 不应再发
deps_b, sink_b = _make_deps(
received_replies=["approve"], # 这次用户回复 approve
)
state_resume = WorkflowGraphState(
trace_id="t-hitl",
blackboard={},
work_link=[dict(workflow_data_step)],
original_command="",
approvals_notified=[0], # 关键:之前已通知过
)
persistence_b = PostgresStatePersistence(
trace_id="t-hitl-resume", write_history=write, read_history=read
)
persistence_b.set_graph_types(workflow_graph)
# 直接从 Dispatch 起跑(跳过 Initialize 避免它再次 update_workflow_status
from kilostar.core.work.workflow.workflow_engine import Dispatch
async with workflow_graph.iter(
Dispatch(),
state=state_resume,
deps=deps_b,
persistence=persistence_b,
) as run:
async for _ in run:
pass
approval_msgs_second = [
m for _, m in sink_b["pending"] if "需要人工审批" in m
]
# 关键断言:resume 路径上不应该再出现"需要人工审批"提示
assert len(approval_msgs_second) == 0
# approve 通过 → require_approval 置 False → 真正跑到 SkillStep
assert len(sink_b["skill_calls"]) == 1
assert sink_b["skill_calls"][0][0] == "approve-me"
# ─── mermaid 高亮 ─────────────────────────────────────────────────────────
@pytest.mark.asyncio
async def test_mermaid_highlight_visited_nodes_from_history():
"""跑一次 graph 后,dump_json 出来的 history 应足以提取出 visited 节点类名。
这个测试模拟 ``/graph`` API 的过滤逻辑:
- 只取 status="success" 的 NodeSnapshot
- id 形如 "ClassName:hash",截前缀
- 去重保序
然后用提取出的 visited 调 ``mermaid_code(highlighted_nodes=...)``,应能在
输出里看到对应 ``class X highlighted`` 行。
"""
import json
db, write, read = _make_in_memory_io()
persistence = PostgresStatePersistence(
trace_id="t-mermaid", write_history=write, read_history=read
)
persistence.set_graph_types(workflow_graph)
deps, _ = _make_deps(skill_outputs=[("ok", True)])
workflow_data = {
"work_link": [
{"step": 1, "name": "s1", "action": "do",
"node": "skill_individual", "agent_id": "a1"},
]
}
final = await run_workflow_graph(
workflow_data, "t-mermaid", deps=deps, persistence=persistence
)
assert final == WorkflowStatus.COMPLETED.value
# 从 DB 读 history(实际是 list[dict]
history = db["t-mermaid"]
assert isinstance(history, list)
# 复刻 API 的过滤逻辑
seen: set[str] = set()
visited: list[str] = []
for entry in history:
if not isinstance(entry, dict) or entry.get("kind") != "node":
continue
if entry.get("status") != "success":
continue
sid = entry.get("id") or ""
cls_name = sid.split(":", 1)[0] if sid else ""
if cls_name and cls_name not in seen:
seen.add(cls_name)
visited.append(cls_name)
# 至少应该 visit 过 Initialize / Dispatch / SkillStep
assert "Initialize" in visited
assert "Dispatch" in visited
assert "SkillStep" in visited
mermaid = workflow_graph.mermaid_code(highlighted_nodes=visited)
# 高亮一定会带 classDef + class 行
assert "classDef" in mermaid
assert "class Initialize" in mermaid
assert "class SkillStep" in mermaid