"""``PostgresStatePersistence`` 与 graph resume 路径单元测试。 不依赖真实 postgres / ray —— 用两个 lambda 模拟 read/write 即可。 """ from __future__ import annotations from typing import Any, Dict, List, Optional, Tuple from unittest.mock import AsyncMock import pytest from kilostar.core.work.workflow.workflow_engine import ( WorkflowDeps, resume_workflow_graph, run_workflow_graph, workflow_graph, ) from kilostar.core.work.workflow.graph_persistence import ( PostgresStatePersistence, ) from kilostar.core.work.workflow.model import WorkflowStatus def _make_in_memory_io() -> tuple[ Dict[str, Any], "callable", "callable" ]: """构造一对 (db_state, write_fn, read_fn):测试用模拟 postgres。""" db: Dict[str, Any] = {} async def write(trace_id: str, history: Any) -> None: db[trace_id] = history async def read(trace_id: str) -> Optional[Any]: return db.get(trace_id) return db, write, read def _make_deps( *, skill_outputs: List[Tuple[str, bool]] | None = None, received_replies: List[str] | None = None, ) -> tuple[WorkflowDeps, Dict[str, List[Any]]]: sink: Dict[str, List[Any]] = {"skill_calls": [], "pending": []} skill_q = list(skill_outputs or []) reply_q = list(received_replies or []) upsert = AsyncMock() status = AsyncMock() pending = AsyncMock(side_effect=lambda tid, msg: sink["pending"].append((tid, msg))) async def _get_received(tid): if reply_q: return reply_q.pop(0) return "" async def _run_skill(step, state): sink["skill_calls"].append((step.get("name"), state.current_step_index)) if not skill_q: return "(no fixture)", True return skill_q.pop(0) async def _run_consciousness(step, state): # 不会被本测试触发 return "(consc)", True return ( WorkflowDeps( upsert_workflow_context=upsert, update_workflow_status=status, put_pending=pending, get_received=_get_received, run_skill=_run_skill, run_consciousness=_run_consciousness, ), sink, ) # ─── PostgresStatePersistence: 写穿 ────────────────────────────────────── @pytest.mark.asyncio async def test_postgres_persistence_writes_history_on_each_node(): """每经过一个节点边界,DB 都应该被更新一次(覆盖式写最新 history)。""" db, write, read = _make_in_memory_io() persistence = PostgresStatePersistence( trace_id="t1", write_history=write, read_history=read ) persistence.set_graph_types(workflow_graph) deps, _ = _make_deps(skill_outputs=[("ok", True)]) workflow_data = { "work_link": [ {"step": 1, "name": "s1", "action": "do", "node": "skill_individual", "agent_id": "a1"}, ] } final = await run_workflow_graph( workflow_data, "t1", deps=deps, persistence=persistence ) assert final == WorkflowStatus.COMPLETED.value # DB 里应当有 history 且 history 至少包含一个节点 snapshot assert "t1" in db history = db["t1"] assert isinstance(history, list) assert len(history) >= 1 @pytest.mark.asyncio async def test_postgres_persistence_swallows_db_errors_during_run(): """DB 写入失败不应中断 graph 运行(只在 snapshot_end 必须 succeed)。""" db, _, read = _make_in_memory_io() write_calls = {"n": 0} async def flaky_write(trace_id, history): write_calls["n"] += 1 # 中途失败几次但最终一次(snapshot_end)成功 if write_calls["n"] < 3: raise RuntimeError("transient db error") db[trace_id] = history persistence = PostgresStatePersistence( trace_id="t-flaky", write_history=flaky_write, read_history=read ) persistence.set_graph_types(workflow_graph) deps, _ = _make_deps(skill_outputs=[("ok", True)]) workflow_data = { "work_link": [ {"step": 1, "name": "s1", "action": "do", "node": "skill_individual", "agent_id": "a1"}, ] } final = await run_workflow_graph( workflow_data, "t-flaky", deps=deps, persistence=persistence ) # 即便中途多次写入失败,graph 仍然跑完 assert final == WorkflowStatus.COMPLETED.value # ─── hydrate / resume:从 DB 续跑 ────────────────────────────────────── @pytest.mark.asyncio async def test_hydrate_returns_false_when_no_record(): """DB 没记录时 hydrate 应返回 False,调用方走 fresh start。""" db, write, read = _make_in_memory_io() persistence = PostgresStatePersistence( trace_id="t-empty", write_history=write, read_history=read ) persistence.set_graph_types(workflow_graph) assert await persistence.hydrate() is False @pytest.mark.asyncio async def test_resume_continues_from_persisted_history(monkeypatch): """先用 Graph.iter 跑半截留下"created"snapshot,然后用 resume 把剩余节点跑完。 断言:resume 路径只触发 1 次 skill(第 2 步),第 1 步是从 hydrate 恢复的、 不再被重新执行。 """ from kilostar.core.work.workflow.workflow_engine import ( Initialize, WorkflowGraphState, ) db, write, read = _make_in_memory_io() persistence_a = PostgresStatePersistence( trace_id="t-resume", write_history=write, read_history=read ) persistence_a.set_graph_types(workflow_graph) deps_a, sink_a = _make_deps(skill_outputs=[("step-1-ok", True), ("step-2-ok", True)]) workflow_data = { "work_link": [ {"step": 1, "name": "s1", "action": "do", "node": "skill_individual", "agent_id": "a1"}, {"step": 2, "name": "s2", "action": "do", "node": "skill_individual", "agent_id": "a1"}, ] } # 用 iter 手动驱动:跑 Initialize → Dispatch → SkillStep(跑完第 1 步), # 然后停下,让最后一个还没跑的节点(应该是 Dispatch)保持 created state = WorkflowGraphState( trace_id="t-resume", blackboard={}, work_link=list(workflow_data["work_link"]), original_command="", ) async with workflow_graph.iter( Initialize(), state=state, deps=deps_a, persistence=persistence_a, ) as run: # Initialize → Dispatch → SkillStep → Dispatch (这一步停下) await run.next() # 跑 Initialize,next_node = Dispatch await run.next() # 跑 Dispatch,next_node = SkillStep await run.next() # 跑 SkillStep,next_node = Dispatch(第 2 个) # 此时不再调 next,第 2 个 Dispatch 仍是 created 状态 assert len(sink_a["skill_calls"]) == 1, "中断时只应跑了 1 步" # 第二阶段:用同份 history resume persistence_b = PostgresStatePersistence( trace_id="t-resume", write_history=write, read_history=read ) persistence_b.set_graph_types(workflow_graph) assert await persistence_b.hydrate() is True deps_b, sink_b = _make_deps(skill_outputs=[("step-2-resumed", True)]) final = await resume_workflow_graph( "t-resume", deps=deps_b, persistence=persistence_b ) assert final == WorkflowStatus.COMPLETED.value # 关键断言:resume 只执行了第 2 步,第 1 步没被重复 assert len(sink_b["skill_calls"]) == 1 assert sink_b["skill_calls"][0][0] == "s2" # ─── HumanApproval idempotent resume ────────────────────────────────────── @pytest.mark.asyncio async def test_human_approval_idempotent_on_resume(): """HumanApproval 节点在 resume 后不应重复给前端推 put_pending。 流程: 1. 第一次跑:进 HumanApproval → put_pending 1 次 → 用户没回复 → 中断 2. 第二次跑(resume):用户已回复 approve → HumanApproval 第二次进入但 不应该再 put_pending;只读 reply 通过。 """ from kilostar.core.work.workflow.workflow_engine import ( Initialize, WorkflowGraphState, ) db, write, read = _make_in_memory_io() persistence_a = PostgresStatePersistence( trace_id="t-hitl", write_history=write, read_history=read ) persistence_a.set_graph_types(workflow_graph) # 第一次:reply 队列空 → get_received 返回空串 → HumanApproval 走拒绝路径之前 # 我们要在 put_pending 后停下,所以用 iter 手动驱动,跑到 HumanApproval 内部前停住。 # 简化策略:直接让第一次跑完到 Finalize FAILED(reply=""),第二次 resume 验证幂等 # 不太合适——FAILED 后没法 resume。改用:第一次 reply=""(拒绝),用 graph_state # 里 approvals_notified 字段来直接验证幂等性更纯粹。 deps_a, sink_a = _make_deps( skill_outputs=[("step-1", True)], received_replies=[""], # 空 reply → HumanApproval 拒绝 ) workflow_data_step = { "step": 1, "name": "approve-me", "action": "do", "node": "skill_individual", "agent_id": "a1", "require_approval": True, } # 第一次跑:单步 require_approval=True,reply="" → put_pending 1 次 + Finalize FAILED state = WorkflowGraphState( trace_id="t-hitl", blackboard={}, work_link=[dict(workflow_data_step)], original_command="", ) async with workflow_graph.iter( Initialize(), state=state, deps=deps_a, persistence=persistence_a, ) as run: async for _ in run: pass pending_count_first = len(sink_a["pending"]) # 至少有审批提示 + 最终 FAILED 提示这两类 pending approval_msgs_first = [ m for _, m in sink_a["pending"] if "需要人工审批" in m ] assert len(approval_msgs_first) == 1, "第一次跑必须发审批提示一次" # 关键:state 已经记下这次通知 assert state.approvals_notified == [0] # 第二次:直接构造一个 state,approvals_notified 里已有 0,再次进 HumanApproval # 模拟"resume 后 graph 重新进入审批节点"——put_pending 不应再发 deps_b, sink_b = _make_deps( received_replies=["approve"], # 这次用户回复 approve ) state_resume = WorkflowGraphState( trace_id="t-hitl", blackboard={}, work_link=[dict(workflow_data_step)], original_command="", approvals_notified=[0], # 关键:之前已通知过 ) persistence_b = PostgresStatePersistence( trace_id="t-hitl-resume", write_history=write, read_history=read ) persistence_b.set_graph_types(workflow_graph) # 直接从 Dispatch 起跑(跳过 Initialize 避免它再次 update_workflow_status) from kilostar.core.work.workflow.workflow_engine import Dispatch async with workflow_graph.iter( Dispatch(), state=state_resume, deps=deps_b, persistence=persistence_b, ) as run: async for _ in run: pass approval_msgs_second = [ m for _, m in sink_b["pending"] if "需要人工审批" in m ] # 关键断言:resume 路径上不应该再出现"需要人工审批"提示 assert len(approval_msgs_second) == 0 # approve 通过 → require_approval 置 False → 真正跑到 SkillStep assert len(sink_b["skill_calls"]) == 1 assert sink_b["skill_calls"][0][0] == "approve-me" # ─── mermaid 高亮 ───────────────────────────────────────────────────────── @pytest.mark.asyncio async def test_mermaid_highlight_visited_nodes_from_history(): """跑一次 graph 后,dump_json 出来的 history 应足以提取出 visited 节点类名。 这个测试模拟 ``/graph`` API 的过滤逻辑: - 只取 status="success" 的 NodeSnapshot - id 形如 "ClassName:hash",截前缀 - 去重保序 然后用提取出的 visited 调 ``mermaid_code(highlighted_nodes=...)``,应能在 输出里看到对应 ``class X highlighted`` 行。 """ import json db, write, read = _make_in_memory_io() persistence = PostgresStatePersistence( trace_id="t-mermaid", write_history=write, read_history=read ) persistence.set_graph_types(workflow_graph) deps, _ = _make_deps(skill_outputs=[("ok", True)]) workflow_data = { "work_link": [ {"step": 1, "name": "s1", "action": "do", "node": "skill_individual", "agent_id": "a1"}, ] } final = await run_workflow_graph( workflow_data, "t-mermaid", deps=deps, persistence=persistence ) assert final == WorkflowStatus.COMPLETED.value # 从 DB 读 history(实际是 list[dict]) history = db["t-mermaid"] assert isinstance(history, list) # 复刻 API 的过滤逻辑 seen: set[str] = set() visited: list[str] = [] for entry in history: if not isinstance(entry, dict) or entry.get("kind") != "node": continue if entry.get("status") != "success": continue sid = entry.get("id") or "" cls_name = sid.split(":", 1)[0] if sid else "" if cls_name and cls_name not in seen: seen.add(cls_name) visited.append(cls_name) # 至少应该 visit 过 Initialize / Dispatch / SkillStep assert "Initialize" in visited assert "Dispatch" in visited assert "SkillStep" in visited mermaid = workflow_graph.mermaid_code(highlighted_nodes=visited) # 高亮一定会带 classDef + class 行 assert "classDef" in mermaid assert "class Initialize" in mermaid assert "class SkillStep" in mermaid