feat(system):优化后端
1.新增后端测试 2.增加了后端的加密 3.增加了i18n(国际化)
This commit is contained in:
@@ -23,12 +23,16 @@ from kilostar.core.postgres_database.model.individual import (
|
||||
from kilostar.core.postgres_database.model.workflow import (
|
||||
Workflow,
|
||||
WorkflowContextModel,
|
||||
WorkflowGraphStateModel,
|
||||
)
|
||||
from kilostar.core.postgres_database.model.chat_history import (
|
||||
ChatHistoryRegister,
|
||||
ChatHistoryMessage,
|
||||
)
|
||||
from kilostar.core.postgres_database.model.system_node import SystemNodeConfigModel
|
||||
from kilostar.core.postgres_database.model.mcp_server import MCPServerModel
|
||||
from kilostar.core.postgres_database.model.tool_config import ToolConfigModel
|
||||
from kilostar.core.postgres_database.model.custom_toolset import CustomToolsetModel
|
||||
|
||||
# 兼容旧代码的别名
|
||||
Provider = ProviderModel
|
||||
@@ -49,9 +53,13 @@ __all__ = [
|
||||
"SpecialIndividualModel",
|
||||
"Workflow",
|
||||
"WorkflowContextModel",
|
||||
"WorkflowGraphStateModel",
|
||||
"ChatHistoryRegister",
|
||||
"ChatHistoryMessage",
|
||||
"SystemNodeConfigModel",
|
||||
"SystemNodeConfig",
|
||||
"MCPServerModel",
|
||||
"ToolConfigModel",
|
||||
"CustomToolsetModel",
|
||||
"AgentType",
|
||||
]
|
||||
|
||||
@@ -0,0 +1,32 @@
|
||||
from datetime import datetime
|
||||
from typing import List, Optional
|
||||
|
||||
from sqlalchemy import String, Text, DateTime, func
|
||||
from sqlalchemy.dialects.postgresql import JSONB
|
||||
from sqlalchemy.orm import Mapped, mapped_column
|
||||
|
||||
from .base import BaseDataModel
|
||||
|
||||
|
||||
class CustomToolsetModel(BaseDataModel):
|
||||
"""用户自定义工具组:把若干个非 system / 非 mcp 的工具插件打包成一个 toolset。
|
||||
|
||||
``tools`` 字段保存工具名列表(即 ``plugin/tool_plugin/`` 下的目录名);
|
||||
GSM 启动时按列表把对应工具函数装进同一个 ``FunctionToolset``。
|
||||
"""
|
||||
|
||||
__tablename__ = "custom_toolset"
|
||||
|
||||
toolset_id: Mapped[str] = mapped_column(String(64), primary_key=True)
|
||||
name: Mapped[str] = mapped_column(String(100), nullable=False)
|
||||
description: Mapped[Optional[str]] = mapped_column(Text)
|
||||
owner_id: Mapped[Optional[str]] = mapped_column(String(64), index=True)
|
||||
tools: Mapped[List[str]] = mapped_column(
|
||||
JSONB, default=list, comment="工具名列表,仅允许非 system/非 mcp 的工具"
|
||||
)
|
||||
created_at: Mapped[datetime] = mapped_column(
|
||||
DateTime(timezone=True), server_default=func.now()
|
||||
)
|
||||
updated_at: Mapped[datetime] = mapped_column(
|
||||
DateTime(timezone=True), server_default=func.now(), onupdate=func.now()
|
||||
)
|
||||
@@ -0,0 +1,29 @@
|
||||
from typing import Optional
|
||||
from datetime import datetime
|
||||
|
||||
from sqlalchemy import String, DateTime, func
|
||||
from sqlalchemy.dialects.postgresql import JSONB
|
||||
from sqlalchemy.orm import Mapped, mapped_column
|
||||
|
||||
from .base import BaseDataModel
|
||||
|
||||
|
||||
class MCPServerModel(BaseDataModel):
|
||||
"""MCP 服务器注册表,记录 stdio/sse/http 三种 transport 的连接配置。"""
|
||||
|
||||
__tablename__ = "mcp_server"
|
||||
|
||||
server_id: Mapped[str] = mapped_column(String(64), primary_key=True)
|
||||
name: Mapped[str] = mapped_column(String(100), nullable=False)
|
||||
transport: Mapped[str] = mapped_column(String(16), nullable=False)
|
||||
command: Mapped[Optional[str]] = mapped_column(String(255))
|
||||
args: Mapped[list] = mapped_column(JSONB, default=list)
|
||||
url: Mapped[Optional[str]] = mapped_column(String(500))
|
||||
tool_prefix: Mapped[Optional[str]] = mapped_column(String(64))
|
||||
env: Mapped[dict] = mapped_column(JSONB, default=dict, comment="敏感字段已 Fernet 加密")
|
||||
created_at: Mapped[datetime] = mapped_column(
|
||||
DateTime(timezone=True), server_default=func.now()
|
||||
)
|
||||
updated_at: Mapped[datetime] = mapped_column(
|
||||
DateTime(timezone=True), server_default=func.now(), onupdate=func.now()
|
||||
)
|
||||
@@ -0,0 +1,22 @@
|
||||
from datetime import datetime
|
||||
|
||||
from sqlalchemy import String, DateTime, func
|
||||
from sqlalchemy.dialects.postgresql import JSONB
|
||||
from sqlalchemy.orm import Mapped, mapped_column
|
||||
|
||||
from .base import BaseDataModel
|
||||
|
||||
|
||||
class ToolConfigModel(BaseDataModel):
|
||||
"""工具运行期配置(如 Tavily API key);config 内的敏感字段已 Fernet 加密。"""
|
||||
|
||||
__tablename__ = "tool_config"
|
||||
|
||||
tool_name: Mapped[str] = mapped_column(String(100), primary_key=True)
|
||||
config: Mapped[dict] = mapped_column(JSONB, default=dict)
|
||||
created_at: Mapped[datetime] = mapped_column(
|
||||
DateTime(timezone=True), server_default=func.now()
|
||||
)
|
||||
updated_at: Mapped[datetime] = mapped_column(
|
||||
DateTime(timezone=True), server_default=func.now(), onupdate=func.now()
|
||||
)
|
||||
@@ -79,3 +79,28 @@ class WorkflowContextModel(BaseDataModel):
|
||||
updated_at: Mapped[str] = mapped_column(
|
||||
DateTime(timezone=True), server_default=func.now(), onupdate=func.now()
|
||||
)
|
||||
|
||||
|
||||
class WorkflowGraphStateModel(BaseDataModel):
|
||||
"""pydantic_graph 持久化 blob 的存储表。
|
||||
|
||||
与 ``workflow_context`` 解耦——后者面向"业务展示 / 用户可读",前者面向
|
||||
"graph 引擎自身的状态恢复"。一份 trace_id 一行,jsonb 直接存 history 全量。
|
||||
"""
|
||||
|
||||
__tablename__ = "workflow_graph_state"
|
||||
|
||||
trace_id: Mapped[str] = mapped_column(
|
||||
String(64), primary_key=True, comment="对应的工作流 Trace ID"
|
||||
)
|
||||
history: Mapped[list] = mapped_column(
|
||||
JSONB,
|
||||
default=list,
|
||||
comment="pydantic_graph FullStatePersistence.history 的 JSON 序列化",
|
||||
)
|
||||
created_at: Mapped[str] = mapped_column(
|
||||
DateTime(timezone=True), server_default=func.now()
|
||||
)
|
||||
updated_at: Mapped[str] = mapped_column(
|
||||
DateTime(timezone=True), server_default=func.now(), onupdate=func.now()
|
||||
)
|
||||
|
||||
@@ -0,0 +1,81 @@
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from sqlalchemy import delete, select
|
||||
|
||||
from kilostar.core.postgres_database.database_exception import database_exception
|
||||
from kilostar.core.postgres_database.model.custom_toolset import CustomToolsetModel
|
||||
|
||||
|
||||
class CustomToolsetDatabase:
|
||||
"""用户自定义工具组 DAO。``tools`` 字段是工具名列表,业务层负责保证只放非 system/非 mcp 的工具。"""
|
||||
|
||||
def __init__(self, async_session_maker):
|
||||
self.async_session_maker = async_session_maker
|
||||
|
||||
@staticmethod
|
||||
def _to_dict(row: CustomToolsetModel) -> Dict[str, Any]:
|
||||
return {
|
||||
"toolset_id": row.toolset_id,
|
||||
"name": row.name,
|
||||
"description": row.description,
|
||||
"owner_id": row.owner_id,
|
||||
"tools": list(row.tools or []),
|
||||
}
|
||||
|
||||
@database_exception
|
||||
async def upsert(
|
||||
self,
|
||||
toolset_id: str,
|
||||
name: str,
|
||||
tools: List[str],
|
||||
description: Optional[str] = None,
|
||||
owner_id: Optional[str] = None,
|
||||
) -> Dict[str, Any]:
|
||||
async with self.async_session_maker() as session:
|
||||
stmt = select(CustomToolsetModel).where(
|
||||
CustomToolsetModel.toolset_id == toolset_id
|
||||
)
|
||||
row = (await session.execute(stmt)).scalar_one_or_none()
|
||||
if row:
|
||||
row.name = name
|
||||
row.description = description
|
||||
row.owner_id = owner_id
|
||||
row.tools = list(tools)
|
||||
else:
|
||||
row = CustomToolsetModel(
|
||||
toolset_id=toolset_id,
|
||||
name=name,
|
||||
description=description,
|
||||
owner_id=owner_id,
|
||||
tools=list(tools),
|
||||
)
|
||||
session.add(row)
|
||||
await session.commit()
|
||||
await session.refresh(row)
|
||||
return self._to_dict(row)
|
||||
|
||||
@database_exception
|
||||
async def get(self, toolset_id: str) -> Optional[Dict[str, Any]]:
|
||||
async with self.async_session_maker() as session:
|
||||
stmt = select(CustomToolsetModel).where(
|
||||
CustomToolsetModel.toolset_id == toolset_id
|
||||
)
|
||||
row = (await session.execute(stmt)).scalar_one_or_none()
|
||||
return self._to_dict(row) if row else None
|
||||
|
||||
@database_exception
|
||||
async def list_all(self) -> List[Dict[str, Any]]:
|
||||
async with self.async_session_maker() as session:
|
||||
stmt = select(CustomToolsetModel)
|
||||
rows = (await session.execute(stmt)).scalars().all()
|
||||
return [self._to_dict(r) for r in rows]
|
||||
|
||||
@database_exception
|
||||
async def delete(self, toolset_id: str) -> bool:
|
||||
async with self.async_session_maker() as session:
|
||||
stmt = delete(CustomToolsetModel).where(
|
||||
CustomToolsetModel.toolset_id == toolset_id
|
||||
)
|
||||
result = await session.execute(stmt)
|
||||
await session.commit()
|
||||
return result.rowcount > 0
|
||||
@@ -0,0 +1,79 @@
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from sqlalchemy import delete, select
|
||||
|
||||
from kilostar.core.postgres_database.database_exception import database_exception
|
||||
from kilostar.core.postgres_database.model.mcp_server import MCPServerModel
|
||||
from kilostar.utils.crypto import decrypt_dict_secrets, encrypt_dict_secrets
|
||||
|
||||
|
||||
class MCPServerDatabase:
|
||||
"""MCP 服务器配置 DAO;写入前自动加密 ``env`` 中的敏感字段,读出后自动解密。"""
|
||||
|
||||
def __init__(self, async_session_maker):
|
||||
self.async_session_maker = async_session_maker
|
||||
|
||||
@staticmethod
|
||||
def _to_dict(row: MCPServerModel) -> Dict[str, Any]:
|
||||
return {
|
||||
"server_id": row.server_id,
|
||||
"name": row.name,
|
||||
"transport": row.transport,
|
||||
"command": row.command,
|
||||
"args": row.args or [],
|
||||
"url": row.url,
|
||||
"tool_prefix": row.tool_prefix,
|
||||
"env": decrypt_dict_secrets(row.env or {}),
|
||||
}
|
||||
|
||||
@database_exception
|
||||
async def upsert(self, server_id: str, config: Dict[str, Any]) -> Dict[str, Any]:
|
||||
env = encrypt_dict_secrets(config.get("env") or {})
|
||||
async with self.async_session_maker() as session:
|
||||
stmt = select(MCPServerModel).where(MCPServerModel.server_id == server_id)
|
||||
row = (await session.execute(stmt)).scalar_one_or_none()
|
||||
if row:
|
||||
row.name = config.get("name", row.name)
|
||||
row.transport = config.get("transport", row.transport)
|
||||
row.command = config.get("command")
|
||||
row.args = config.get("args") or []
|
||||
row.url = config.get("url")
|
||||
row.tool_prefix = config.get("tool_prefix")
|
||||
row.env = env
|
||||
else:
|
||||
row = MCPServerModel(
|
||||
server_id=server_id,
|
||||
name=config.get("name", server_id),
|
||||
transport=config.get("transport", "stdio"),
|
||||
command=config.get("command"),
|
||||
args=config.get("args") or [],
|
||||
url=config.get("url"),
|
||||
tool_prefix=config.get("tool_prefix"),
|
||||
env=env,
|
||||
)
|
||||
session.add(row)
|
||||
await session.commit()
|
||||
await session.refresh(row)
|
||||
return self._to_dict(row)
|
||||
|
||||
@database_exception
|
||||
async def get(self, server_id: str) -> Optional[Dict[str, Any]]:
|
||||
async with self.async_session_maker() as session:
|
||||
stmt = select(MCPServerModel).where(MCPServerModel.server_id == server_id)
|
||||
row = (await session.execute(stmt)).scalar_one_or_none()
|
||||
return self._to_dict(row) if row else None
|
||||
|
||||
@database_exception
|
||||
async def list_all(self) -> List[Dict[str, Any]]:
|
||||
async with self.async_session_maker() as session:
|
||||
stmt = select(MCPServerModel)
|
||||
rows = (await session.execute(stmt)).scalars().all()
|
||||
return [self._to_dict(r) for r in rows]
|
||||
|
||||
@database_exception
|
||||
async def delete(self, server_id: str) -> bool:
|
||||
async with self.async_session_maker() as session:
|
||||
stmt = delete(MCPServerModel).where(MCPServerModel.server_id == server_id)
|
||||
result = await session.execute(stmt)
|
||||
await session.commit()
|
||||
return result.rowcount > 0
|
||||
@@ -17,10 +17,37 @@ from typing import List
|
||||
from kilostar.core.postgres_database.model.provider import ProviderModel
|
||||
from sqlalchemy import select
|
||||
from kilostar.core.postgres_database.database_exception import database_exception
|
||||
from kilostar.utils.crypto import (
|
||||
CryptoError,
|
||||
decrypt_secret,
|
||||
encrypt_secret,
|
||||
is_encrypted,
|
||||
)
|
||||
from kilostar.utils.logger import get_logger
|
||||
|
||||
logger = get_logger("provider_dao")
|
||||
|
||||
|
||||
def _decrypt_apikey(value):
|
||||
if not value:
|
||||
return value
|
||||
if not is_encrypted(value):
|
||||
return value
|
||||
try:
|
||||
return decrypt_secret(value)
|
||||
except CryptoError as e:
|
||||
logger.error(f"Provider apikey 解密失败: {e}")
|
||||
return value
|
||||
|
||||
|
||||
def _encrypt_apikey(value):
|
||||
if not value or is_encrypted(value):
|
||||
return value
|
||||
return encrypt_secret(value)
|
||||
|
||||
|
||||
class ProviderDatabase:
|
||||
"""Provider 表的 DAO:模型 Provider 的增删查改。"""
|
||||
"""Provider 表的 DAO:模型 Provider 的增删查改;``provider_apikey`` 透明 Fernet 加解密。"""
|
||||
|
||||
def __init__(self, async_session_maker):
|
||||
self.async_session_maker = async_session_maker
|
||||
@@ -37,11 +64,10 @@ class ProviderDatabase:
|
||||
provider_id=provider.provider_id,
|
||||
provider_title=provider.provider_title,
|
||||
provider_url=provider.provider_url,
|
||||
provider_apikey=provider.provider_apikey,
|
||||
provider_apikey=_decrypt_apikey(provider.provider_apikey),
|
||||
provider_models=provider.provider_models,
|
||||
provider_type=provider.provider_type,
|
||||
provider_owner=provider.provider_owner,
|
||||
provider_status=provider.provider_status,
|
||||
is_active=provider.is_active,
|
||||
)
|
||||
for provider in results
|
||||
@@ -50,7 +76,9 @@ class ProviderDatabase:
|
||||
|
||||
@database_exception
|
||||
async def add_provider(self, **kwargs) -> None:
|
||||
"""新建一条 Provider 记录;字段通过 kwargs 直接传给 ProviderModel。"""
|
||||
"""新建一条 Provider 记录;``provider_apikey`` 写入前自动加密。"""
|
||||
if "provider_apikey" in kwargs:
|
||||
kwargs["provider_apikey"] = _encrypt_apikey(kwargs["provider_apikey"])
|
||||
async with self.async_session_maker() as session:
|
||||
provider = ProviderModel(**kwargs)
|
||||
session.add(provider)
|
||||
@@ -67,7 +95,9 @@ class ProviderDatabase:
|
||||
|
||||
@database_exception
|
||||
async def update_provider(self, provider_id: str, **kwargs) -> None:
|
||||
"""部分更新指定 Provider 的字段;不存在时返回 None,否则返回刷新后的对象。"""
|
||||
"""部分更新指定 Provider 的字段;``provider_apikey`` 写入前自动加密。"""
|
||||
if "provider_apikey" in kwargs:
|
||||
kwargs["provider_apikey"] = _encrypt_apikey(kwargs["provider_apikey"])
|
||||
async with self.async_session_maker() as session:
|
||||
provider = await session.get(ProviderModel, provider_id)
|
||||
if provider is not None:
|
||||
@@ -76,5 +106,7 @@ class ProviderDatabase:
|
||||
session.add(provider)
|
||||
await session.commit()
|
||||
await session.refresh(provider)
|
||||
# 解密返回,方便上游使用
|
||||
provider.provider_apikey = _decrypt_apikey(provider.provider_apikey)
|
||||
return provider
|
||||
return None
|
||||
|
||||
@@ -0,0 +1,58 @@
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from sqlalchemy import delete, select
|
||||
|
||||
from kilostar.core.postgres_database.database_exception import database_exception
|
||||
from kilostar.core.postgres_database.model.tool_config import ToolConfigModel
|
||||
from kilostar.utils.crypto import decrypt_dict_secrets, encrypt_dict_secrets
|
||||
|
||||
|
||||
class ToolConfigDatabase:
|
||||
"""工具运行期配置 DAO;config 中的敏感字段(key/token/secret/password 系列)自动加解密。"""
|
||||
|
||||
def __init__(self, async_session_maker):
|
||||
self.async_session_maker = async_session_maker
|
||||
|
||||
@staticmethod
|
||||
def _to_dict(row: ToolConfigModel) -> Dict[str, Any]:
|
||||
return {
|
||||
"tool_name": row.tool_name,
|
||||
"config": decrypt_dict_secrets(row.config or {}),
|
||||
}
|
||||
|
||||
@database_exception
|
||||
async def upsert(self, tool_name: str, config: Dict[str, Any]) -> Dict[str, Any]:
|
||||
encrypted = encrypt_dict_secrets(config or {})
|
||||
async with self.async_session_maker() as session:
|
||||
stmt = select(ToolConfigModel).where(ToolConfigModel.tool_name == tool_name)
|
||||
row = (await session.execute(stmt)).scalar_one_or_none()
|
||||
if row:
|
||||
row.config = encrypted
|
||||
else:
|
||||
row = ToolConfigModel(tool_name=tool_name, config=encrypted)
|
||||
session.add(row)
|
||||
await session.commit()
|
||||
await session.refresh(row)
|
||||
return self._to_dict(row)
|
||||
|
||||
@database_exception
|
||||
async def get(self, tool_name: str) -> Optional[Dict[str, Any]]:
|
||||
async with self.async_session_maker() as session:
|
||||
stmt = select(ToolConfigModel).where(ToolConfigModel.tool_name == tool_name)
|
||||
row = (await session.execute(stmt)).scalar_one_or_none()
|
||||
return self._to_dict(row) if row else None
|
||||
|
||||
@database_exception
|
||||
async def list_all(self) -> List[Dict[str, Any]]:
|
||||
async with self.async_session_maker() as session:
|
||||
stmt = select(ToolConfigModel)
|
||||
rows = (await session.execute(stmt)).scalars().all()
|
||||
return [self._to_dict(r) for r in rows]
|
||||
|
||||
@database_exception
|
||||
async def delete(self, tool_name: str) -> bool:
|
||||
async with self.async_session_maker() as session:
|
||||
stmt = delete(ToolConfigModel).where(ToolConfigModel.tool_name == tool_name)
|
||||
result = await session.execute(stmt)
|
||||
await session.commit()
|
||||
return result.rowcount > 0
|
||||
@@ -17,6 +17,7 @@ from typing import List, Optional
|
||||
from kilostar.core.postgres_database.model.workflow import (
|
||||
Workflow,
|
||||
WorkflowContextModel,
|
||||
WorkflowGraphStateModel,
|
||||
)
|
||||
from sqlalchemy.ext.asyncio import async_sessionmaker, AsyncSession
|
||||
from kilostar.core.postgres_database.database_exception import database_exception
|
||||
@@ -101,3 +102,58 @@ class WorkflowDatabase:
|
||||
)
|
||||
results = await session.execute(statement)
|
||||
return results.scalar_one_or_none()
|
||||
|
||||
# ─── pydantic_graph 持久化(resume 用)─────────────────────────────
|
||||
|
||||
@database_exception
|
||||
async def upsert_workflow_graph_state(
|
||||
self, trace_id: str, history: list
|
||||
) -> WorkflowGraphStateModel:
|
||||
"""落 pydantic_graph FullStatePersistence.history 的 JSON 视图。
|
||||
|
||||
每个节点边界都会被引擎调一次,覆盖式写入;回滚到任一历史点是 graph
|
||||
引擎自身的能力,DB 这层只保留最新版本。
|
||||
"""
|
||||
async with self.async_session_maker() as session:
|
||||
statement = select(WorkflowGraphStateModel).where(
|
||||
WorkflowGraphStateModel.trace_id == trace_id
|
||||
)
|
||||
results = await session.execute(statement)
|
||||
record = results.scalar_one_or_none()
|
||||
if record:
|
||||
record.history = history
|
||||
else:
|
||||
record = WorkflowGraphStateModel(
|
||||
trace_id=trace_id, history=history
|
||||
)
|
||||
session.add(record)
|
||||
await session.commit()
|
||||
await session.refresh(record)
|
||||
return record
|
||||
|
||||
@database_exception
|
||||
async def get_workflow_graph_state(
|
||||
self, trace_id: str
|
||||
) -> Optional[WorkflowGraphStateModel]:
|
||||
"""读取 graph 持久化 history;不存在返回 ``None``。"""
|
||||
async with self.async_session_maker() as session:
|
||||
statement = select(WorkflowGraphStateModel).where(
|
||||
WorkflowGraphStateModel.trace_id == trace_id
|
||||
)
|
||||
results = await session.execute(statement)
|
||||
return results.scalar_one_or_none()
|
||||
|
||||
@database_exception
|
||||
async def delete_workflow_graph_state(self, trace_id: str) -> bool:
|
||||
"""删除某个工作流的 graph 持久化记录(用于显式清理)。"""
|
||||
async with self.async_session_maker() as session:
|
||||
statement = select(WorkflowGraphStateModel).where(
|
||||
WorkflowGraphStateModel.trace_id == trace_id
|
||||
)
|
||||
results = await session.execute(statement)
|
||||
record = results.scalar_one_or_none()
|
||||
if record is None:
|
||||
return False
|
||||
await session.delete(record)
|
||||
await session.commit()
|
||||
return True
|
||||
|
||||
@@ -38,6 +38,9 @@ from kilostar.core.postgres_database.model.chat_history import (
|
||||
ChatHistoryMessage,
|
||||
)
|
||||
from kilostar.core.postgres_database.model.system_node import SystemNodeConfigModel
|
||||
from kilostar.core.postgres_database.model.mcp_server import MCPServerModel
|
||||
from kilostar.core.postgres_database.model.tool_config import ToolConfigModel
|
||||
from kilostar.core.postgres_database.model.custom_toolset import CustomToolsetModel
|
||||
|
||||
from .module.individual import IndividualDatabase
|
||||
from .module.user import AuthDatabase
|
||||
@@ -45,6 +48,9 @@ from .module.provider import ProviderDatabase
|
||||
from .module.system_node import SystemNodeDatabase
|
||||
from .module.workflow import WorkflowDatabase
|
||||
from .module.chat_history import ChatHistoryDatabase
|
||||
from .module.mcp_server import MCPServerDatabase
|
||||
from .module.tool_config import ToolConfigDatabase
|
||||
from .module.custom_toolset import CustomToolsetDatabase
|
||||
|
||||
|
||||
@ray.remote
|
||||
@@ -76,6 +82,9 @@ class PostgresDatabase:
|
||||
self._system_node_database = SystemNodeDatabase(self.async_session_maker)
|
||||
self._workflow_database = WorkflowDatabase(self.async_session_maker)
|
||||
self._chat_history_database = ChatHistoryDatabase(self.async_session_maker)
|
||||
self._mcp_server_database = MCPServerDatabase(self.async_session_maker)
|
||||
self._tool_config_database = ToolConfigDatabase(self.async_session_maker)
|
||||
self._custom_toolset_database = CustomToolsetDatabase(self.async_session_maker)
|
||||
|
||||
self.ready_event = asyncio.Event()
|
||||
|
||||
@@ -91,6 +100,15 @@ class PostgresDatabase:
|
||||
finally:
|
||||
self.ready_event.set()
|
||||
|
||||
async def ping(self) -> bool:
|
||||
"""轻量探活:等待 ready 后执行 ``SELECT 1``。"""
|
||||
from sqlalchemy import text
|
||||
|
||||
await self.ready_event.wait()
|
||||
async with self.async_engine.connect() as conn:
|
||||
await conn.execute(text("SELECT 1"))
|
||||
return True
|
||||
|
||||
# Auth Database Methods
|
||||
async def add_user(self, user_name: str, hashed_password: str):
|
||||
"""新建一名用户。"""
|
||||
@@ -242,6 +260,24 @@ class PostgresDatabase:
|
||||
await self.ready_event.wait()
|
||||
return await self._workflow_database.get_workflow_context(trace_id)
|
||||
|
||||
# Workflow Graph State (pydantic_graph 持久化)
|
||||
async def upsert_workflow_graph_state(self, trace_id: str, history: list):
|
||||
"""覆盖式写入 graph 持久化 history(pydantic_graph 节点边界自动调用)。"""
|
||||
await self.ready_event.wait()
|
||||
return await self._workflow_database.upsert_workflow_graph_state(
|
||||
trace_id, history
|
||||
)
|
||||
|
||||
async def get_workflow_graph_state(self, trace_id: str):
|
||||
"""读取 graph 持久化记录,用于跨进程 resume。"""
|
||||
await self.ready_event.wait()
|
||||
return await self._workflow_database.get_workflow_graph_state(trace_id)
|
||||
|
||||
async def delete_workflow_graph_state(self, trace_id: str):
|
||||
"""显式清理 graph 持久化记录(已完成/失败的 workflow 释放空间)。"""
|
||||
await self.ready_event.wait()
|
||||
return await self._workflow_database.delete_workflow_graph_state(trace_id)
|
||||
|
||||
# Chat History Database Methods
|
||||
async def create_chat_session(self, user_id: str, title: str = "新对话"):
|
||||
"""新建一个聊天会话。"""
|
||||
@@ -264,3 +300,79 @@ class PostgresDatabase:
|
||||
"""返回某个聊天会话的全部消息。"""
|
||||
await self.ready_event.wait()
|
||||
return await self._chat_history_database.list_chat_messages(chat_id)
|
||||
|
||||
# MCP Server Database Methods
|
||||
async def upsert_mcp_server(self, server_id: str, config: dict):
|
||||
"""插入或更新一条 MCP 服务器配置;env 中敏感字段自动加密。"""
|
||||
await self.ready_event.wait()
|
||||
return await self._mcp_server_database.upsert(server_id, config)
|
||||
|
||||
async def get_mcp_server_db(self, server_id: str):
|
||||
"""读取单条 MCP 服务器配置;env 自动解密。"""
|
||||
await self.ready_event.wait()
|
||||
return await self._mcp_server_database.get(server_id)
|
||||
|
||||
async def list_mcp_servers_db(self):
|
||||
"""读取全部 MCP 服务器配置。"""
|
||||
await self.ready_event.wait()
|
||||
return await self._mcp_server_database.list_all()
|
||||
|
||||
async def delete_mcp_server_db(self, server_id: str):
|
||||
"""删除某条 MCP 服务器配置。"""
|
||||
await self.ready_event.wait()
|
||||
return await self._mcp_server_database.delete(server_id)
|
||||
|
||||
# Tool Config Database Methods
|
||||
async def upsert_tool_config(self, tool_name: str, config: dict):
|
||||
"""插入或更新某工具的运行期配置;敏感字段自动加密。"""
|
||||
await self.ready_event.wait()
|
||||
return await self._tool_config_database.upsert(tool_name, config)
|
||||
|
||||
async def get_tool_config_db(self, tool_name: str):
|
||||
"""读取某工具的运行期配置;敏感字段自动解密。"""
|
||||
await self.ready_event.wait()
|
||||
return await self._tool_config_database.get(tool_name)
|
||||
|
||||
async def list_tool_configs_db(self):
|
||||
"""读取全部工具的运行期配置。"""
|
||||
await self.ready_event.wait()
|
||||
return await self._tool_config_database.list_all()
|
||||
|
||||
async def delete_tool_config_db(self, tool_name: str):
|
||||
"""删除某工具的运行期配置。"""
|
||||
await self.ready_event.wait()
|
||||
return await self._tool_config_database.delete(tool_name)
|
||||
|
||||
# Custom Toolset Database Methods
|
||||
async def upsert_custom_toolset(
|
||||
self,
|
||||
toolset_id: str,
|
||||
name: str,
|
||||
tools: list,
|
||||
description: str = None,
|
||||
owner_id: str = None,
|
||||
):
|
||||
"""插入或更新一个用户自定义工具组。"""
|
||||
await self.ready_event.wait()
|
||||
return await self._custom_toolset_database.upsert(
|
||||
toolset_id=toolset_id,
|
||||
name=name,
|
||||
tools=tools,
|
||||
description=description,
|
||||
owner_id=owner_id,
|
||||
)
|
||||
|
||||
async def get_custom_toolset(self, toolset_id: str):
|
||||
"""按 ID 读取一个自定义工具组。"""
|
||||
await self.ready_event.wait()
|
||||
return await self._custom_toolset_database.get(toolset_id)
|
||||
|
||||
async def list_custom_toolsets(self):
|
||||
"""读取全部自定义工具组。"""
|
||||
await self.ready_event.wait()
|
||||
return await self._custom_toolset_database.list_all()
|
||||
|
||||
async def delete_custom_toolset(self, toolset_id: str):
|
||||
"""删除一个自定义工具组。"""
|
||||
await self.ready_event.wait()
|
||||
return await self._custom_toolset_database.delete(toolset_id)
|
||||
|
||||
Reference in New Issue
Block a user