236 lines
8.8 KiB
Python
236 lines
8.8 KiB
Python
"""data_analytics 插件本地 SQLite 表与 DAO。
|
|
|
|
注意:本插件用的 ``DeclarativeBase`` 跟核心 PG 完全独立,避免元数据空间串场。
|
|
所有数据落到 ``data/plugin/data_analytics/_data/data_analytics.db``。
|
|
"""
|
|
|
|
from __future__ import annotations
|
|
|
|
from datetime import datetime
|
|
from typing import List, Optional
|
|
|
|
from sqlalchemy import DateTime, String, Text, select
|
|
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker
|
|
from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column
|
|
|
|
from kilostar.utils.crypto import decrypt_dict_secrets, encrypt_dict_secrets
|
|
|
|
|
|
class Base(DeclarativeBase):
|
|
"""data_analytics 插件私有的元数据空间,跟核心 PG 隔离。"""
|
|
|
|
pass
|
|
|
|
|
|
class S3Credential(Base):
|
|
__tablename__ = "s3_credential"
|
|
|
|
cred_id: Mapped[str] = mapped_column(String(64), primary_key=True)
|
|
user_id: Mapped[str] = mapped_column(String(64), index=True, nullable=False)
|
|
display_name: Mapped[str] = mapped_column(String(100), nullable=False)
|
|
endpoint_url: Mapped[Optional[str]] = mapped_column(String(255), nullable=True)
|
|
region: Mapped[str] = mapped_column(String(50), default="us-east-1")
|
|
access_key: Mapped[str] = mapped_column(String(255), nullable=False)
|
|
secret_key: Mapped[str] = mapped_column(String(255), nullable=False)
|
|
created_at: Mapped[datetime] = mapped_column(
|
|
DateTime, default=datetime.utcnow, nullable=False
|
|
)
|
|
updated_at: Mapped[datetime] = mapped_column(
|
|
DateTime, default=datetime.utcnow, onupdate=datetime.utcnow, nullable=False
|
|
)
|
|
|
|
|
|
class AnalysisJob(Base):
|
|
__tablename__ = "analysis_job"
|
|
|
|
job_id: Mapped[str] = mapped_column(String(64), primary_key=True)
|
|
user_id: Mapped[str] = mapped_column(String(64), index=True, nullable=False)
|
|
cred_id: Mapped[Optional[str]] = mapped_column(String(64), nullable=True)
|
|
description: Mapped[str] = mapped_column(Text, nullable=False)
|
|
status: Mapped[str] = mapped_column(String(20), default="pending", index=True)
|
|
org_task_id: Mapped[Optional[str]] = mapped_column(String(64), nullable=True)
|
|
result: Mapped[Optional[str]] = mapped_column(Text, nullable=True)
|
|
created_at: Mapped[datetime] = mapped_column(
|
|
DateTime, default=datetime.utcnow, nullable=False, index=True
|
|
)
|
|
updated_at: Mapped[datetime] = mapped_column(
|
|
DateTime, default=datetime.utcnow, onupdate=datetime.utcnow, nullable=False
|
|
)
|
|
|
|
|
|
class CredentialDAO:
|
|
"""S3 凭证 DAO:写入时自动加密,读取时自动解密。"""
|
|
|
|
SENSITIVE_KEYS = ("access_key", "secret_key")
|
|
|
|
def __init__(self, sm: async_sessionmaker[AsyncSession]):
|
|
self._sm = sm
|
|
|
|
@staticmethod
|
|
def _row_to_dict(row: S3Credential, *, include_secrets: bool) -> dict:
|
|
d = {
|
|
"cred_id": row.cred_id,
|
|
"user_id": row.user_id,
|
|
"display_name": row.display_name,
|
|
"endpoint_url": row.endpoint_url,
|
|
"region": row.region,
|
|
"access_key": row.access_key,
|
|
"secret_key": row.secret_key,
|
|
"created_at": row.created_at.isoformat() if row.created_at else None,
|
|
"updated_at": row.updated_at.isoformat() if row.updated_at else None,
|
|
}
|
|
if not include_secrets:
|
|
ak = decrypt_dict_secrets({"access_key": d["access_key"]}).get("access_key", "")
|
|
d["access_key"] = (ak[:4] + "***" + ak[-2:]) if len(ak) > 6 else "***"
|
|
d.pop("secret_key", None)
|
|
return d
|
|
# include_secrets=True 用于工具内部,返回明文给 boto3
|
|
return decrypt_dict_secrets(d)
|
|
|
|
async def list_by_user(self, user_id: str) -> List[dict]:
|
|
async with self._sm() as s:
|
|
stmt = select(S3Credential).where(S3Credential.user_id == user_id)
|
|
rows = (await s.execute(stmt)).scalars().all()
|
|
return [self._row_to_dict(r, include_secrets=False) for r in rows]
|
|
|
|
async def get(self, cred_id: str, *, include_secrets: bool = False) -> Optional[dict]:
|
|
async with self._sm() as s:
|
|
stmt = select(S3Credential).where(S3Credential.cred_id == cred_id)
|
|
row = (await s.execute(stmt)).scalar_one_or_none()
|
|
if row is None:
|
|
return None
|
|
return self._row_to_dict(row, include_secrets=include_secrets)
|
|
|
|
async def upsert(
|
|
self,
|
|
cred_id: str,
|
|
user_id: str,
|
|
display_name: str,
|
|
access_key: str,
|
|
secret_key: str,
|
|
endpoint_url: Optional[str] = None,
|
|
region: str = "us-east-1",
|
|
) -> dict:
|
|
encrypted = encrypt_dict_secrets(
|
|
{"access_key": access_key, "secret_key": secret_key}
|
|
)
|
|
async with self._sm() as s:
|
|
stmt = select(S3Credential).where(S3Credential.cred_id == cred_id)
|
|
existing = (await s.execute(stmt)).scalar_one_or_none()
|
|
if existing is not None:
|
|
existing.display_name = display_name
|
|
existing.endpoint_url = endpoint_url
|
|
existing.region = region
|
|
existing.access_key = encrypted["access_key"]
|
|
existing.secret_key = encrypted["secret_key"]
|
|
s.add(existing)
|
|
await s.commit()
|
|
await s.refresh(existing)
|
|
return self._row_to_dict(existing, include_secrets=False)
|
|
row = S3Credential(
|
|
cred_id=cred_id,
|
|
user_id=user_id,
|
|
display_name=display_name,
|
|
endpoint_url=endpoint_url,
|
|
region=region,
|
|
access_key=encrypted["access_key"],
|
|
secret_key=encrypted["secret_key"],
|
|
)
|
|
s.add(row)
|
|
await s.commit()
|
|
await s.refresh(row)
|
|
return self._row_to_dict(row, include_secrets=False)
|
|
|
|
async def delete(self, cred_id: str, user_id: str) -> bool:
|
|
async with self._sm() as s:
|
|
stmt = select(S3Credential).where(
|
|
S3Credential.cred_id == cred_id, S3Credential.user_id == user_id
|
|
)
|
|
row = (await s.execute(stmt)).scalar_one_or_none()
|
|
if row is None:
|
|
return False
|
|
await s.delete(row)
|
|
await s.commit()
|
|
return True
|
|
|
|
|
|
class JobDAO:
|
|
"""分析任务记录 DAO。"""
|
|
|
|
def __init__(self, sm: async_sessionmaker[AsyncSession]):
|
|
self._sm = sm
|
|
|
|
@staticmethod
|
|
def _row_to_dict(row: AnalysisJob) -> dict:
|
|
return {
|
|
"job_id": row.job_id,
|
|
"user_id": row.user_id,
|
|
"cred_id": row.cred_id,
|
|
"description": row.description,
|
|
"status": row.status,
|
|
"org_task_id": row.org_task_id,
|
|
"result": row.result,
|
|
"created_at": row.created_at.isoformat() if row.created_at else None,
|
|
"updated_at": row.updated_at.isoformat() if row.updated_at else None,
|
|
}
|
|
|
|
async def create(
|
|
self,
|
|
job_id: str,
|
|
user_id: str,
|
|
description: str,
|
|
cred_id: Optional[str] = None,
|
|
) -> dict:
|
|
async with self._sm() as s:
|
|
row = AnalysisJob(
|
|
job_id=job_id,
|
|
user_id=user_id,
|
|
description=description,
|
|
cred_id=cred_id,
|
|
)
|
|
s.add(row)
|
|
await s.commit()
|
|
await s.refresh(row)
|
|
return self._row_to_dict(row)
|
|
|
|
async def update(
|
|
self,
|
|
job_id: str,
|
|
*,
|
|
status: Optional[str] = None,
|
|
result: Optional[str] = None,
|
|
org_task_id: Optional[str] = None,
|
|
) -> Optional[dict]:
|
|
async with self._sm() as s:
|
|
stmt = select(AnalysisJob).where(AnalysisJob.job_id == job_id)
|
|
row = (await s.execute(stmt)).scalar_one_or_none()
|
|
if row is None:
|
|
return None
|
|
if status is not None:
|
|
row.status = status
|
|
if result is not None:
|
|
row.result = result
|
|
if org_task_id is not None:
|
|
row.org_task_id = org_task_id
|
|
s.add(row)
|
|
await s.commit()
|
|
await s.refresh(row)
|
|
return self._row_to_dict(row)
|
|
|
|
async def list_by_user(self, user_id: str, limit: int = 50) -> List[dict]:
|
|
async with self._sm() as s:
|
|
stmt = (
|
|
select(AnalysisJob)
|
|
.where(AnalysisJob.user_id == user_id)
|
|
.order_by(AnalysisJob.created_at.desc())
|
|
.limit(limit)
|
|
)
|
|
rows = (await s.execute(stmt)).scalars().all()
|
|
return [self._row_to_dict(r) for r in rows]
|
|
|
|
async def get(self, job_id: str) -> Optional[dict]:
|
|
async with self._sm() as s:
|
|
stmt = select(AnalysisJob).where(AnalysisJob.job_id == job_id)
|
|
row = (await s.execute(stmt)).scalar_one_or_none()
|
|
return self._row_to_dict(row) if row else None
|