"""``TaskDatabase`` 单元测试:覆盖 create / get / list / update_status 路径。""" from __future__ import annotations import pytest from unittest.mock import AsyncMock, MagicMock from kilostar.core.postgres_database.module.task import TaskDatabase def _make_db(): session = AsyncMock() session.__aenter__ = AsyncMock(return_value=session) session.__aexit__ = AsyncMock(return_value=False) session_maker = MagicMock(return_value=session) return TaskDatabase(session_maker), session @pytest.mark.anyio async def test_create_task_persists_row(): db, session = _make_db() session.add = MagicMock() session.commit = AsyncMock() await db.create_task( task_id="t1", user_id="alice", command="写一份周报", title="Q2 周报", chat_id="chat-1", status="completed", result_summary="已生成报告", ) session.add.assert_called_once() added = session.add.call_args[0][0] assert added.task_id == "t1" assert added.user_id == "alice" assert added.title == "Q2 周报" assert added.status == "completed" session.commit.assert_awaited_once() @pytest.mark.anyio async def test_get_task_returns_none_when_missing(): db, session = _make_db() execute_result = MagicMock() execute_result.scalar_one_or_none.return_value = None session.execute = AsyncMock(return_value=execute_result) result = await db.get_task("missing") assert result is None @pytest.mark.anyio async def test_list_tasks_by_user_filters_status(): """传 status 时 SQL 应进入 status 过滤分支(execute 被调用一次即视为路径已走通)。""" db, session = _make_db() execute_result = MagicMock() execute_result.scalars.return_value.all.return_value = [] session.execute = AsyncMock(return_value=execute_result) result = await db.list_tasks_by_user(user_id="alice", status="completed", limit=10) assert result == [] session.execute.assert_awaited_once() @pytest.mark.anyio async def test_list_tasks_by_user_no_status(): db, session = _make_db() execute_result = MagicMock() execute_result.scalars.return_value.all.return_value = [] session.execute = AsyncMock(return_value=execute_result) await db.list_tasks_by_user(user_id="alice") session.execute.assert_awaited_once() @pytest.mark.anyio async def test_update_status_with_summary(): db, session = _make_db() session.execute = AsyncMock() session.commit = AsyncMock() await db.update_status("t1", status="failed", result_summary="出错") session.execute.assert_awaited_once() session.commit.assert_awaited_once() @pytest.mark.anyio async def test_update_status_without_summary(): db, session = _make_db() session.execute = AsyncMock() session.commit = AsyncMock() await db.update_status("t1", status="running") session.execute.assert_awaited_once() session.commit.assert_awaited_once()