存档
This commit is contained in:
@@ -0,0 +1,235 @@
|
||||
"""data_analytics 插件本地 SQLite 表与 DAO。
|
||||
|
||||
注意:本插件用的 ``DeclarativeBase`` 跟核心 PG 完全独立,避免元数据空间串场。
|
||||
所有数据落到 ``data/plugin/data_analytics/_data/data_analytics.db``。
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import datetime
|
||||
from typing import List, Optional
|
||||
|
||||
from sqlalchemy import DateTime, String, Text, select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker
|
||||
from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column
|
||||
|
||||
from kilostar.utils.crypto import decrypt_dict_secrets, encrypt_dict_secrets
|
||||
|
||||
|
||||
class Base(DeclarativeBase):
|
||||
"""data_analytics 插件私有的元数据空间,跟核心 PG 隔离。"""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class S3Credential(Base):
|
||||
__tablename__ = "s3_credential"
|
||||
|
||||
cred_id: Mapped[str] = mapped_column(String(64), primary_key=True)
|
||||
user_id: Mapped[str] = mapped_column(String(64), index=True, nullable=False)
|
||||
display_name: Mapped[str] = mapped_column(String(100), nullable=False)
|
||||
endpoint_url: Mapped[Optional[str]] = mapped_column(String(255), nullable=True)
|
||||
region: Mapped[str] = mapped_column(String(50), default="us-east-1")
|
||||
access_key: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||
secret_key: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||
created_at: Mapped[datetime] = mapped_column(
|
||||
DateTime, default=datetime.utcnow, nullable=False
|
||||
)
|
||||
updated_at: Mapped[datetime] = mapped_column(
|
||||
DateTime, default=datetime.utcnow, onupdate=datetime.utcnow, nullable=False
|
||||
)
|
||||
|
||||
|
||||
class AnalysisJob(Base):
|
||||
__tablename__ = "analysis_job"
|
||||
|
||||
job_id: Mapped[str] = mapped_column(String(64), primary_key=True)
|
||||
user_id: Mapped[str] = mapped_column(String(64), index=True, nullable=False)
|
||||
cred_id: Mapped[Optional[str]] = mapped_column(String(64), nullable=True)
|
||||
description: Mapped[str] = mapped_column(Text, nullable=False)
|
||||
status: Mapped[str] = mapped_column(String(20), default="pending", index=True)
|
||||
org_task_id: Mapped[Optional[str]] = mapped_column(String(64), nullable=True)
|
||||
result: Mapped[Optional[str]] = mapped_column(Text, nullable=True)
|
||||
created_at: Mapped[datetime] = mapped_column(
|
||||
DateTime, default=datetime.utcnow, nullable=False, index=True
|
||||
)
|
||||
updated_at: Mapped[datetime] = mapped_column(
|
||||
DateTime, default=datetime.utcnow, onupdate=datetime.utcnow, nullable=False
|
||||
)
|
||||
|
||||
|
||||
class CredentialDAO:
|
||||
"""S3 凭证 DAO:写入时自动加密,读取时自动解密。"""
|
||||
|
||||
SENSITIVE_KEYS = ("access_key", "secret_key")
|
||||
|
||||
def __init__(self, sm: async_sessionmaker[AsyncSession]):
|
||||
self._sm = sm
|
||||
|
||||
@staticmethod
|
||||
def _row_to_dict(row: S3Credential, *, include_secrets: bool) -> dict:
|
||||
d = {
|
||||
"cred_id": row.cred_id,
|
||||
"user_id": row.user_id,
|
||||
"display_name": row.display_name,
|
||||
"endpoint_url": row.endpoint_url,
|
||||
"region": row.region,
|
||||
"access_key": row.access_key,
|
||||
"secret_key": row.secret_key,
|
||||
"created_at": row.created_at.isoformat() if row.created_at else None,
|
||||
"updated_at": row.updated_at.isoformat() if row.updated_at else None,
|
||||
}
|
||||
if not include_secrets:
|
||||
ak = decrypt_dict_secrets({"access_key": d["access_key"]}).get("access_key", "")
|
||||
d["access_key"] = (ak[:4] + "***" + ak[-2:]) if len(ak) > 6 else "***"
|
||||
d.pop("secret_key", None)
|
||||
return d
|
||||
# include_secrets=True 用于工具内部,返回明文给 boto3
|
||||
return decrypt_dict_secrets(d)
|
||||
|
||||
async def list_by_user(self, user_id: str) -> List[dict]:
|
||||
async with self._sm() as s:
|
||||
stmt = select(S3Credential).where(S3Credential.user_id == user_id)
|
||||
rows = (await s.execute(stmt)).scalars().all()
|
||||
return [self._row_to_dict(r, include_secrets=False) for r in rows]
|
||||
|
||||
async def get(self, cred_id: str, *, include_secrets: bool = False) -> Optional[dict]:
|
||||
async with self._sm() as s:
|
||||
stmt = select(S3Credential).where(S3Credential.cred_id == cred_id)
|
||||
row = (await s.execute(stmt)).scalar_one_or_none()
|
||||
if row is None:
|
||||
return None
|
||||
return self._row_to_dict(row, include_secrets=include_secrets)
|
||||
|
||||
async def upsert(
|
||||
self,
|
||||
cred_id: str,
|
||||
user_id: str,
|
||||
display_name: str,
|
||||
access_key: str,
|
||||
secret_key: str,
|
||||
endpoint_url: Optional[str] = None,
|
||||
region: str = "us-east-1",
|
||||
) -> dict:
|
||||
encrypted = encrypt_dict_secrets(
|
||||
{"access_key": access_key, "secret_key": secret_key}
|
||||
)
|
||||
async with self._sm() as s:
|
||||
stmt = select(S3Credential).where(S3Credential.cred_id == cred_id)
|
||||
existing = (await s.execute(stmt)).scalar_one_or_none()
|
||||
if existing is not None:
|
||||
existing.display_name = display_name
|
||||
existing.endpoint_url = endpoint_url
|
||||
existing.region = region
|
||||
existing.access_key = encrypted["access_key"]
|
||||
existing.secret_key = encrypted["secret_key"]
|
||||
s.add(existing)
|
||||
await s.commit()
|
||||
await s.refresh(existing)
|
||||
return self._row_to_dict(existing, include_secrets=False)
|
||||
row = S3Credential(
|
||||
cred_id=cred_id,
|
||||
user_id=user_id,
|
||||
display_name=display_name,
|
||||
endpoint_url=endpoint_url,
|
||||
region=region,
|
||||
access_key=encrypted["access_key"],
|
||||
secret_key=encrypted["secret_key"],
|
||||
)
|
||||
s.add(row)
|
||||
await s.commit()
|
||||
await s.refresh(row)
|
||||
return self._row_to_dict(row, include_secrets=False)
|
||||
|
||||
async def delete(self, cred_id: str, user_id: str) -> bool:
|
||||
async with self._sm() as s:
|
||||
stmt = select(S3Credential).where(
|
||||
S3Credential.cred_id == cred_id, S3Credential.user_id == user_id
|
||||
)
|
||||
row = (await s.execute(stmt)).scalar_one_or_none()
|
||||
if row is None:
|
||||
return False
|
||||
await s.delete(row)
|
||||
await s.commit()
|
||||
return True
|
||||
|
||||
|
||||
class JobDAO:
|
||||
"""分析任务记录 DAO。"""
|
||||
|
||||
def __init__(self, sm: async_sessionmaker[AsyncSession]):
|
||||
self._sm = sm
|
||||
|
||||
@staticmethod
|
||||
def _row_to_dict(row: AnalysisJob) -> dict:
|
||||
return {
|
||||
"job_id": row.job_id,
|
||||
"user_id": row.user_id,
|
||||
"cred_id": row.cred_id,
|
||||
"description": row.description,
|
||||
"status": row.status,
|
||||
"org_task_id": row.org_task_id,
|
||||
"result": row.result,
|
||||
"created_at": row.created_at.isoformat() if row.created_at else None,
|
||||
"updated_at": row.updated_at.isoformat() if row.updated_at else None,
|
||||
}
|
||||
|
||||
async def create(
|
||||
self,
|
||||
job_id: str,
|
||||
user_id: str,
|
||||
description: str,
|
||||
cred_id: Optional[str] = None,
|
||||
) -> dict:
|
||||
async with self._sm() as s:
|
||||
row = AnalysisJob(
|
||||
job_id=job_id,
|
||||
user_id=user_id,
|
||||
description=description,
|
||||
cred_id=cred_id,
|
||||
)
|
||||
s.add(row)
|
||||
await s.commit()
|
||||
await s.refresh(row)
|
||||
return self._row_to_dict(row)
|
||||
|
||||
async def update(
|
||||
self,
|
||||
job_id: str,
|
||||
*,
|
||||
status: Optional[str] = None,
|
||||
result: Optional[str] = None,
|
||||
org_task_id: Optional[str] = None,
|
||||
) -> Optional[dict]:
|
||||
async with self._sm() as s:
|
||||
stmt = select(AnalysisJob).where(AnalysisJob.job_id == job_id)
|
||||
row = (await s.execute(stmt)).scalar_one_or_none()
|
||||
if row is None:
|
||||
return None
|
||||
if status is not None:
|
||||
row.status = status
|
||||
if result is not None:
|
||||
row.result = result
|
||||
if org_task_id is not None:
|
||||
row.org_task_id = org_task_id
|
||||
s.add(row)
|
||||
await s.commit()
|
||||
await s.refresh(row)
|
||||
return self._row_to_dict(row)
|
||||
|
||||
async def list_by_user(self, user_id: str, limit: int = 50) -> List[dict]:
|
||||
async with self._sm() as s:
|
||||
stmt = (
|
||||
select(AnalysisJob)
|
||||
.where(AnalysisJob.user_id == user_id)
|
||||
.order_by(AnalysisJob.created_at.desc())
|
||||
.limit(limit)
|
||||
)
|
||||
rows = (await s.execute(stmt)).scalars().all()
|
||||
return [self._row_to_dict(r) for r in rows]
|
||||
|
||||
async def get(self, job_id: str) -> Optional[dict]:
|
||||
async with self._sm() as s:
|
||||
stmt = select(AnalysisJob).where(AnalysisJob.job_id == job_id)
|
||||
row = (await s.execute(stmt)).scalar_one_or_none()
|
||||
return self._row_to_dict(row) if row else None
|
||||
Reference in New Issue
Block a user