存档
This commit is contained in:
@@ -0,0 +1 @@
|
||||
"""data_analytics organization 实现。"""
|
||||
@@ -0,0 +1,235 @@
|
||||
"""data_analytics 插件本地 SQLite 表与 DAO。
|
||||
|
||||
注意:本插件用的 ``DeclarativeBase`` 跟核心 PG 完全独立,避免元数据空间串场。
|
||||
所有数据落到 ``data/plugin/data_analytics/_data/data_analytics.db``。
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import datetime
|
||||
from typing import List, Optional
|
||||
|
||||
from sqlalchemy import DateTime, String, Text, select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker
|
||||
from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column
|
||||
|
||||
from kilostar.utils.crypto import decrypt_dict_secrets, encrypt_dict_secrets
|
||||
|
||||
|
||||
class Base(DeclarativeBase):
|
||||
"""data_analytics 插件私有的元数据空间,跟核心 PG 隔离。"""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class S3Credential(Base):
|
||||
__tablename__ = "s3_credential"
|
||||
|
||||
cred_id: Mapped[str] = mapped_column(String(64), primary_key=True)
|
||||
user_id: Mapped[str] = mapped_column(String(64), index=True, nullable=False)
|
||||
display_name: Mapped[str] = mapped_column(String(100), nullable=False)
|
||||
endpoint_url: Mapped[Optional[str]] = mapped_column(String(255), nullable=True)
|
||||
region: Mapped[str] = mapped_column(String(50), default="us-east-1")
|
||||
access_key: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||
secret_key: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||
created_at: Mapped[datetime] = mapped_column(
|
||||
DateTime, default=datetime.utcnow, nullable=False
|
||||
)
|
||||
updated_at: Mapped[datetime] = mapped_column(
|
||||
DateTime, default=datetime.utcnow, onupdate=datetime.utcnow, nullable=False
|
||||
)
|
||||
|
||||
|
||||
class AnalysisJob(Base):
|
||||
__tablename__ = "analysis_job"
|
||||
|
||||
job_id: Mapped[str] = mapped_column(String(64), primary_key=True)
|
||||
user_id: Mapped[str] = mapped_column(String(64), index=True, nullable=False)
|
||||
cred_id: Mapped[Optional[str]] = mapped_column(String(64), nullable=True)
|
||||
description: Mapped[str] = mapped_column(Text, nullable=False)
|
||||
status: Mapped[str] = mapped_column(String(20), default="pending", index=True)
|
||||
org_task_id: Mapped[Optional[str]] = mapped_column(String(64), nullable=True)
|
||||
result: Mapped[Optional[str]] = mapped_column(Text, nullable=True)
|
||||
created_at: Mapped[datetime] = mapped_column(
|
||||
DateTime, default=datetime.utcnow, nullable=False, index=True
|
||||
)
|
||||
updated_at: Mapped[datetime] = mapped_column(
|
||||
DateTime, default=datetime.utcnow, onupdate=datetime.utcnow, nullable=False
|
||||
)
|
||||
|
||||
|
||||
class CredentialDAO:
|
||||
"""S3 凭证 DAO:写入时自动加密,读取时自动解密。"""
|
||||
|
||||
SENSITIVE_KEYS = ("access_key", "secret_key")
|
||||
|
||||
def __init__(self, sm: async_sessionmaker[AsyncSession]):
|
||||
self._sm = sm
|
||||
|
||||
@staticmethod
|
||||
def _row_to_dict(row: S3Credential, *, include_secrets: bool) -> dict:
|
||||
d = {
|
||||
"cred_id": row.cred_id,
|
||||
"user_id": row.user_id,
|
||||
"display_name": row.display_name,
|
||||
"endpoint_url": row.endpoint_url,
|
||||
"region": row.region,
|
||||
"access_key": row.access_key,
|
||||
"secret_key": row.secret_key,
|
||||
"created_at": row.created_at.isoformat() if row.created_at else None,
|
||||
"updated_at": row.updated_at.isoformat() if row.updated_at else None,
|
||||
}
|
||||
if not include_secrets:
|
||||
ak = decrypt_dict_secrets({"access_key": d["access_key"]}).get("access_key", "")
|
||||
d["access_key"] = (ak[:4] + "***" + ak[-2:]) if len(ak) > 6 else "***"
|
||||
d.pop("secret_key", None)
|
||||
return d
|
||||
# include_secrets=True 用于工具内部,返回明文给 boto3
|
||||
return decrypt_dict_secrets(d)
|
||||
|
||||
async def list_by_user(self, user_id: str) -> List[dict]:
|
||||
async with self._sm() as s:
|
||||
stmt = select(S3Credential).where(S3Credential.user_id == user_id)
|
||||
rows = (await s.execute(stmt)).scalars().all()
|
||||
return [self._row_to_dict(r, include_secrets=False) for r in rows]
|
||||
|
||||
async def get(self, cred_id: str, *, include_secrets: bool = False) -> Optional[dict]:
|
||||
async with self._sm() as s:
|
||||
stmt = select(S3Credential).where(S3Credential.cred_id == cred_id)
|
||||
row = (await s.execute(stmt)).scalar_one_or_none()
|
||||
if row is None:
|
||||
return None
|
||||
return self._row_to_dict(row, include_secrets=include_secrets)
|
||||
|
||||
async def upsert(
|
||||
self,
|
||||
cred_id: str,
|
||||
user_id: str,
|
||||
display_name: str,
|
||||
access_key: str,
|
||||
secret_key: str,
|
||||
endpoint_url: Optional[str] = None,
|
||||
region: str = "us-east-1",
|
||||
) -> dict:
|
||||
encrypted = encrypt_dict_secrets(
|
||||
{"access_key": access_key, "secret_key": secret_key}
|
||||
)
|
||||
async with self._sm() as s:
|
||||
stmt = select(S3Credential).where(S3Credential.cred_id == cred_id)
|
||||
existing = (await s.execute(stmt)).scalar_one_or_none()
|
||||
if existing is not None:
|
||||
existing.display_name = display_name
|
||||
existing.endpoint_url = endpoint_url
|
||||
existing.region = region
|
||||
existing.access_key = encrypted["access_key"]
|
||||
existing.secret_key = encrypted["secret_key"]
|
||||
s.add(existing)
|
||||
await s.commit()
|
||||
await s.refresh(existing)
|
||||
return self._row_to_dict(existing, include_secrets=False)
|
||||
row = S3Credential(
|
||||
cred_id=cred_id,
|
||||
user_id=user_id,
|
||||
display_name=display_name,
|
||||
endpoint_url=endpoint_url,
|
||||
region=region,
|
||||
access_key=encrypted["access_key"],
|
||||
secret_key=encrypted["secret_key"],
|
||||
)
|
||||
s.add(row)
|
||||
await s.commit()
|
||||
await s.refresh(row)
|
||||
return self._row_to_dict(row, include_secrets=False)
|
||||
|
||||
async def delete(self, cred_id: str, user_id: str) -> bool:
|
||||
async with self._sm() as s:
|
||||
stmt = select(S3Credential).where(
|
||||
S3Credential.cred_id == cred_id, S3Credential.user_id == user_id
|
||||
)
|
||||
row = (await s.execute(stmt)).scalar_one_or_none()
|
||||
if row is None:
|
||||
return False
|
||||
await s.delete(row)
|
||||
await s.commit()
|
||||
return True
|
||||
|
||||
|
||||
class JobDAO:
|
||||
"""分析任务记录 DAO。"""
|
||||
|
||||
def __init__(self, sm: async_sessionmaker[AsyncSession]):
|
||||
self._sm = sm
|
||||
|
||||
@staticmethod
|
||||
def _row_to_dict(row: AnalysisJob) -> dict:
|
||||
return {
|
||||
"job_id": row.job_id,
|
||||
"user_id": row.user_id,
|
||||
"cred_id": row.cred_id,
|
||||
"description": row.description,
|
||||
"status": row.status,
|
||||
"org_task_id": row.org_task_id,
|
||||
"result": row.result,
|
||||
"created_at": row.created_at.isoformat() if row.created_at else None,
|
||||
"updated_at": row.updated_at.isoformat() if row.updated_at else None,
|
||||
}
|
||||
|
||||
async def create(
|
||||
self,
|
||||
job_id: str,
|
||||
user_id: str,
|
||||
description: str,
|
||||
cred_id: Optional[str] = None,
|
||||
) -> dict:
|
||||
async with self._sm() as s:
|
||||
row = AnalysisJob(
|
||||
job_id=job_id,
|
||||
user_id=user_id,
|
||||
description=description,
|
||||
cred_id=cred_id,
|
||||
)
|
||||
s.add(row)
|
||||
await s.commit()
|
||||
await s.refresh(row)
|
||||
return self._row_to_dict(row)
|
||||
|
||||
async def update(
|
||||
self,
|
||||
job_id: str,
|
||||
*,
|
||||
status: Optional[str] = None,
|
||||
result: Optional[str] = None,
|
||||
org_task_id: Optional[str] = None,
|
||||
) -> Optional[dict]:
|
||||
async with self._sm() as s:
|
||||
stmt = select(AnalysisJob).where(AnalysisJob.job_id == job_id)
|
||||
row = (await s.execute(stmt)).scalar_one_or_none()
|
||||
if row is None:
|
||||
return None
|
||||
if status is not None:
|
||||
row.status = status
|
||||
if result is not None:
|
||||
row.result = result
|
||||
if org_task_id is not None:
|
||||
row.org_task_id = org_task_id
|
||||
s.add(row)
|
||||
await s.commit()
|
||||
await s.refresh(row)
|
||||
return self._row_to_dict(row)
|
||||
|
||||
async def list_by_user(self, user_id: str, limit: int = 50) -> List[dict]:
|
||||
async with self._sm() as s:
|
||||
stmt = (
|
||||
select(AnalysisJob)
|
||||
.where(AnalysisJob.user_id == user_id)
|
||||
.order_by(AnalysisJob.created_at.desc())
|
||||
.limit(limit)
|
||||
)
|
||||
rows = (await s.execute(stmt)).scalars().all()
|
||||
return [self._row_to_dict(r) for r in rows]
|
||||
|
||||
async def get(self, job_id: str) -> Optional[dict]:
|
||||
async with self._sm() as s:
|
||||
stmt = select(AnalysisJob).where(AnalysisJob.job_id == job_id)
|
||||
row = (await s.execute(stmt)).scalar_one_or_none()
|
||||
return self._row_to_dict(row) if row else None
|
||||
@@ -0,0 +1,135 @@
|
||||
"""data_analytics organization:管理本插件的 SQLite 元数据 + 注入凭证 ctx。
|
||||
|
||||
凭证经由 ``S3_CREDS_VAR`` ContextVar 传给工具,避免污染 agent tool signature
|
||||
(agent 看到的工具不带 cred 参数,模型不会误传)。
|
||||
|
||||
API 层通过本类暴露的 ``cred_*`` / ``job_*`` 代理方法跨 actor 调 DAO,
|
||||
保证分布式模式下 actor 之间不直接共享 SQLAlchemy session。
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import contextvars
|
||||
import uuid
|
||||
from typing import Any, Callable, Dict, List, Optional
|
||||
|
||||
from kilostar.plugin_runtime.base_organization import BaseOrganization
|
||||
from kilostar.plugin_runtime.event import OrgEvent
|
||||
|
||||
from .db import Base, CredentialDAO, JobDAO
|
||||
|
||||
# 当前任务的 S3 凭证(明文):工具内部读 .get() 拿
|
||||
S3_CREDS_VAR: contextvars.ContextVar[Optional[Dict[str, Any]]] = contextvars.ContextVar(
|
||||
"data_analytics_s3_creds", default=None
|
||||
)
|
||||
|
||||
|
||||
class DataAnalyticsOrganization(BaseOrganization):
|
||||
"""对接 S3 的数据分析组织。"""
|
||||
|
||||
async def setup(self) -> None:
|
||||
await super().setup()
|
||||
await self.init_local_db([Base])
|
||||
# 跨工具/跨 API 共享的 DAO 实例
|
||||
self.cred_dao = CredentialDAO(self._session_maker)
|
||||
self.job_dao = JobDAO(self._session_maker)
|
||||
|
||||
async def on_first_install(self) -> None:
|
||||
self.logger.info(
|
||||
"data_analytics installed; configure S3 credentials in dashboard."
|
||||
)
|
||||
|
||||
async def react(
|
||||
self,
|
||||
task_description: str,
|
||||
ctx: Dict[str, Any],
|
||||
emit: Callable[[OrgEvent], Any],
|
||||
) -> Any:
|
||||
cred_id = ctx.get("cred_id")
|
||||
if cred_id and getattr(self, "cred_dao", None) is not None:
|
||||
cred = await self.cred_dao.get(cred_id, include_secrets=True)
|
||||
if cred is None:
|
||||
raise RuntimeError(f"S3 凭证 {cred_id} 不存在")
|
||||
S3_CREDS_VAR.set(cred)
|
||||
ctx["s3_cred_display"] = cred.get("display_name")
|
||||
else:
|
||||
S3_CREDS_VAR.set(None)
|
||||
return await super().react(task_description, ctx, emit)
|
||||
|
||||
# ─── 凭证代理(API 层调用) ─────────────────────────────────────
|
||||
|
||||
async def cred_list(self, user_id: str) -> List[dict]:
|
||||
return await self.cred_dao.list_by_user(user_id)
|
||||
|
||||
async def cred_create(
|
||||
self,
|
||||
user_id: str,
|
||||
display_name: str,
|
||||
access_key: str,
|
||||
secret_key: str,
|
||||
endpoint_url: Optional[str] = None,
|
||||
region: str = "us-east-1",
|
||||
) -> dict:
|
||||
cred_id = uuid.uuid4().hex
|
||||
return await self.cred_dao.upsert(
|
||||
cred_id=cred_id,
|
||||
user_id=user_id,
|
||||
display_name=display_name,
|
||||
access_key=access_key,
|
||||
secret_key=secret_key,
|
||||
endpoint_url=endpoint_url,
|
||||
region=region,
|
||||
)
|
||||
|
||||
async def cred_delete(self, cred_id: str, user_id: str) -> bool:
|
||||
return await self.cred_dao.delete(cred_id, user_id)
|
||||
|
||||
# ─── 任务代理 ──────────────────────────────────────────────────
|
||||
|
||||
async def job_create(
|
||||
self, user_id: str, cred_id: str, description: str
|
||||
) -> dict:
|
||||
# 校验凭证归属
|
||||
cred = await self.cred_dao.get(cred_id, include_secrets=False)
|
||||
if cred is None or cred.get("user_id") != user_id:
|
||||
raise ValueError("凭证不存在或不属于当前用户")
|
||||
|
||||
job_id = uuid.uuid4().hex
|
||||
await self.job_dao.create(
|
||||
job_id=job_id,
|
||||
user_id=user_id,
|
||||
description=description,
|
||||
cred_id=cred_id,
|
||||
)
|
||||
# 投递 organization 任务(拿 task_id 回填,便于前端拉事件流)
|
||||
task_id = await self.submit(
|
||||
description, {"user_id": user_id, "cred_id": cred_id, "job_id": job_id}
|
||||
)
|
||||
await self.job_dao.update(job_id, status="running", org_task_id=task_id)
|
||||
return {"job_id": job_id, "task_id": task_id, "status": "running"}
|
||||
|
||||
async def job_list(self, user_id: str) -> List[dict]:
|
||||
return await self.job_dao.list_by_user(user_id)
|
||||
|
||||
async def job_get(self, job_id: str, user_id: str) -> Optional[dict]:
|
||||
row = await self.job_dao.get(job_id)
|
||||
if row is None or row.get("user_id") != user_id:
|
||||
return None
|
||||
# 附带最新 organization 状态
|
||||
org_task_id = row.get("org_task_id")
|
||||
if org_task_id:
|
||||
ts = await self.status(org_task_id)
|
||||
if ts is not None:
|
||||
row["task_status"] = ts.get("status")
|
||||
row["task_result"] = ts.get("result")
|
||||
row["task_error"] = ts.get("error")
|
||||
# 任务终态时把结果回写 SQLite,方便重启后查询
|
||||
if ts.get("status") in ("completed", "failed") and row.get("status") != ts.get("status"):
|
||||
result_payload = ts.get("result") if ts.get("status") == "completed" else ts.get("error")
|
||||
await self.job_dao.update(
|
||||
job_id,
|
||||
status=ts.get("status"),
|
||||
result=str(result_payload) if result_payload is not None else None,
|
||||
)
|
||||
row["status"] = ts.get("status")
|
||||
return row
|
||||
Reference in New Issue
Block a user