feat(system):优化后端
1.新增后端测试 2.增加了后端的加密 3.增加了i18n(国际化)
This commit is contained in:
@@ -0,0 +1,81 @@
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from sqlalchemy import delete, select
|
||||
|
||||
from kilostar.core.postgres_database.database_exception import database_exception
|
||||
from kilostar.core.postgres_database.model.custom_toolset import CustomToolsetModel
|
||||
|
||||
|
||||
class CustomToolsetDatabase:
|
||||
"""用户自定义工具组 DAO。``tools`` 字段是工具名列表,业务层负责保证只放非 system/非 mcp 的工具。"""
|
||||
|
||||
def __init__(self, async_session_maker):
|
||||
self.async_session_maker = async_session_maker
|
||||
|
||||
@staticmethod
|
||||
def _to_dict(row: CustomToolsetModel) -> Dict[str, Any]:
|
||||
return {
|
||||
"toolset_id": row.toolset_id,
|
||||
"name": row.name,
|
||||
"description": row.description,
|
||||
"owner_id": row.owner_id,
|
||||
"tools": list(row.tools or []),
|
||||
}
|
||||
|
||||
@database_exception
|
||||
async def upsert(
|
||||
self,
|
||||
toolset_id: str,
|
||||
name: str,
|
||||
tools: List[str],
|
||||
description: Optional[str] = None,
|
||||
owner_id: Optional[str] = None,
|
||||
) -> Dict[str, Any]:
|
||||
async with self.async_session_maker() as session:
|
||||
stmt = select(CustomToolsetModel).where(
|
||||
CustomToolsetModel.toolset_id == toolset_id
|
||||
)
|
||||
row = (await session.execute(stmt)).scalar_one_or_none()
|
||||
if row:
|
||||
row.name = name
|
||||
row.description = description
|
||||
row.owner_id = owner_id
|
||||
row.tools = list(tools)
|
||||
else:
|
||||
row = CustomToolsetModel(
|
||||
toolset_id=toolset_id,
|
||||
name=name,
|
||||
description=description,
|
||||
owner_id=owner_id,
|
||||
tools=list(tools),
|
||||
)
|
||||
session.add(row)
|
||||
await session.commit()
|
||||
await session.refresh(row)
|
||||
return self._to_dict(row)
|
||||
|
||||
@database_exception
|
||||
async def get(self, toolset_id: str) -> Optional[Dict[str, Any]]:
|
||||
async with self.async_session_maker() as session:
|
||||
stmt = select(CustomToolsetModel).where(
|
||||
CustomToolsetModel.toolset_id == toolset_id
|
||||
)
|
||||
row = (await session.execute(stmt)).scalar_one_or_none()
|
||||
return self._to_dict(row) if row else None
|
||||
|
||||
@database_exception
|
||||
async def list_all(self) -> List[Dict[str, Any]]:
|
||||
async with self.async_session_maker() as session:
|
||||
stmt = select(CustomToolsetModel)
|
||||
rows = (await session.execute(stmt)).scalars().all()
|
||||
return [self._to_dict(r) for r in rows]
|
||||
|
||||
@database_exception
|
||||
async def delete(self, toolset_id: str) -> bool:
|
||||
async with self.async_session_maker() as session:
|
||||
stmt = delete(CustomToolsetModel).where(
|
||||
CustomToolsetModel.toolset_id == toolset_id
|
||||
)
|
||||
result = await session.execute(stmt)
|
||||
await session.commit()
|
||||
return result.rowcount > 0
|
||||
@@ -0,0 +1,79 @@
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from sqlalchemy import delete, select
|
||||
|
||||
from kilostar.core.postgres_database.database_exception import database_exception
|
||||
from kilostar.core.postgres_database.model.mcp_server import MCPServerModel
|
||||
from kilostar.utils.crypto import decrypt_dict_secrets, encrypt_dict_secrets
|
||||
|
||||
|
||||
class MCPServerDatabase:
|
||||
"""MCP 服务器配置 DAO;写入前自动加密 ``env`` 中的敏感字段,读出后自动解密。"""
|
||||
|
||||
def __init__(self, async_session_maker):
|
||||
self.async_session_maker = async_session_maker
|
||||
|
||||
@staticmethod
|
||||
def _to_dict(row: MCPServerModel) -> Dict[str, Any]:
|
||||
return {
|
||||
"server_id": row.server_id,
|
||||
"name": row.name,
|
||||
"transport": row.transport,
|
||||
"command": row.command,
|
||||
"args": row.args or [],
|
||||
"url": row.url,
|
||||
"tool_prefix": row.tool_prefix,
|
||||
"env": decrypt_dict_secrets(row.env or {}),
|
||||
}
|
||||
|
||||
@database_exception
|
||||
async def upsert(self, server_id: str, config: Dict[str, Any]) -> Dict[str, Any]:
|
||||
env = encrypt_dict_secrets(config.get("env") or {})
|
||||
async with self.async_session_maker() as session:
|
||||
stmt = select(MCPServerModel).where(MCPServerModel.server_id == server_id)
|
||||
row = (await session.execute(stmt)).scalar_one_or_none()
|
||||
if row:
|
||||
row.name = config.get("name", row.name)
|
||||
row.transport = config.get("transport", row.transport)
|
||||
row.command = config.get("command")
|
||||
row.args = config.get("args") or []
|
||||
row.url = config.get("url")
|
||||
row.tool_prefix = config.get("tool_prefix")
|
||||
row.env = env
|
||||
else:
|
||||
row = MCPServerModel(
|
||||
server_id=server_id,
|
||||
name=config.get("name", server_id),
|
||||
transport=config.get("transport", "stdio"),
|
||||
command=config.get("command"),
|
||||
args=config.get("args") or [],
|
||||
url=config.get("url"),
|
||||
tool_prefix=config.get("tool_prefix"),
|
||||
env=env,
|
||||
)
|
||||
session.add(row)
|
||||
await session.commit()
|
||||
await session.refresh(row)
|
||||
return self._to_dict(row)
|
||||
|
||||
@database_exception
|
||||
async def get(self, server_id: str) -> Optional[Dict[str, Any]]:
|
||||
async with self.async_session_maker() as session:
|
||||
stmt = select(MCPServerModel).where(MCPServerModel.server_id == server_id)
|
||||
row = (await session.execute(stmt)).scalar_one_or_none()
|
||||
return self._to_dict(row) if row else None
|
||||
|
||||
@database_exception
|
||||
async def list_all(self) -> List[Dict[str, Any]]:
|
||||
async with self.async_session_maker() as session:
|
||||
stmt = select(MCPServerModel)
|
||||
rows = (await session.execute(stmt)).scalars().all()
|
||||
return [self._to_dict(r) for r in rows]
|
||||
|
||||
@database_exception
|
||||
async def delete(self, server_id: str) -> bool:
|
||||
async with self.async_session_maker() as session:
|
||||
stmt = delete(MCPServerModel).where(MCPServerModel.server_id == server_id)
|
||||
result = await session.execute(stmt)
|
||||
await session.commit()
|
||||
return result.rowcount > 0
|
||||
@@ -17,10 +17,37 @@ from typing import List
|
||||
from kilostar.core.postgres_database.model.provider import ProviderModel
|
||||
from sqlalchemy import select
|
||||
from kilostar.core.postgres_database.database_exception import database_exception
|
||||
from kilostar.utils.crypto import (
|
||||
CryptoError,
|
||||
decrypt_secret,
|
||||
encrypt_secret,
|
||||
is_encrypted,
|
||||
)
|
||||
from kilostar.utils.logger import get_logger
|
||||
|
||||
logger = get_logger("provider_dao")
|
||||
|
||||
|
||||
def _decrypt_apikey(value):
|
||||
if not value:
|
||||
return value
|
||||
if not is_encrypted(value):
|
||||
return value
|
||||
try:
|
||||
return decrypt_secret(value)
|
||||
except CryptoError as e:
|
||||
logger.error(f"Provider apikey 解密失败: {e}")
|
||||
return value
|
||||
|
||||
|
||||
def _encrypt_apikey(value):
|
||||
if not value or is_encrypted(value):
|
||||
return value
|
||||
return encrypt_secret(value)
|
||||
|
||||
|
||||
class ProviderDatabase:
|
||||
"""Provider 表的 DAO:模型 Provider 的增删查改。"""
|
||||
"""Provider 表的 DAO:模型 Provider 的增删查改;``provider_apikey`` 透明 Fernet 加解密。"""
|
||||
|
||||
def __init__(self, async_session_maker):
|
||||
self.async_session_maker = async_session_maker
|
||||
@@ -37,11 +64,10 @@ class ProviderDatabase:
|
||||
provider_id=provider.provider_id,
|
||||
provider_title=provider.provider_title,
|
||||
provider_url=provider.provider_url,
|
||||
provider_apikey=provider.provider_apikey,
|
||||
provider_apikey=_decrypt_apikey(provider.provider_apikey),
|
||||
provider_models=provider.provider_models,
|
||||
provider_type=provider.provider_type,
|
||||
provider_owner=provider.provider_owner,
|
||||
provider_status=provider.provider_status,
|
||||
is_active=provider.is_active,
|
||||
)
|
||||
for provider in results
|
||||
@@ -50,7 +76,9 @@ class ProviderDatabase:
|
||||
|
||||
@database_exception
|
||||
async def add_provider(self, **kwargs) -> None:
|
||||
"""新建一条 Provider 记录;字段通过 kwargs 直接传给 ProviderModel。"""
|
||||
"""新建一条 Provider 记录;``provider_apikey`` 写入前自动加密。"""
|
||||
if "provider_apikey" in kwargs:
|
||||
kwargs["provider_apikey"] = _encrypt_apikey(kwargs["provider_apikey"])
|
||||
async with self.async_session_maker() as session:
|
||||
provider = ProviderModel(**kwargs)
|
||||
session.add(provider)
|
||||
@@ -67,7 +95,9 @@ class ProviderDatabase:
|
||||
|
||||
@database_exception
|
||||
async def update_provider(self, provider_id: str, **kwargs) -> None:
|
||||
"""部分更新指定 Provider 的字段;不存在时返回 None,否则返回刷新后的对象。"""
|
||||
"""部分更新指定 Provider 的字段;``provider_apikey`` 写入前自动加密。"""
|
||||
if "provider_apikey" in kwargs:
|
||||
kwargs["provider_apikey"] = _encrypt_apikey(kwargs["provider_apikey"])
|
||||
async with self.async_session_maker() as session:
|
||||
provider = await session.get(ProviderModel, provider_id)
|
||||
if provider is not None:
|
||||
@@ -76,5 +106,7 @@ class ProviderDatabase:
|
||||
session.add(provider)
|
||||
await session.commit()
|
||||
await session.refresh(provider)
|
||||
# 解密返回,方便上游使用
|
||||
provider.provider_apikey = _decrypt_apikey(provider.provider_apikey)
|
||||
return provider
|
||||
return None
|
||||
|
||||
@@ -0,0 +1,58 @@
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from sqlalchemy import delete, select
|
||||
|
||||
from kilostar.core.postgres_database.database_exception import database_exception
|
||||
from kilostar.core.postgres_database.model.tool_config import ToolConfigModel
|
||||
from kilostar.utils.crypto import decrypt_dict_secrets, encrypt_dict_secrets
|
||||
|
||||
|
||||
class ToolConfigDatabase:
|
||||
"""工具运行期配置 DAO;config 中的敏感字段(key/token/secret/password 系列)自动加解密。"""
|
||||
|
||||
def __init__(self, async_session_maker):
|
||||
self.async_session_maker = async_session_maker
|
||||
|
||||
@staticmethod
|
||||
def _to_dict(row: ToolConfigModel) -> Dict[str, Any]:
|
||||
return {
|
||||
"tool_name": row.tool_name,
|
||||
"config": decrypt_dict_secrets(row.config or {}),
|
||||
}
|
||||
|
||||
@database_exception
|
||||
async def upsert(self, tool_name: str, config: Dict[str, Any]) -> Dict[str, Any]:
|
||||
encrypted = encrypt_dict_secrets(config or {})
|
||||
async with self.async_session_maker() as session:
|
||||
stmt = select(ToolConfigModel).where(ToolConfigModel.tool_name == tool_name)
|
||||
row = (await session.execute(stmt)).scalar_one_or_none()
|
||||
if row:
|
||||
row.config = encrypted
|
||||
else:
|
||||
row = ToolConfigModel(tool_name=tool_name, config=encrypted)
|
||||
session.add(row)
|
||||
await session.commit()
|
||||
await session.refresh(row)
|
||||
return self._to_dict(row)
|
||||
|
||||
@database_exception
|
||||
async def get(self, tool_name: str) -> Optional[Dict[str, Any]]:
|
||||
async with self.async_session_maker() as session:
|
||||
stmt = select(ToolConfigModel).where(ToolConfigModel.tool_name == tool_name)
|
||||
row = (await session.execute(stmt)).scalar_one_or_none()
|
||||
return self._to_dict(row) if row else None
|
||||
|
||||
@database_exception
|
||||
async def list_all(self) -> List[Dict[str, Any]]:
|
||||
async with self.async_session_maker() as session:
|
||||
stmt = select(ToolConfigModel)
|
||||
rows = (await session.execute(stmt)).scalars().all()
|
||||
return [self._to_dict(r) for r in rows]
|
||||
|
||||
@database_exception
|
||||
async def delete(self, tool_name: str) -> bool:
|
||||
async with self.async_session_maker() as session:
|
||||
stmt = delete(ToolConfigModel).where(ToolConfigModel.tool_name == tool_name)
|
||||
result = await session.execute(stmt)
|
||||
await session.commit()
|
||||
return result.rowcount > 0
|
||||
@@ -17,6 +17,7 @@ from typing import List, Optional
|
||||
from kilostar.core.postgres_database.model.workflow import (
|
||||
Workflow,
|
||||
WorkflowContextModel,
|
||||
WorkflowGraphStateModel,
|
||||
)
|
||||
from sqlalchemy.ext.asyncio import async_sessionmaker, AsyncSession
|
||||
from kilostar.core.postgres_database.database_exception import database_exception
|
||||
@@ -101,3 +102,58 @@ class WorkflowDatabase:
|
||||
)
|
||||
results = await session.execute(statement)
|
||||
return results.scalar_one_or_none()
|
||||
|
||||
# ─── pydantic_graph 持久化(resume 用)─────────────────────────────
|
||||
|
||||
@database_exception
|
||||
async def upsert_workflow_graph_state(
|
||||
self, trace_id: str, history: list
|
||||
) -> WorkflowGraphStateModel:
|
||||
"""落 pydantic_graph FullStatePersistence.history 的 JSON 视图。
|
||||
|
||||
每个节点边界都会被引擎调一次,覆盖式写入;回滚到任一历史点是 graph
|
||||
引擎自身的能力,DB 这层只保留最新版本。
|
||||
"""
|
||||
async with self.async_session_maker() as session:
|
||||
statement = select(WorkflowGraphStateModel).where(
|
||||
WorkflowGraphStateModel.trace_id == trace_id
|
||||
)
|
||||
results = await session.execute(statement)
|
||||
record = results.scalar_one_or_none()
|
||||
if record:
|
||||
record.history = history
|
||||
else:
|
||||
record = WorkflowGraphStateModel(
|
||||
trace_id=trace_id, history=history
|
||||
)
|
||||
session.add(record)
|
||||
await session.commit()
|
||||
await session.refresh(record)
|
||||
return record
|
||||
|
||||
@database_exception
|
||||
async def get_workflow_graph_state(
|
||||
self, trace_id: str
|
||||
) -> Optional[WorkflowGraphStateModel]:
|
||||
"""读取 graph 持久化 history;不存在返回 ``None``。"""
|
||||
async with self.async_session_maker() as session:
|
||||
statement = select(WorkflowGraphStateModel).where(
|
||||
WorkflowGraphStateModel.trace_id == trace_id
|
||||
)
|
||||
results = await session.execute(statement)
|
||||
return results.scalar_one_or_none()
|
||||
|
||||
@database_exception
|
||||
async def delete_workflow_graph_state(self, trace_id: str) -> bool:
|
||||
"""删除某个工作流的 graph 持久化记录(用于显式清理)。"""
|
||||
async with self.async_session_maker() as session:
|
||||
statement = select(WorkflowGraphStateModel).where(
|
||||
WorkflowGraphStateModel.trace_id == trace_id
|
||||
)
|
||||
results = await session.execute(statement)
|
||||
record = results.scalar_one_or_none()
|
||||
if record is None:
|
||||
return False
|
||||
await session.delete(record)
|
||||
await session.commit()
|
||||
return True
|
||||
|
||||
Reference in New Issue
Block a user