# Copyright 2026 zhaoxi826 # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 """Task DAO:管控节点短任务的最小持久化层。""" from __future__ import annotations from typing import List, Optional from sqlalchemy import select, desc, update from sqlalchemy.ext.asyncio import async_sessionmaker, AsyncSession from kilostar.core.postgres_database.model.task import Task from kilostar.core.postgres_database.database_exception import database_exception class TaskDatabase: def __init__(self, async_session_maker: async_sessionmaker[AsyncSession]): self.async_session_maker = async_session_maker @database_exception async def create_task( self, task_id: str, user_id: str, command: str, title: str, chat_id: Optional[str] = None, status: str = "completed", result_summary: Optional[str] = None, artifact_refs: Optional[list] = None, ) -> None: async with self.async_session_maker() as session: row = Task( task_id=task_id, user_id=user_id, chat_id=chat_id, command=command, title=title, status=status, result_summary=result_summary, artifact_refs=artifact_refs, ) session.add(row) await session.commit() @database_exception async def update_status( self, task_id: str, status: str, result_summary: Optional[str] = None, ) -> None: async with self.async_session_maker() as session: values = {"status": status} if result_summary is not None: values["result_summary"] = result_summary stmt = update(Task).where(Task.task_id == task_id).values(**values) await session.execute(stmt) await session.commit() @database_exception async def get_task(self, task_id: str) -> Optional[dict]: async with self.async_session_maker() as session: stmt = select(Task).where(Task.task_id == task_id) row = (await session.execute(stmt)).scalar_one_or_none() if not row: return None return _row_to_dict(row) @database_exception async def list_tasks_by_user( self, user_id: str, status: Optional[str] = None, limit: int = 20, offset: int = 0, ) -> List[dict]: async with self.async_session_maker() as session: stmt = select(Task).where(Task.user_id == user_id) if status: stmt = stmt.where(Task.status == status) stmt = stmt.order_by(desc(Task.created_at)).offset(offset).limit(limit) rows = (await session.execute(stmt)).scalars().all() return [_row_to_dict(r) for r in rows] def _row_to_dict(row: Task) -> dict: return { "task_id": row.task_id, "user_id": row.user_id, "chat_id": row.chat_id, "command": row.command, "title": row.title, "status": row.status, "result_summary": row.result_summary, "artifact_refs": row.artifact_refs or [], "created_at": str(row.created_at) if row.created_at else None, "updated_at": str(row.updated_at) if row.updated_at else None, }