feat: 清理 control_node + 引入 task 一等公民

- control_node 标注 DEPRECATED:保留目录壳子供未来远程探针节点复用,删除调用路径与相关测试
- 新增 task 表:极简元数据持久化 regulatory_node 完成的短任务(出报告/写文件/查询整理)
- regulatory_node 自标注:MessageResponse 扩展 task_action/title/summary,_run 末尾非阻塞落库
- query_task_list 改查 task 表,符合用户对"任务列表"的直觉,与 workflow 体系解耦
- 新增 /api/v1/task/list|/{id} 只读 API(task 由 regulatory 内部触发,不开放对外创建)

Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
This commit is contained in:
2026-06-17 16:30:19 +00:00
parent 005ce566a8
commit 4aa1dab283
20 changed files with 510 additions and 91 deletions
+1 -1
View File
@@ -147,7 +147,7 @@ KiloStar/
│ │ ├── individual/ # Agent node implementations
│ │ │ ├── consciousness_node/ # Task planning
│ │ │ ├── regulatory_node/ # Quality oversight
│ │ │ ├── control_node/ # Routing & dispatch
│ │ │ ├── control_node/ # Deprecated (name reserved for future remote-probe node)
│ │ │ └── growth_node/ # Capability expansion
│ │ ├── work/ # Work execution layer
│ │ │ ├── workflow/ # Workflow engine (pydantic-graph)
@@ -0,0 +1,46 @@
"""add task table for regulatory_node short tasks
Revision ID: 0011
Revises: 0010
Create Date: 2026-06-17
"""
from alembic import op
import sqlalchemy as sa
from sqlalchemy.dialects.postgresql import JSONB
revision = "0011"
down_revision = "0010"
branch_labels = None
depends_on = None
def upgrade() -> None:
op.create_table(
"task",
sa.Column("task_id", sa.String(64), primary_key=True),
sa.Column("user_id", sa.String(64), index=True, nullable=False),
sa.Column("chat_id", sa.String(64), index=True, nullable=True),
sa.Column("command", sa.Text(), nullable=False),
sa.Column("title", sa.String(255), nullable=False),
sa.Column(
"status", sa.String(20), index=True, server_default="completed"
),
sa.Column("result_summary", sa.Text(), nullable=True),
sa.Column("artifact_refs", JSONB, nullable=True),
sa.Column(
"created_at",
sa.DateTime(timezone=True),
server_default=sa.func.now(),
index=True,
),
sa.Column(
"updated_at",
sa.DateTime(timezone=True),
server_default=sa.func.now(),
),
)
def downgrade() -> None:
op.drop_table("task")
+1 -1
View File
@@ -55,7 +55,7 @@
"name": "tavily_search",
"file": "tavily_search.py",
"is_system": false,
"action_scope": ["control_node", "consciousness_node", "regulatory_node"],
"action_scope": ["consciousness_node", "regulatory_node"],
"config_args": {
"api_key": "",
"max_results": "5",
@@ -1,7 +1,10 @@
"""query_task_list:列出当前用户的所有工作流任务
"""query_task_list:列出当前用户的短任务记录
regulatory_node 用以回答"有哪些任务/正在跑什么"。返回精简后的任务列表,
不包含 graph state、context 等大字段
regulatory_node 用以回答"之前那份报告呢""昨天那个查询结果是什么"
返回 task 表中的精简元数据列表(不含工作流的 graph state、context 等
注:此处的 "task" 是 regulatory_node 完成的轻量短任务(出报告/写文件/查询整理等),
与 workflow(多步骤复杂任务)是两套独立体系。如需查工作流进度,使用 query_workflow_status。
"""
from typing import Any, Dict, List, Optional
@@ -14,41 +17,40 @@ async def query_task_list(
status_filter: Optional[str] = None,
limit: int = 20,
) -> Dict[str, Any]:
"""列出当前用户的工作流任务
"""列出当前用户的短任务记录,按时间倒序
Args:
user_id: 用户 ID(通常由调用方从对话上下文中带入)
status_filter: 可选,按状态过滤(pending/running/completed/failed
status_filter: 可选,按状态过滤(running/completed/failed
limit: 最多返回条数,默认 20
Returns:
{
"user_id": str,
"tasks": [
{"trace_id": ..., "title": ..., "status": ..., "command": ..., "created_at": ...}
{"task_id": ..., "title": ..., "status": ...,
"result_summary": ..., "created_at": ...}
],
"total": int
}
"""
pg = ray_actor_hook("postgres_database").postgres_database
workflows = await pg.list_workflows.remote(user_id) or []
rows: List[Dict[str, Any]] = await pg.list_tasks_by_user.remote(
user_id=user_id,
status=status_filter,
limit=limit,
) or []
tasks: List[Dict[str, Any]] = []
for wf in workflows:
status = getattr(wf, "status", None)
if status_filter and status != status_filter:
continue
tasks.append(
{
"trace_id": getattr(wf, "trace_id", None),
"title": getattr(wf, "title", None),
"status": status,
"command": getattr(wf, "command", None),
"created_at": str(getattr(wf, "created_at", "")),
}
)
if len(tasks) >= limit:
break
tasks = [
{
"task_id": r.get("task_id"),
"title": r.get("title"),
"status": r.get("status"),
"result_summary": r.get("result_summary"),
"created_at": r.get("created_at"),
}
for r in rows
]
return {
"user_id": user_id,
+1 -1
View File
@@ -38,7 +38,7 @@ KiloStar/
│ │ │ ├── regulatory_node/ # 监管节点:直面用户对话、质量把关
│ │ │ │ ├── regulatory_node.py
│ │ │ │ └── template.py
│ │ │ ├── control_node/ # 控制节点:工作流节点内路由调度
│ │ │ ├── control_node/ # 控制节点(已废弃,名字保留给未来远程探针节点)
│ │ │ │ ├── control_node.py
│ │ │ │ └── template.py
│ │ │ └── growth_node/ # 生长节点:能力自扩展(占位)
+1 -1
View File
@@ -64,7 +64,7 @@ Mode is set via `KILOSTAR_MODE` env var. Entry point `main.py` branches into `st
### Backend Layout (`kilostar/`)
- `api/` — FastAPI routers (auth, chat, agent, workflow, system, resource, platform)
- `core/individual/` — 4 node types: RegulatoryNode (user-facing QA), ConsciousnessNode (planning), ControlNode (routing), GrowthNode (capability expansion)
- `core/individual/` — 4 node types: RegulatoryNode (user-facing QA + short tasks), ConsciousnessNode (workflow planning), ControlNode (deprecated; name reserved for future remote-probe node), GrowthNode (capability expansion, not yet implemented)
- `core/global_state_machine/` — Provider registry, model config state
- `core/global_workflow_manager/` — Workflow queue & recovery
- `core/postgres_database/` — DAO layer: `model/` (SQLAlchemy models), `module/` (CRUD methods), `postgres.py` (facade)
+2
View File
@@ -36,6 +36,7 @@ from .resource import resource_router
from .workflow import workflow_router
from .chat import chat_router
from .plugin import plugin_router
from .task import task_router
from kilostar.utils.error import (
KiloStarError,
BusinessError,
@@ -105,6 +106,7 @@ app.include_router(agent_router) # agent路径
app.include_router(workflow_router) # workflow路径
app.include_router(chat_router) # chat路径
app.include_router(plugin_router) # plugin路径
app.include_router(task_router) # 短任务路径
@app.exception_handler(BusinessError)
+55
View File
@@ -0,0 +1,55 @@
# Copyright 2026 zhaoxi826
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
"""Task API:管控节点短任务的查询接口。
task 由 regulatory_node 在完成短任务时内部建立,因此本路由只暴露读取,
不开放对外创建。
"""
from typing import Optional
from fastapi import APIRouter, Depends, HTTPException
from kilostar.utils.access import Accessor, TokenData
from kilostar.utils.ray_hook import ray_actor_hook
task_router = APIRouter(prefix="/api/v1/task", tags=["task"])
@task_router.get("/list")
async def list_tasks(
status: Optional[str] = None,
limit: int = 20,
offset: int = 0,
token_data: TokenData = Depends(Accessor.get_current_user),
):
"""列出当前用户的所有短任务,按时间倒序。"""
postgres_database = ray_actor_hook("postgres_database").postgres_database
tasks = await postgres_database.list_tasks_by_user.remote(
user_id=token_data.user_id,
status=status,
limit=limit,
offset=offset,
)
return {"tasks": tasks, "total": len(tasks)}
@task_router.get("/{task_id}")
async def get_task(
task_id: str,
token_data: TokenData = Depends(Accessor.get_current_user),
):
"""按 task_id 读取一条 task 详情。仅 owner 可访问。"""
postgres_database = ray_actor_hook("postgres_database").postgres_database
task = await postgres_database.get_task.remote(task_id)
if not task:
raise HTTPException(status_code=404, detail="task not found")
if task.get("user_id") != token_data.user_id:
raise HTTPException(status_code=403, detail="forbidden")
return task
@@ -27,10 +27,12 @@ from kilostar.utils.prompts import agent_prompt
@actor_class
class ControlNode:
"""ControlNode(控制节点):工作流中具体子任务的执行 Actor
"""ControlNode(控制节点):**已废弃**——名字保留给未来的远程探针/系统控制节点
它把 ConsciousnessNode 编排出的 ``workflow_step`` 拿来当作输入,借助
pydantic-ai Agent + 已绑定的工具集合产出 ``ForWorkflow`` 结构化输出
历史:早期设计里它是工作流的"单步执行 actor",但 workflow_engine 的 Dispatch
最终只识别 ``consciousness_node`` 和 ``skill_individual``,本类从未真正被调用过
保留目录与类壳子,避免改名带来的 git 历史断层;**不要新增对它的依赖**。
待远程探针/监控流子项目启动时,本目录将被重写为远程机器控制节点。
"""
def __init__(self):
@@ -212,7 +212,42 @@ class RegulatoryNode:
response: MessageResponse = agent_response.output
response.platform = platform
response.platform_id = payload.platform_id
await self._maybe_persist_task(payload, response)
return response
except Exception as e:
self.logger.exception(f"RegulatoryNode._run failed: {e}")
return None
return None
async def _maybe_persist_task(
self, payload: MessageRequest, response: MessageResponse
) -> None:
"""LLM 自标注 task_action=create_task 时落一条 task 记录。
失败不抛错——task 表是辅助元数据,不能拖垮主回复链路。
"""
if response.task_action != "create_task":
return
if not response.task_title or not response.task_summary:
self.logger.warning(
"task_action=create_task 但 title/summary 为空,跳过落库"
)
return
try:
import uuid
from kilostar.utils.ray_hook import ray_actor_hook
postgres_database = ray_actor_hook("postgres_database").postgres_database
task_id = uuid.uuid4().hex
chat_id = payload.platform_id if payload.platform == "client" else None
await postgres_database.create_task.remote(
task_id=task_id,
user_id=payload.user_name,
command=payload.message,
title=response.task_title,
chat_id=chat_id,
status="completed",
result_summary=response.task_summary,
artifact_refs=None,
)
except Exception as e:
self.logger.warning(f"persist task failed (non-fatal): {e}")
@@ -62,3 +62,15 @@ class MessageResponse(RegulatoryNodeResponse):
platform: Optional[Literal["client", "onebot"]] = Field(description="系统自动填入的platform")
platform_id: Optional[str] = Field(description="系统自动填入的platform_id")
reply_message: str = Field(...,description="模型回复的消息")
task_action: Optional[Literal["create_task"]] = Field(
default=None,
description="本次回复是否完成了一个值得记录的短任务。生成文件/出报告/查询整理等填 'create_task',闲聊或简单问答留空。",
)
task_title: Optional[str] = Field(
default=None,
description="task 的简短标题(task_action=create_task 时必填,<=80 字)",
)
task_summary: Optional[str] = Field(
default=None,
description="task 的结果摘要(task_action=create_task 时必填,描述产出与去向)",
)
@@ -37,6 +37,7 @@ from kilostar.core.postgres_database.model.system_event_log import SystemEventLo
from kilostar.core.postgres_database.model.persona_template import PersonaTemplate
from kilostar.core.postgres_database.model.org_task import OrgTask
from kilostar.core.postgres_database.model.org_task_event import OrgTaskEvent
from kilostar.core.postgres_database.model.task import Task
# 兼容旧代码的别名
Provider = ProviderModel
@@ -69,5 +70,6 @@ __all__ = [
"PersonaTemplate",
"OrgTask",
"OrgTaskEvent",
"Task",
"AgentType",
]
@@ -0,0 +1,54 @@
# Copyright 2026 zhaoxi826
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
"""Task:管控节点(regulatory_node)完成的短任务记录。
与 workflow 不同,task 是上下文内能完成的轻量任务(写文件/查询/出报告等),
表里只存最终元数据 + 结果摘要 + 关联 artifact,不入库执行过程。
"""
from sqlalchemy import String, DateTime, Text, func
from sqlalchemy.orm import Mapped, mapped_column
from sqlalchemy.dialects.postgresql import JSONB
from .base import BaseDataModel
class Task(BaseDataModel):
__tablename__ = "task"
task_id: Mapped[str] = mapped_column(
String(64), primary_key=True, comment="任务唯一 IDUUID"
)
user_id: Mapped[str] = mapped_column(
String(64), index=True, comment="所属用户 ID"
)
chat_id: Mapped[str | None] = mapped_column(
String(64), index=True, nullable=True, comment="所属对话(如有)"
)
command: Mapped[str] = mapped_column(
Text, comment="用户原始指令"
)
title: Mapped[str] = mapped_column(
String(255), comment="任务简短标题(LLM 生成)"
)
status: Mapped[str] = mapped_column(
String(20), index=True, default="completed",
comment="running / completed / failed",
)
result_summary: Mapped[str | None] = mapped_column(
Text, nullable=True, comment="完成后的结果摘要"
)
artifact_refs: Mapped[list | None] = mapped_column(
JSONB, nullable=True, comment="关联的 artifact url 列表"
)
created_at: Mapped[str] = mapped_column(
DateTime(timezone=True), server_default=func.now(), index=True
)
updated_at: Mapped[str] = mapped_column(
DateTime(timezone=True), server_default=func.now(), onupdate=func.now()
)
@@ -0,0 +1,104 @@
# Copyright 2026 zhaoxi826
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
"""Task DAO:管控节点短任务的最小持久化层。"""
from __future__ import annotations
from typing import List, Optional
from sqlalchemy import select, desc, update
from sqlalchemy.ext.asyncio import async_sessionmaker, AsyncSession
from kilostar.core.postgres_database.model.task import Task
from kilostar.core.postgres_database.database_exception import database_exception
class TaskDatabase:
def __init__(self, async_session_maker: async_sessionmaker[AsyncSession]):
self.async_session_maker = async_session_maker
@database_exception
async def create_task(
self,
task_id: str,
user_id: str,
command: str,
title: str,
chat_id: Optional[str] = None,
status: str = "completed",
result_summary: Optional[str] = None,
artifact_refs: Optional[list] = None,
) -> None:
async with self.async_session_maker() as session:
row = Task(
task_id=task_id,
user_id=user_id,
chat_id=chat_id,
command=command,
title=title,
status=status,
result_summary=result_summary,
artifact_refs=artifact_refs,
)
session.add(row)
await session.commit()
@database_exception
async def update_status(
self,
task_id: str,
status: str,
result_summary: Optional[str] = None,
) -> None:
async with self.async_session_maker() as session:
values = {"status": status}
if result_summary is not None:
values["result_summary"] = result_summary
stmt = update(Task).where(Task.task_id == task_id).values(**values)
await session.execute(stmt)
await session.commit()
@database_exception
async def get_task(self, task_id: str) -> Optional[dict]:
async with self.async_session_maker() as session:
stmt = select(Task).where(Task.task_id == task_id)
row = (await session.execute(stmt)).scalar_one_or_none()
if not row:
return None
return _row_to_dict(row)
@database_exception
async def list_tasks_by_user(
self,
user_id: str,
status: Optional[str] = None,
limit: int = 20,
offset: int = 0,
) -> List[dict]:
async with self.async_session_maker() as session:
stmt = select(Task).where(Task.user_id == user_id)
if status:
stmt = stmt.where(Task.status == status)
stmt = stmt.order_by(desc(Task.created_at)).offset(offset).limit(limit)
rows = (await session.execute(stmt)).scalars().all()
return [_row_to_dict(r) for r in rows]
def _row_to_dict(row: Task) -> dict:
return {
"task_id": row.task_id,
"user_id": row.user_id,
"chat_id": row.chat_id,
"command": row.command,
"title": row.title,
"status": row.status,
"result_summary": row.result_summary,
"artifact_refs": row.artifact_refs or [],
"created_at": str(row.created_at) if row.created_at else None,
"updated_at": str(row.updated_at) if row.updated_at else None,
}
@@ -59,6 +59,7 @@ from .module.custom_toolset import CustomToolsetDatabase
from .module.system_event_log import SystemEventLogDatabase
from .module.persona_template import PersonaTemplateDatabase
from .module.org_task import OrgTaskDatabase
from .module.task import TaskDatabase
@actor_class
@@ -93,6 +94,7 @@ class PostgresDatabase:
self._system_event_log_database = SystemEventLogDatabase(self.async_session_maker)
self._persona_template_database = PersonaTemplateDatabase(self.async_session_maker)
self._org_task_database = OrgTaskDatabase(self.async_session_maker)
self._task_database = TaskDatabase(self.async_session_maker)
self.ready_event = asyncio.Event()
@@ -487,3 +489,49 @@ class PostgresDatabase:
async def query_org_events(self, task_id: str, limit=200):
await self.ready_event.wait()
return await self._org_task_database.query_events(task_id, limit)
# Task Methods(管控节点短任务)
async def create_task(
self,
task_id: str,
user_id: str,
command: str,
title: str,
chat_id: str | None = None,
status: str = "completed",
result_summary: str | None = None,
artifact_refs: list | None = None,
):
await self.ready_event.wait()
return await self._task_database.create_task(
task_id=task_id,
user_id=user_id,
command=command,
title=title,
chat_id=chat_id,
status=status,
result_summary=result_summary,
artifact_refs=artifact_refs,
)
async def update_task_status(
self, task_id: str, status: str, result_summary: str | None = None
):
await self.ready_event.wait()
return await self._task_database.update_status(task_id, status, result_summary)
async def get_task(self, task_id: str):
await self.ready_event.wait()
return await self._task_database.get_task(task_id)
async def list_tasks_by_user(
self,
user_id: str,
status: str | None = None,
limit: int = 20,
offset: int = 0,
):
await self.ready_event.wait()
return await self._task_database.list_tasks_by_user(
user_id=user_id, status=status, limit=limit, offset=offset
)
+1 -1
View File
@@ -1,7 +1,7 @@
"""把组织包装成 cabinet 可调用的高阶 tool。
每个组织 → 一个 ``dispatch_to_<org>(task_description)`` 工具。
ConsciousnessNode/ControlNode 通过这个工具向部门派单,等待部门完成。
RegulatoryNode/ConsciousnessNode 通过这个工具向部门派单,等待部门完成。
"""
from __future__ import annotations
+16 -2
View File
@@ -18,7 +18,13 @@ _PROMPTS: Dict[str, Dict[str, str]] = {
"1. 准确理解用户的意图,提供专业、友好且有帮助的回复。\n"
"2. 如果你有可用工具,可以主动调用工具来辅助回答(如搜索、文件操作等)。\n"
"3. 如果你收到工作流的执行报告,请将其转化为面向用户的清晰总结。\n"
"4. 保持回复简洁、有结构,避免冗余信息。\n"
"4. 保持回复简洁、有结构,避免冗余信息。\n\n"
"【关于短任务(task)】\n"
"如果本次回复完成了一个值得记录的'短任务'(生成文件/出报告/查询整理资料/写代码片段等具体产出),\n"
"请把 task_action 设为 'create_task',并填写:\n"
"- task_title:简短标题(<=80 字,例如 'Python 学习计划'\n"
"- task_summary:结果摘要(说明产出了什么、附件去向)\n"
"闲聊、打招呼、纯问答这类不留下产出物的回复,task_action 留空(None)。\n"
"请保持专业、友好的沟通风格。"
),
"en": (
@@ -28,7 +34,13 @@ _PROMPTS: Dict[str, Dict[str, str]] = {
"1. Accurately understand user intent and provide professional, friendly, and helpful replies.\n"
"2. If tools are available, proactively use them to assist your responses (e.g., search, file operations).\n"
"3. If you receive a workflow execution report, convert it into a clear user-facing summary.\n"
"4. Keep responses concise, well-structured, and free of redundancy.\n"
"4. Keep responses concise, well-structured, and free of redundancy.\n\n"
"[About short tasks]\n"
"If this reply completes a worth-recording short task (generating files / writing reports / collecting information / producing code snippets etc.),\n"
"set task_action to 'create_task' and fill:\n"
"- task_title: short title (<=80 chars, e.g. 'Python learning plan')\n"
"- task_summary: result summary (what was produced, where attachments live)\n"
"Leave task_action empty for chit-chat / greetings / plain Q&A that produce no artifact.\n"
"Maintain a professional and friendly communication style."
),
},
@@ -72,6 +84,8 @@ _PROMPTS: Dict[str, Dict[str, str]] = {
"Ensure all output is logical, rigorous, and high-quality."
),
},
# DEPRECATED: control_node 当前未被任何路径调用,保留 prompt 占位以便未来
# 改造为远程探针/系统控制节点时直接复用 key。
"control_node": {
"zh": (
"你叫kilostar,是一个多智能体AI助手系统中的【控制节点 (Control Node)】。\n"
+1 -56
View File
@@ -76,62 +76,7 @@ async def test_regulatory_run_swallows_exception_returns_none(regulatory_instanc
assert out is None
# ─── ControlNode ────────────────────────────────────────────────────────────
@pytest.fixture
def control_instance():
from kilostar.core.individual.control_node.control_node import ControlNode
cls = ControlNode.__ray_actor_class__
obj = cls.__new__(cls)
from kilostar.utils.logger import get_logger
obj.logger = get_logger("control_node")
obj.agent = None
obj._model_settings = {}
return obj
def _make_workflow_step():
from kilostar.core.work.workflow.workflow import WorkflowStep
return WorkflowStep(
step=1,
name="do something",
action="execute the thing",
inputs=None,
outputs="result",
)
@pytest.mark.asyncio
async def test_control_working_returns_for_workflow_output(control_instance):
from kilostar.core.individual.control_node.template import (
ForWorkflow,
ForWorkflowInput,
)
step = _make_workflow_step()
expected = ForWorkflow(output="done")
agent_run_result = SimpleNamespace(output=expected)
control_instance.agent = MagicMock()
control_instance.agent.run = AsyncMock(return_value=agent_run_result)
out = await control_instance.working(ForWorkflowInput(workflow_step=step))
assert out is expected
@pytest.mark.asyncio
async def test_control_working_swallows_exception_returns_none(control_instance):
from kilostar.core.individual.control_node.template import ForWorkflowInput
step = _make_workflow_step()
control_instance.agent = MagicMock()
control_instance.agent.run = AsyncMock(side_effect=RuntimeError("boom"))
out = await control_instance.working(ForWorkflowInput(workflow_step=step))
assert out is None
# ─── ControlNode 已废弃,相关 fixture 与测试已删除(保留目录壳子供未来改写) ──
# ─── ConsciousnessNode ──────────────────────────────────────────────────────
+2 -1
View File
@@ -66,8 +66,9 @@ def test_tavily_search_metadata():
tool = _get_tool_def(manifest, "tavily_search")
assert tool["is_system"] is False
assert tool["category"] == "search"
assert "control_node" in tool["action_scope"]
assert "consciousness_node" in tool["action_scope"]
assert "regulatory_node" in tool["action_scope"]
assert "control_node" not in tool["action_scope"]
assert "api_key" in tool["config_args"]
+97
View File
@@ -0,0 +1,97 @@
"""``TaskDatabase`` 单元测试:覆盖 create / get / list / update_status 路径。"""
from __future__ import annotations
import pytest
from unittest.mock import AsyncMock, MagicMock
from kilostar.core.postgres_database.module.task import TaskDatabase
def _make_db():
session = AsyncMock()
session.__aenter__ = AsyncMock(return_value=session)
session.__aexit__ = AsyncMock(return_value=False)
session_maker = MagicMock(return_value=session)
return TaskDatabase(session_maker), session
@pytest.mark.anyio
async def test_create_task_persists_row():
db, session = _make_db()
session.add = MagicMock()
session.commit = AsyncMock()
await db.create_task(
task_id="t1",
user_id="alice",
command="写一份周报",
title="Q2 周报",
chat_id="chat-1",
status="completed",
result_summary="已生成报告",
)
session.add.assert_called_once()
added = session.add.call_args[0][0]
assert added.task_id == "t1"
assert added.user_id == "alice"
assert added.title == "Q2 周报"
assert added.status == "completed"
session.commit.assert_awaited_once()
@pytest.mark.anyio
async def test_get_task_returns_none_when_missing():
db, session = _make_db()
execute_result = MagicMock()
execute_result.scalar_one_or_none.return_value = None
session.execute = AsyncMock(return_value=execute_result)
result = await db.get_task("missing")
assert result is None
@pytest.mark.anyio
async def test_list_tasks_by_user_filters_status():
"""传 status 时 SQL 应进入 status 过滤分支(execute 被调用一次即视为路径已走通)。"""
db, session = _make_db()
execute_result = MagicMock()
execute_result.scalars.return_value.all.return_value = []
session.execute = AsyncMock(return_value=execute_result)
result = await db.list_tasks_by_user(user_id="alice", status="completed", limit=10)
assert result == []
session.execute.assert_awaited_once()
@pytest.mark.anyio
async def test_list_tasks_by_user_no_status():
db, session = _make_db()
execute_result = MagicMock()
execute_result.scalars.return_value.all.return_value = []
session.execute = AsyncMock(return_value=execute_result)
await db.list_tasks_by_user(user_id="alice")
session.execute.assert_awaited_once()
@pytest.mark.anyio
async def test_update_status_with_summary():
db, session = _make_db()
session.execute = AsyncMock()
session.commit = AsyncMock()
await db.update_status("t1", status="failed", result_summary="出错")
session.execute.assert_awaited_once()
session.commit.assert_awaited_once()
@pytest.mark.anyio
async def test_update_status_without_summary():
db, session = _make_db()
session.execute = AsyncMock()
session.commit = AsyncMock()
await db.update_status("t1", status="running")
session.execute.assert_awaited_once()
session.commit.assert_awaited_once()