# Copyright 2026 zhaoxi826 # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import os import asyncio from kilostar.utils.standalone_proxy import actor_class from sqlalchemy.ext.asyncio import create_async_engine, AsyncSession from sqlalchemy.orm import sessionmaker from kilostar.core.postgres_database.model.base import BaseDataModel # 在 create_all 前显式导入所有 ORM 模型类,确保它们注册到 metadata from kilostar.core.postgres_database.model.provider import ProviderModel from kilostar.core.postgres_database.model.user import User from kilostar.core.postgres_database.model.individual import ( BaseIndividualModel, SpecialistIndividualModel, OrdinaryIndividualModel, SpecialIndividualModel, ) from kilostar.core.postgres_database.model.workflow import ( Workflow, WorkflowContextModel, ) from kilostar.core.postgres_database.model.chat_history import ( ChatHistoryRegister, ChatHistoryMessage, ) from kilostar.core.postgres_database.model.system_node import SystemNodeConfigModel from kilostar.core.postgres_database.model.mcp_server import MCPServerModel from kilostar.core.postgres_database.model.tool_config import ToolConfigModel from kilostar.core.postgres_database.model.custom_toolset import CustomToolsetModel from kilostar.core.postgres_database.model.system_event_log import SystemEventLog from .module.individual import IndividualDatabase from .module.user import AuthDatabase from .module.provider import ProviderDatabase from .module.system_node import SystemNodeDatabase from .module.workflow import WorkflowDatabase from .module.chat_history import ChatHistoryDatabase from .module.mcp_server import MCPServerDatabase from .module.tool_config import ToolConfigDatabase from .module.custom_toolset import CustomToolsetDatabase from .module.system_event_log import SystemEventLogDatabase @actor_class class PostgresDatabase: """以 Ray Actor 形式暴露的统一数据库门面。 内部组合了 Auth / Provider / Individual / SystemNode / Workflow / ChatHistory 六个子库,所有方法在调用前都会等待 ``ready_event``,确保 ``init_db`` 完成后 再放行业务请求。 """ def __init__(self): user = os.environ.get("POSTGRES_USER") password = os.environ.get("POSTGRES_PASSWORD") host = os.environ.get("POSTGRES_HOST") port = os.environ.get("POSTGRES_PORT") database = os.environ.get("POSTGRES_DB") database_url = ( f"postgresql+asyncpg://{user}:{password}@{host}:{port}/{database}" ) self.async_engine = create_async_engine(database_url, echo=True) self.async_session_maker = sessionmaker( self.async_engine, class_=AsyncSession, expire_on_commit=False ) self._auth_database = AuthDatabase(self.async_session_maker) self._provider_database = ProviderDatabase(self.async_session_maker) self._individual_database = IndividualDatabase(self.async_session_maker) self._system_node_database = SystemNodeDatabase(self.async_session_maker) self._workflow_database = WorkflowDatabase(self.async_session_maker) self._chat_history_database = ChatHistoryDatabase(self.async_session_maker) self._mcp_server_database = MCPServerDatabase(self.async_session_maker) self._tool_config_database = ToolConfigDatabase(self.async_session_maker) self._custom_toolset_database = CustomToolsetDatabase(self.async_session_maker) self._system_event_log_database = SystemEventLogDatabase(self.async_session_maker) self.ready_event = asyncio.Event() async def init_db(self) -> None: """根据 metadata 创建(或校验)所有 ORM 表,并置位 ready_event。""" try: async with self.async_engine.begin() as conn: await conn.run_sync(BaseDataModel.metadata.create_all) print("✅ 数据库表创建/验证完成") self.ready_event.set() except Exception as e: print(f"❌ 数据库初始化失败: {e}") raise async def ping(self) -> bool: """轻量探活:等待 ready 后执行 ``SELECT 1``。""" from sqlalchemy import text await self.ready_event.wait() async with self.async_engine.connect() as conn: await conn.execute(text("SELECT 1")) return True # Auth Database Methods async def add_user(self, user_name: str, hashed_password: str): """新建一名用户。""" await self.ready_event.wait() return await self._auth_database.add_user(user_name, hashed_password) async def change_password(self, user_name, old_password, new_password): """校验旧密码后将用户的密码替换为新密码。""" await self.ready_event.wait() return await self._auth_database.change_password( user_name, old_password, new_password ) async def delete_user(self, user_name: str): """按用户名删除一名用户。""" await self.ready_event.wait() return await self._auth_database.delete_user(user_name) async def delete_user_by_id(self, user_id: str): """按用户 ID 删除一名用户。""" await self.ready_event.wait() return await self._auth_database.delete_user_by_id(user_id) async def login_user(self, user_name: str): """按用户名查询用户记录,用于上层做密码校验与签发 token。""" await self.ready_event.wait() return await self._auth_database.login_user(user_name) async def get_all_users(self): """返回全部用户列表。""" await self.ready_event.wait() return await self._auth_database.get_all_users() async def get_user_authority(self, user_id: str): """读取指定用户的权限/角色字段。""" await self.ready_event.wait() return await self._auth_database.get_user_authority(user_id) async def change_user_authority(self, user_id: str, new_authority): """更新指定用户的权限/角色字段。""" await self.ready_event.wait() return await self._auth_database.change_user_authority(user_id, new_authority) # Provider Database Methods async def get_provider(self): """返回全部已登记的模型 Provider。""" await self.ready_event.wait() return await self._provider_database.get_provider() async def add_provider_db(self, **kwargs): """新增一个模型 Provider 记录。""" await self.ready_event.wait() return await self._provider_database.add_provider(**kwargs) async def delete_provider_db(self, provider_id: str): """删除指定 ID 的模型 Provider 记录。""" await self.ready_event.wait() return await self._provider_database.delete_provider(provider_id) async def update_provider_db(self, provider_id: str, **kwargs): """部分更新指定 Provider 的字段。""" await self.ready_event.wait() return await self._provider_database.update_provider(provider_id, **kwargs) # System Node Database Methods async def upsert_system_node_config( self, node_name: str, provider_title: str, model_id: str, tools: list[str] = None, ): """插入或更新某个系统节点(如 control/consciousness/regulatory)的模型配置。""" await self.ready_event.wait() return await self._system_node_database.upsert_system_node_config( node_name, provider_title, model_id, tools ) async def get_all_system_node_configs(self): """返回所有系统节点的模型配置。""" await self.ready_event.wait() return await self._system_node_database.get_all_system_node_configs() # Individual Database Methods async def add_worker_individual(self, **kwargs): """登记一个新的 Worker Individual 配置。""" await self.ready_event.wait() return await self._individual_database.add_worker_individual(**kwargs) async def get_worker_individual(self, agent_id: str): """按 agent_id 读取单个 Worker Individual 配置。""" await self.ready_event.wait() return await self._individual_database.get_worker_individual(agent_id) async def get_worker_individual_list(self, owner_id: str): """读取某用户名下的所有 Worker Individual 配置。""" await self.ready_event.wait() return await self._individual_database.get_worker_individual_list(owner_id) async def update_worker_individual(self, agent_id: str, **kwargs): """部分更新指定 Worker Individual 的字段。""" await self.ready_event.wait() return await self._individual_database.update_worker_individual( agent_id, **kwargs ) async def delete_worker_individual(self, agent_id: str): """删除指定的 Worker Individual。""" await self.ready_event.wait() return await self._individual_database.delete_worker_individual(agent_id) async def get_all_worker_individual(self): """返回全部 Worker Individual 配置。""" await self.ready_event.wait() return await self._individual_database.get_all_worker_individual() # Workflow Database Methods async def create_workflow( self, trace_id: str, user_id: str, title: str, command: str ): """新建一个工作流记录。""" await self.ready_event.wait() return await self._workflow_database.create_workflow( trace_id, user_id, title, command ) async def get_workflow(self, trace_id: str): """按 trace_id 读取工作流记录。""" await self.ready_event.wait() return await self._workflow_database.get_workflow(trace_id) async def update_workflow_status(self, trace_id: str, status: str): """更新工作流的状态字段。""" await self.ready_event.wait() return await self._workflow_database.update_workflow_status(trace_id, status) async def list_workflows(self, user_id: str): """返回某用户名下的全部工作流。""" await self.ready_event.wait() return await self._workflow_database.list_workflows(user_id) async def upsert_workflow_context(self, trace_id: str, **kwargs): """插入或更新工作流的运行期上下文快照。""" await self.ready_event.wait() return await self._workflow_database.upsert_workflow_context(trace_id, **kwargs) async def get_workflow_context(self, trace_id: str): """读取指定工作流的上下文快照。""" await self.ready_event.wait() return await self._workflow_database.get_workflow_context(trace_id) # Workflow Graph State (pydantic_graph 持久化) async def upsert_workflow_graph_state(self, trace_id: str, history: list): """覆盖式写入 graph 持久化 history(pydantic_graph 节点边界自动调用)。""" await self.ready_event.wait() return await self._workflow_database.upsert_workflow_graph_state( trace_id, history ) async def get_workflow_graph_state(self, trace_id: str): """读取 graph 持久化记录,用于跨进程 resume。""" await self.ready_event.wait() return await self._workflow_database.get_workflow_graph_state(trace_id) async def delete_workflow_graph_state(self, trace_id: str): """显式清理 graph 持久化记录(已完成/失败的 workflow 释放空间)。""" await self.ready_event.wait() return await self._workflow_database.delete_workflow_graph_state(trace_id) # Chat History Database Methods async def create_chat_session(self, user_id: str, title: str = "新对话"): """新建一个聊天会话。""" await self.ready_event.wait() return await self._chat_history_database.create_chat_session(user_id, title) async def list_chat_sessions(self, user_id: str): """返回某用户名下的全部聊天会话。""" await self.ready_event.wait() return await self._chat_history_database.list_chat_sessions(user_id) async def add_chat_message(self, chat_id: str, message: str, message_owner: str): """向某个聊天会话追加一条消息。""" await self.ready_event.wait() return await self._chat_history_database.add_chat_message( chat_id, message, message_owner ) async def list_chat_messages(self, chat_id: str): """返回某个聊天会话的全部消息。""" await self.ready_event.wait() return await self._chat_history_database.list_chat_messages(chat_id) # MCP Server Database Methods async def upsert_mcp_server(self, server_id: str, config: dict): """插入或更新一条 MCP 服务器配置;env 中敏感字段自动加密。""" await self.ready_event.wait() return await self._mcp_server_database.upsert(server_id, config) async def get_mcp_server_db(self, server_id: str): """读取单条 MCP 服务器配置;env 自动解密。""" await self.ready_event.wait() return await self._mcp_server_database.get(server_id) async def list_mcp_servers_db(self): """读取全部 MCP 服务器配置。""" await self.ready_event.wait() return await self._mcp_server_database.list_all() async def delete_mcp_server_db(self, server_id: str): """删除某条 MCP 服务器配置。""" await self.ready_event.wait() return await self._mcp_server_database.delete(server_id) # Tool Config Database Methods async def upsert_tool_config(self, tool_name: str, config: dict): """插入或更新某工具的运行期配置;敏感字段自动加密。""" await self.ready_event.wait() return await self._tool_config_database.upsert(tool_name, config) async def get_tool_config_db(self, tool_name: str): """读取某工具的运行期配置;敏感字段自动解密。""" await self.ready_event.wait() return await self._tool_config_database.get(tool_name) async def list_tool_configs_db(self): """读取全部工具的运行期配置。""" await self.ready_event.wait() return await self._tool_config_database.list_all() async def delete_tool_config_db(self, tool_name: str): """删除某工具的运行期配置。""" await self.ready_event.wait() return await self._tool_config_database.delete(tool_name) # Custom Toolset Database Methods async def upsert_custom_toolset( self, toolset_id: str, name: str, tools: list, description: str = None, owner_id: str = None, ): """插入或更新一个用户自定义工具组。""" await self.ready_event.wait() return await self._custom_toolset_database.upsert( toolset_id=toolset_id, name=name, tools=tools, description=description, owner_id=owner_id, ) async def get_custom_toolset(self, toolset_id: str): """按 ID 读取一个自定义工具组。""" await self.ready_event.wait() return await self._custom_toolset_database.get(toolset_id) async def list_custom_toolsets(self): """读取全部自定义工具组。""" await self.ready_event.wait() return await self._custom_toolset_database.list_all() async def delete_custom_toolset(self, toolset_id: str): """删除一个自定义工具组。""" await self.ready_event.wait() return await self._custom_toolset_database.delete(toolset_id) # System Event Log Methods async def insert_event_log( self, trace_id: str, event_type: str, level: str, message: str, node_name=None, metadata=None, ): await self.ready_event.wait() return await self._system_event_log_database.insert_event( trace_id=trace_id, event_type=event_type, level=level, message=message, node_name=node_name, metadata=metadata, ) async def query_event_logs( self, trace_id=None, event_type=None, level=None, limit=100, offset=0 ): await self.ready_event.wait() return await self._system_event_log_database.query_events( trace_id=trace_id, event_type=event_type, level=level, limit=limit, offset=offset, )