Files
2026-07-01 09:22:26 +00:00

236 lines
8.8 KiB
Python

"""data_analytics 插件本地 SQLite 表与 DAO。
注意:本插件用的 ``DeclarativeBase`` 跟核心 PG 完全独立,避免元数据空间串场。
所有数据落到 ``data/plugin/data_analytics/_data/data_analytics.db``。
"""
from __future__ import annotations
from datetime import datetime
from typing import List, Optional
from sqlalchemy import DateTime, String, Text, select
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker
from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column
from kilostar.utils.crypto import decrypt_dict_secrets, encrypt_dict_secrets
class Base(DeclarativeBase):
"""data_analytics 插件私有的元数据空间,跟核心 PG 隔离。"""
pass
class S3Credential(Base):
__tablename__ = "s3_credential"
cred_id: Mapped[str] = mapped_column(String(64), primary_key=True)
user_id: Mapped[str] = mapped_column(String(64), index=True, nullable=False)
display_name: Mapped[str] = mapped_column(String(100), nullable=False)
endpoint_url: Mapped[Optional[str]] = mapped_column(String(255), nullable=True)
region: Mapped[str] = mapped_column(String(50), default="us-east-1")
access_key: Mapped[str] = mapped_column(String(255), nullable=False)
secret_key: Mapped[str] = mapped_column(String(255), nullable=False)
created_at: Mapped[datetime] = mapped_column(
DateTime, default=datetime.utcnow, nullable=False
)
updated_at: Mapped[datetime] = mapped_column(
DateTime, default=datetime.utcnow, onupdate=datetime.utcnow, nullable=False
)
class AnalysisJob(Base):
__tablename__ = "analysis_job"
job_id: Mapped[str] = mapped_column(String(64), primary_key=True)
user_id: Mapped[str] = mapped_column(String(64), index=True, nullable=False)
cred_id: Mapped[Optional[str]] = mapped_column(String(64), nullable=True)
description: Mapped[str] = mapped_column(Text, nullable=False)
status: Mapped[str] = mapped_column(String(20), default="pending", index=True)
org_task_id: Mapped[Optional[str]] = mapped_column(String(64), nullable=True)
result: Mapped[Optional[str]] = mapped_column(Text, nullable=True)
created_at: Mapped[datetime] = mapped_column(
DateTime, default=datetime.utcnow, nullable=False, index=True
)
updated_at: Mapped[datetime] = mapped_column(
DateTime, default=datetime.utcnow, onupdate=datetime.utcnow, nullable=False
)
class CredentialDAO:
"""S3 凭证 DAO:写入时自动加密,读取时自动解密。"""
SENSITIVE_KEYS = ("access_key", "secret_key")
def __init__(self, sm: async_sessionmaker[AsyncSession]):
self._sm = sm
@staticmethod
def _row_to_dict(row: S3Credential, *, include_secrets: bool) -> dict:
d = {
"cred_id": row.cred_id,
"user_id": row.user_id,
"display_name": row.display_name,
"endpoint_url": row.endpoint_url,
"region": row.region,
"access_key": row.access_key,
"secret_key": row.secret_key,
"created_at": row.created_at.isoformat() if row.created_at else None,
"updated_at": row.updated_at.isoformat() if row.updated_at else None,
}
if not include_secrets:
ak = decrypt_dict_secrets({"access_key": d["access_key"]}).get("access_key", "")
d["access_key"] = (ak[:4] + "***" + ak[-2:]) if len(ak) > 6 else "***"
d.pop("secret_key", None)
return d
# include_secrets=True 用于工具内部,返回明文给 boto3
return decrypt_dict_secrets(d)
async def list_by_user(self, user_id: str) -> List[dict]:
async with self._sm() as s:
stmt = select(S3Credential).where(S3Credential.user_id == user_id)
rows = (await s.execute(stmt)).scalars().all()
return [self._row_to_dict(r, include_secrets=False) for r in rows]
async def get(self, cred_id: str, *, include_secrets: bool = False) -> Optional[dict]:
async with self._sm() as s:
stmt = select(S3Credential).where(S3Credential.cred_id == cred_id)
row = (await s.execute(stmt)).scalar_one_or_none()
if row is None:
return None
return self._row_to_dict(row, include_secrets=include_secrets)
async def upsert(
self,
cred_id: str,
user_id: str,
display_name: str,
access_key: str,
secret_key: str,
endpoint_url: Optional[str] = None,
region: str = "us-east-1",
) -> dict:
encrypted = encrypt_dict_secrets(
{"access_key": access_key, "secret_key": secret_key}
)
async with self._sm() as s:
stmt = select(S3Credential).where(S3Credential.cred_id == cred_id)
existing = (await s.execute(stmt)).scalar_one_or_none()
if existing is not None:
existing.display_name = display_name
existing.endpoint_url = endpoint_url
existing.region = region
existing.access_key = encrypted["access_key"]
existing.secret_key = encrypted["secret_key"]
s.add(existing)
await s.commit()
await s.refresh(existing)
return self._row_to_dict(existing, include_secrets=False)
row = S3Credential(
cred_id=cred_id,
user_id=user_id,
display_name=display_name,
endpoint_url=endpoint_url,
region=region,
access_key=encrypted["access_key"],
secret_key=encrypted["secret_key"],
)
s.add(row)
await s.commit()
await s.refresh(row)
return self._row_to_dict(row, include_secrets=False)
async def delete(self, cred_id: str, user_id: str) -> bool:
async with self._sm() as s:
stmt = select(S3Credential).where(
S3Credential.cred_id == cred_id, S3Credential.user_id == user_id
)
row = (await s.execute(stmt)).scalar_one_or_none()
if row is None:
return False
await s.delete(row)
await s.commit()
return True
class JobDAO:
"""分析任务记录 DAO。"""
def __init__(self, sm: async_sessionmaker[AsyncSession]):
self._sm = sm
@staticmethod
def _row_to_dict(row: AnalysisJob) -> dict:
return {
"job_id": row.job_id,
"user_id": row.user_id,
"cred_id": row.cred_id,
"description": row.description,
"status": row.status,
"org_task_id": row.org_task_id,
"result": row.result,
"created_at": row.created_at.isoformat() if row.created_at else None,
"updated_at": row.updated_at.isoformat() if row.updated_at else None,
}
async def create(
self,
job_id: str,
user_id: str,
description: str,
cred_id: Optional[str] = None,
) -> dict:
async with self._sm() as s:
row = AnalysisJob(
job_id=job_id,
user_id=user_id,
description=description,
cred_id=cred_id,
)
s.add(row)
await s.commit()
await s.refresh(row)
return self._row_to_dict(row)
async def update(
self,
job_id: str,
*,
status: Optional[str] = None,
result: Optional[str] = None,
org_task_id: Optional[str] = None,
) -> Optional[dict]:
async with self._sm() as s:
stmt = select(AnalysisJob).where(AnalysisJob.job_id == job_id)
row = (await s.execute(stmt)).scalar_one_or_none()
if row is None:
return None
if status is not None:
row.status = status
if result is not None:
row.result = result
if org_task_id is not None:
row.org_task_id = org_task_id
s.add(row)
await s.commit()
await s.refresh(row)
return self._row_to_dict(row)
async def list_by_user(self, user_id: str, limit: int = 50) -> List[dict]:
async with self._sm() as s:
stmt = (
select(AnalysisJob)
.where(AnalysisJob.user_id == user_id)
.order_by(AnalysisJob.created_at.desc())
.limit(limit)
)
rows = (await s.execute(stmt)).scalars().all()
return [self._row_to_dict(r) for r in rows]
async def get(self, job_id: str) -> Optional[dict]:
async with self._sm() as s:
stmt = select(AnalysisJob).where(AnalysisJob.job_id == job_id)
row = (await s.execute(stmt)).scalar_one_or_none()
return self._row_to_dict(row) if row else None