Refactor Workflow and Chat Architecture (#68)
* refactor: overhaul workflow and chat architecture - Separate Chat and Workflow API endpoints and database models - Use JSONB to store workflow execution context in Postgres - Convert workflow engine to use pydantic-ai execution graphs inside a Ray task - Update frontend React components to support standalone workflow creation - Remove obsolete and broken workflow runner tests Co-authored-by: zhaoxi826 <198742034+zhaoxi826@users.noreply.github.com> * refactor: overhaul workflow and chat architecture - Separate Chat and Workflow API endpoints and database models - Use JSONB to store workflow execution context in Postgres - Convert workflow engine to use pydantic-ai execution graphs inside a Ray task - Update frontend React components to support standalone workflow creation - Remove obsolete and broken workflow runner tests Co-authored-by: zhaoxi826 <198742034+zhaoxi826@users.noreply.github.com> * refactor: overhaul workflow and chat architecture - Separate Chat and Workflow API endpoints and database models - Use JSONB to store workflow execution context in Postgres - Convert workflow engine to use pydantic-ai execution graphs inside a Ray task - Update frontend React components to support standalone workflow creation - Move workflow_engine inside workflow package to keep core root clean - Remove obsolete and broken workflow runner tests Co-authored-by: zhaoxi826 <198742034+zhaoxi826@users.noreply.github.com> --------- Co-authored-by: google-labs-jules[bot] <161369871+google-labs-jules[bot]@users.noreply.github.com> Co-authored-by: zhaoxi826 <198742034+zhaoxi826@users.noreply.github.com>
This commit is contained in:
@@ -15,5 +15,21 @@
|
||||
from kilostar.core.postgres_database.model.user import User
|
||||
from kilostar.core.postgres_database.model.provider import Provider
|
||||
from kilostar.core.postgres_database.model.individual import WorkerIndividual
|
||||
from kilostar.core.postgres_database.model.workflow import (
|
||||
Workflow,
|
||||
WorkflowContextModel,
|
||||
)
|
||||
from kilostar.core.postgres_database.model.chat_history import (
|
||||
ChatHistoryRegister,
|
||||
ChatHistoryMessage,
|
||||
)
|
||||
|
||||
__all__ = ["User", "Provider", "WorkerIndividual"]
|
||||
__all__ = [
|
||||
"User",
|
||||
"Provider",
|
||||
"WorkerIndividual",
|
||||
"Workflow",
|
||||
"WorkflowContextModel",
|
||||
"ChatHistoryRegister",
|
||||
"ChatHistoryMessage",
|
||||
]
|
||||
|
||||
@@ -15,5 +15,6 @@
|
||||
from sqlalchemy.ext.asyncio import AsyncAttrs
|
||||
from sqlalchemy.orm import DeclarativeBase
|
||||
|
||||
|
||||
class BaseDataModel(DeclarativeBase, AsyncAttrs):
|
||||
pass
|
||||
pass
|
||||
|
||||
@@ -11,19 +11,55 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from typing import Literal
|
||||
|
||||
from sqlalchemy import String, DateTime, func
|
||||
from sqlalchemy.orm import Mapped, mapped_column
|
||||
from .base import BaseDataModel
|
||||
from sqlalchemy.orm import Mapped
|
||||
|
||||
class ChatHistoryMessage(BaseDataModel):
|
||||
__tablename__ = "chat_history_massage"
|
||||
message_id: Mapped[str]
|
||||
message: Mapped[str]
|
||||
message_owner: Literal["user","regulatory_node"]
|
||||
|
||||
class ChatHistoryRegister(BaseDataModel):
|
||||
__tablename__ = "chat_history_register"
|
||||
chat_id: Mapped[str]
|
||||
user_id: Mapped[str]
|
||||
"""
|
||||
一个特定的聊天会话记录注册表。
|
||||
类似于多会话的一个 Thread/Session。
|
||||
"""
|
||||
|
||||
__tablename__ = "chat_history_register"
|
||||
|
||||
chat_id: Mapped[str] = mapped_column(
|
||||
String(64), primary_key=True, description="聊天会话ID"
|
||||
)
|
||||
user_id: Mapped[str] = mapped_column(
|
||||
String(64), index=True, description="归属的用户ID"
|
||||
)
|
||||
title: Mapped[str] = mapped_column(
|
||||
String(255), default="新对话", description="对话标题"
|
||||
)
|
||||
created_at: Mapped[str] = mapped_column(
|
||||
DateTime(timezone=True), server_default=func.now()
|
||||
)
|
||||
updated_at: Mapped[str] = mapped_column(
|
||||
DateTime(timezone=True), server_default=func.now(), onupdate=func.now()
|
||||
)
|
||||
|
||||
|
||||
class ChatHistoryMessage(BaseDataModel):
|
||||
"""
|
||||
特定会话中的每一条具体消息记录。
|
||||
"""
|
||||
|
||||
__tablename__ = "chat_history_message"
|
||||
|
||||
message_id: Mapped[str] = mapped_column(
|
||||
String(64), primary_key=True, description="消息ID"
|
||||
)
|
||||
chat_id: Mapped[str] = mapped_column(
|
||||
String(64), index=True, description="所属会话ID"
|
||||
)
|
||||
message: Mapped[str] = mapped_column(String, description="消息体内容")
|
||||
message_owner: Mapped[str] = mapped_column(
|
||||
String(50),
|
||||
description="消息发送方,例如 'user', 'regulatory_node', 'consciousness_node' 等",
|
||||
)
|
||||
created_at: Mapped[str] = mapped_column(
|
||||
DateTime(timezone=True), server_default=func.now()
|
||||
)
|
||||
|
||||
@@ -44,10 +44,7 @@ class BaseIndividualModel(BaseDataModel):
|
||||
|
||||
agent_type: Mapped[str] = mapped_column(String(32))
|
||||
|
||||
__mapper_args__ = {
|
||||
"polymorphic_on": "agent_type",
|
||||
"polymorphic_identity": "base"
|
||||
}
|
||||
__mapper_args__ = {"polymorphic_on": "agent_type", "polymorphic_identity": "base"}
|
||||
|
||||
|
||||
# ==========================================
|
||||
@@ -57,8 +54,7 @@ class SpecialistIndividualModel(BaseIndividualModel):
|
||||
__tablename__ = "specialist_individual"
|
||||
|
||||
agent_id: Mapped[str] = mapped_column(
|
||||
ForeignKey("base_individual.agent_id", ondelete="CASCADE"),
|
||||
primary_key=True
|
||||
ForeignKey("base_individual.agent_id", ondelete="CASCADE"), primary_key=True
|
||||
)
|
||||
bound_skill: Mapped[Optional[Dict[str, Any]]] = mapped_column(JSONB)
|
||||
workspace: Mapped[Optional[List[str]]] = mapped_column(JSONB)
|
||||
@@ -70,12 +66,12 @@ class SpecialistIndividualModel(BaseIndividualModel):
|
||||
sub_ordinary_agents: Mapped[List["OrdinaryIndividualModel"]] = relationship(
|
||||
back_populates="manager",
|
||||
cascade="all, delete-orphan",
|
||||
foreign_keys="[OrdinaryIndividualModel.manager_id]"
|
||||
foreign_keys="[OrdinaryIndividualModel.manager_id]",
|
||||
)
|
||||
sub_special_agents: Mapped[List["SpecialIndividualModel"]] = relationship(
|
||||
back_populates="manager",
|
||||
cascade="all, delete-orphan",
|
||||
foreign_keys="[SpecialIndividualModel.manager_id]"
|
||||
foreign_keys="[SpecialIndividualModel.manager_id]",
|
||||
)
|
||||
|
||||
__mapper_args__ = {
|
||||
@@ -90,8 +86,7 @@ class OrdinaryIndividualModel(BaseIndividualModel):
|
||||
__tablename__ = "ordinary_individual"
|
||||
|
||||
agent_id: Mapped[str] = mapped_column(
|
||||
ForeignKey("base_individual.agent_id", ondelete="CASCADE"),
|
||||
primary_key=True
|
||||
ForeignKey("base_individual.agent_id", ondelete="CASCADE"), primary_key=True
|
||||
)
|
||||
finetuned_from: Mapped[Optional[str]] = mapped_column(String(100))
|
||||
tools: Mapped[Optional[List[str]]] = mapped_column(
|
||||
@@ -106,7 +101,7 @@ class OrdinaryIndividualModel(BaseIndividualModel):
|
||||
# 逻辑关联:指向上级专家
|
||||
manager: Mapped[Optional["SpecialistIndividualModel"]] = relationship(
|
||||
back_populates="sub_ordinary_agents",
|
||||
foreign_keys=[manager_id] # 显式指定使用 manager_id 解析关系
|
||||
foreign_keys=[manager_id], # 显式指定使用 manager_id 解析关系
|
||||
)
|
||||
|
||||
__mapper_args__ = {
|
||||
@@ -121,12 +116,10 @@ class SpecialIndividualModel(BaseIndividualModel):
|
||||
__tablename__ = "special_individual"
|
||||
|
||||
agent_id: Mapped[str] = mapped_column(
|
||||
ForeignKey("base_individual.agent_id", ondelete="CASCADE"),
|
||||
primary_key=True
|
||||
ForeignKey("base_individual.agent_id", ondelete="CASCADE"), primary_key=True
|
||||
)
|
||||
modality_type: Mapped[ModalityType] = mapped_column(
|
||||
default=ModalityType.MULTIMODAL,
|
||||
server_default=text("'multimodal'")
|
||||
default=ModalityType.MULTIMODAL, server_default=text("'multimodal'")
|
||||
)
|
||||
multimodal_config: Mapped[Optional[Dict[str, Any]]] = mapped_column(JSONB)
|
||||
|
||||
@@ -137,10 +130,9 @@ class SpecialIndividualModel(BaseIndividualModel):
|
||||
|
||||
# 【修复2】:修正 back_populates 指向正确的变量名
|
||||
manager: Mapped[Optional["SpecialistIndividualModel"]] = relationship(
|
||||
back_populates="sub_special_agents",
|
||||
foreign_keys=[manager_id]
|
||||
back_populates="sub_special_agents", foreign_keys=[manager_id]
|
||||
)
|
||||
|
||||
__mapper_args__ = {
|
||||
"polymorphic_identity": "special",
|
||||
}
|
||||
}
|
||||
|
||||
@@ -24,6 +24,7 @@ class ProviderModel(BaseDataModel):
|
||||
Provider 物理模型。
|
||||
作为模型/服务提供商适配器,标准化不同供应商(OpenAI, Anthropic 等)的配置。
|
||||
"""
|
||||
|
||||
__tablename__ = "provider"
|
||||
provider_id: Mapped[str] = mapped_column(String(64), primary_key=True)
|
||||
provider_title: Mapped[str] = mapped_column(String(100), index=True, nullable=False)
|
||||
@@ -31,14 +32,12 @@ class ProviderModel(BaseDataModel):
|
||||
provider_url: Mapped[Optional[str]] = mapped_column(Text)
|
||||
provider_apikey: Mapped[Optional[str]] = mapped_column(Text)
|
||||
provider_models: Mapped[List[str]] = mapped_column(
|
||||
JSONB,
|
||||
default=list,
|
||||
server_default=text("'[]'::jsonb")
|
||||
JSONB, default=list, server_default=text("'[]'::jsonb")
|
||||
)
|
||||
provider_owner: Mapped[str] = mapped_column(String(64), index=True)
|
||||
is_active: Mapped[bool] = mapped_column(
|
||||
Boolean,
|
||||
default=True,
|
||||
server_default=text("true"),
|
||||
comment="该服务商节点是否在线/启用"
|
||||
comment="该服务商节点是否在线/启用",
|
||||
)
|
||||
|
||||
@@ -13,10 +13,12 @@
|
||||
# limitations under the License.
|
||||
|
||||
from typing import List, Optional
|
||||
from sqlalchemy import String, Text
|
||||
from sqlalchemy.dialects.postgresql import JSONB # 针对 Postgres 优化,支持索引和高性能解析
|
||||
from sqlalchemy import String
|
||||
from sqlalchemy.dialects.postgresql import (
|
||||
JSONB,
|
||||
) # 针对 Postgres 优化,支持索引和高性能解析
|
||||
from sqlalchemy.orm import Mapped, mapped_column
|
||||
from .base import BaseDataModel
|
||||
from .base import BaseDataModel
|
||||
|
||||
|
||||
class SystemNodeConfigModel(BaseDataModel):
|
||||
@@ -24,12 +26,11 @@ class SystemNodeConfigModel(BaseDataModel):
|
||||
SystemNodeConfig 物理模型。
|
||||
作为 kilostar 架构中的独立处理单元,负责存储 LLM 节点的执行策略与工具配置。
|
||||
"""
|
||||
|
||||
__tablename__ = "system_node_config"
|
||||
node_name: Mapped[str] = mapped_column(String(100), primary_key=True)
|
||||
provider_title: Mapped[str] = mapped_column(String(50), nullable=False)
|
||||
model_id: Mapped[str] = mapped_column(String(100), nullable=False)
|
||||
tools: Mapped[Optional[List[str]]] = mapped_column(
|
||||
JSONB,
|
||||
default=list,
|
||||
comment="节点可调用的工具标识列表"
|
||||
JSONB, default=list, comment="节点可调用的工具标识列表"
|
||||
)
|
||||
|
||||
@@ -25,6 +25,7 @@ class UserAuthority(IntEnum):
|
||||
"""
|
||||
权限枚举类
|
||||
"""
|
||||
|
||||
SUPER_ADMINISTRATOR = 100
|
||||
ADMINISTRATOR = 50
|
||||
USER = 20
|
||||
@@ -36,12 +37,11 @@ class User(BaseDataModel):
|
||||
"""
|
||||
数据库user表模型
|
||||
"""
|
||||
|
||||
__tablename__ = "user"
|
||||
user_id: Mapped[str] = mapped_column(String(64), primary_key=True)
|
||||
user_name: Mapped[str] = mapped_column(String(100), index=True, nullable=False)
|
||||
hashed_password: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||
user_authority: Mapped[UserAuthority] = mapped_column(
|
||||
Integer,
|
||||
default=UserAuthority.USER,
|
||||
server_default=text("20")
|
||||
Integer, default=UserAuthority.USER, server_default=text("20")
|
||||
)
|
||||
|
||||
@@ -12,12 +12,70 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
|
||||
from sqlmodel import SQLModel, Field
|
||||
from sqlalchemy import String, DateTime, func
|
||||
from sqlalchemy.orm import Mapped, mapped_column
|
||||
from sqlalchemy.dialects.postgresql import JSONB
|
||||
from .base import BaseDataModel
|
||||
|
||||
|
||||
class EventRecord(SQLModel, table=True):
|
||||
trace_id: str = Field(
|
||||
primary_key=True, description="The unique trace ID of the kilostarEvent"
|
||||
class Workflow(BaseDataModel):
|
||||
__tablename__ = "workflow"
|
||||
|
||||
trace_id: Mapped[str] = mapped_column(
|
||||
String(64), primary_key=True, description="工作流唯一ID (Trace ID)"
|
||||
)
|
||||
user_id: Mapped[str] = mapped_column(
|
||||
String(64), index=True, description="创建该工作流的用户ID"
|
||||
)
|
||||
title: Mapped[str] = mapped_column(String(255), description="工作流标题/简短描述")
|
||||
command: Mapped[str] = mapped_column(
|
||||
String, description="创建工作流的原始用户命令文本"
|
||||
)
|
||||
status: Mapped[str] = mapped_column(
|
||||
String(50),
|
||||
default="creating",
|
||||
description="工作流的总体状态 (例如: creating, running, pending, completed, failed等)",
|
||||
)
|
||||
version: Mapped[str] = mapped_column(
|
||||
String(50), default="v1.0", description="系统协议版本号"
|
||||
)
|
||||
created_at: Mapped[str] = mapped_column(
|
||||
DateTime(timezone=True), server_default=func.now()
|
||||
)
|
||||
updated_at: Mapped[str] = mapped_column(
|
||||
DateTime(timezone=True), server_default=func.now(), onupdate=func.now()
|
||||
)
|
||||
|
||||
|
||||
class WorkflowContextModel(BaseDataModel):
|
||||
__tablename__ = "workflow_context"
|
||||
|
||||
trace_id: Mapped[str] = mapped_column(
|
||||
String(64), primary_key=True, description="对应的工作流 Trace ID"
|
||||
)
|
||||
workflow_status: Mapped[dict] = mapped_column(
|
||||
JSONB, default=dict, description="工作流状态变更历史"
|
||||
)
|
||||
blackboard: Mapped[dict] = mapped_column(
|
||||
JSONB, default=dict, description="大模型输出的存储区 (共享黑板)"
|
||||
)
|
||||
work_step_status: Mapped[dict] = mapped_column(
|
||||
JSONB, nullable=True, description="工作流运行步骤状态"
|
||||
)
|
||||
workflow_pointer: Mapped[int] = mapped_column(
|
||||
nullable=True, description="工作流指针,指向具体运行步骤位置"
|
||||
)
|
||||
workflow_log: Mapped[list] = mapped_column(
|
||||
JSONB, default=list, description="工作流运行日志"
|
||||
)
|
||||
work_link: Mapped[list] = mapped_column(
|
||||
JSONB,
|
||||
default=list,
|
||||
description="工作链(即 WorkflowStep 的定义列表,包含图结构和原子动作)",
|
||||
)
|
||||
created_at: Mapped[str] = mapped_column(
|
||||
DateTime(timezone=True), server_default=func.now()
|
||||
)
|
||||
updated_at: Mapped[str] = mapped_column(
|
||||
DateTime(timezone=True), server_default=func.now(), onupdate=func.now()
|
||||
)
|
||||
event_data_json: str = Field(description="The JSON serialized kilostarEvent data")
|
||||
|
||||
@@ -0,0 +1,82 @@
|
||||
# Copyright 2026 zhaoxi826
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from sqlalchemy import select, func
|
||||
from typing import List
|
||||
from kilostar.core.postgres_database.model.chat_history import (
|
||||
ChatHistoryRegister,
|
||||
ChatHistoryMessage,
|
||||
)
|
||||
from sqlalchemy.ext.asyncio import async_sessionmaker, AsyncSession
|
||||
from ulid import ULID
|
||||
|
||||
|
||||
class ChatHistoryDatabase:
|
||||
def __init__(self, async_session_maker: async_sessionmaker[AsyncSession]):
|
||||
self.async_session_maker = async_session_maker
|
||||
|
||||
async def create_chat_session(
|
||||
self, user_id: str, title: str = "新对话"
|
||||
) -> ChatHistoryRegister:
|
||||
async with self.async_session_maker() as session:
|
||||
chat_id = str(ULID())
|
||||
chat = ChatHistoryRegister(chat_id=chat_id, user_id=user_id, title=title)
|
||||
session.add(chat)
|
||||
await session.commit()
|
||||
await session.refresh(chat)
|
||||
return chat
|
||||
|
||||
async def list_chat_sessions(self, user_id: str) -> List[ChatHistoryRegister]:
|
||||
async with self.async_session_maker() as session:
|
||||
statement = (
|
||||
select(ChatHistoryRegister)
|
||||
.where(ChatHistoryRegister.user_id == user_id)
|
||||
.order_by(ChatHistoryRegister.updated_at.desc())
|
||||
)
|
||||
results = await session.execute(statement)
|
||||
return results.scalars().all()
|
||||
|
||||
async def add_chat_message(
|
||||
self, chat_id: str, message: str, message_owner: str
|
||||
) -> ChatHistoryMessage:
|
||||
async with self.async_session_maker() as session:
|
||||
msg_id = str(ULID())
|
||||
msg = ChatHistoryMessage(
|
||||
message_id=msg_id,
|
||||
chat_id=chat_id,
|
||||
message=message,
|
||||
message_owner=message_owner,
|
||||
)
|
||||
session.add(msg)
|
||||
# Update the chat session's updated_at
|
||||
statement = select(ChatHistoryRegister).where(
|
||||
ChatHistoryRegister.chat_id == chat_id
|
||||
)
|
||||
results = await session.execute(statement)
|
||||
chat = results.scalar_one_or_none()
|
||||
if chat:
|
||||
chat.updated_at = func.now()
|
||||
await session.commit()
|
||||
await session.refresh(msg)
|
||||
return msg
|
||||
|
||||
async def list_chat_messages(self, chat_id: str) -> List[ChatHistoryMessage]:
|
||||
async with self.async_session_maker() as session:
|
||||
statement = (
|
||||
select(ChatHistoryMessage)
|
||||
.where(ChatHistoryMessage.chat_id == chat_id)
|
||||
.order_by(ChatHistoryMessage.created_at.asc())
|
||||
)
|
||||
results = await session.execute(statement)
|
||||
return results.scalars().all()
|
||||
@@ -0,0 +1,96 @@
|
||||
# Copyright 2026 zhaoxi826
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from sqlalchemy import select
|
||||
from typing import List, Optional
|
||||
from kilostar.core.postgres_database.model.workflow import (
|
||||
Workflow,
|
||||
WorkflowContextModel,
|
||||
)
|
||||
from sqlalchemy.ext.asyncio import async_sessionmaker, AsyncSession
|
||||
|
||||
|
||||
class WorkflowDatabase:
|
||||
def __init__(self, async_session_maker: async_sessionmaker[AsyncSession]):
|
||||
self.async_session_maker = async_session_maker
|
||||
|
||||
async def create_workflow(
|
||||
self, trace_id: str, user_id: str, title: str, command: str
|
||||
) -> Workflow:
|
||||
async with self.async_session_maker() as session:
|
||||
wf = Workflow(
|
||||
trace_id=trace_id,
|
||||
user_id=user_id,
|
||||
title=title,
|
||||
command=command,
|
||||
status="creating",
|
||||
)
|
||||
session.add(wf)
|
||||
await session.commit()
|
||||
await session.refresh(wf)
|
||||
return wf
|
||||
|
||||
async def get_workflow(self, trace_id: str) -> Optional[Workflow]:
|
||||
async with self.async_session_maker() as session:
|
||||
statement = select(Workflow).where(Workflow.trace_id == trace_id)
|
||||
results = await session.execute(statement)
|
||||
return results.scalar_one_or_none()
|
||||
|
||||
async def update_workflow_status(
|
||||
self, trace_id: str, status: str
|
||||
) -> Optional[Workflow]:
|
||||
async with self.async_session_maker() as session:
|
||||
statement = select(Workflow).where(Workflow.trace_id == trace_id)
|
||||
results = await session.execute(statement)
|
||||
record = results.scalar_one_or_none()
|
||||
if record:
|
||||
record.status = status
|
||||
await session.commit()
|
||||
await session.refresh(record)
|
||||
return record
|
||||
|
||||
async def list_workflows(self, user_id: str) -> List[Workflow]:
|
||||
async with self.async_session_maker() as session:
|
||||
statement = select(Workflow).where(Workflow.user_id == user_id)
|
||||
results = await session.execute(statement)
|
||||
return results.scalars().all()
|
||||
|
||||
async def upsert_workflow_context(
|
||||
self, trace_id: str, **kwargs
|
||||
) -> WorkflowContextModel:
|
||||
async with self.async_session_maker() as session:
|
||||
statement = select(WorkflowContextModel).where(
|
||||
WorkflowContextModel.trace_id == trace_id
|
||||
)
|
||||
results = await session.execute(statement)
|
||||
record = results.scalar_one_or_none()
|
||||
if record:
|
||||
for key, value in kwargs.items():
|
||||
setattr(record, key, value)
|
||||
else:
|
||||
record = WorkflowContextModel(trace_id=trace_id, **kwargs)
|
||||
session.add(record)
|
||||
await session.commit()
|
||||
await session.refresh(record)
|
||||
return record
|
||||
|
||||
async def get_workflow_context(
|
||||
self, trace_id: str
|
||||
) -> Optional[WorkflowContextModel]:
|
||||
async with self.async_session_maker() as session:
|
||||
statement = select(WorkflowContextModel).where(
|
||||
WorkflowContextModel.trace_id == trace_id
|
||||
)
|
||||
results = await session.execute(statement)
|
||||
return results.scalar_one_or_none()
|
||||
@@ -25,6 +25,8 @@ from .module.event import EventDatabase
|
||||
from .module.user import AuthDatabase
|
||||
from .module.provider import ProviderDatabase
|
||||
from .module.system_node import SystemNodeDatabase
|
||||
from .module.workflow import WorkflowDatabase
|
||||
from .module.chat_history import ChatHistoryDatabase
|
||||
|
||||
|
||||
@ray.remote
|
||||
@@ -51,6 +53,8 @@ class PostgresDatabase:
|
||||
self._individual_database = IndividualDatabase(self.async_session_maker)
|
||||
self._event_database = EventDatabase(self.async_session_maker)
|
||||
self._system_node_database = SystemNodeDatabase(self.async_session_maker)
|
||||
self._workflow_database = WorkflowDatabase(self.async_session_maker)
|
||||
self._chat_history_database = ChatHistoryDatabase(self.async_session_maker)
|
||||
|
||||
self.ready_event = asyncio.Event()
|
||||
|
||||
@@ -254,3 +258,51 @@ class PostgresDatabase:
|
||||
async def delete_event(self, trace_id: str):
|
||||
await self.ready_event.wait()
|
||||
return await self._event_database.delete_event(trace_id)
|
||||
|
||||
# Workflow Database Methods
|
||||
async def create_workflow(
|
||||
self, trace_id: str, user_id: str, title: str, command: str
|
||||
):
|
||||
await self.ready_event.wait()
|
||||
return await self._workflow_database.create_workflow(
|
||||
trace_id, user_id, title, command
|
||||
)
|
||||
|
||||
async def get_workflow(self, trace_id: str):
|
||||
await self.ready_event.wait()
|
||||
return await self._workflow_database.get_workflow(trace_id)
|
||||
|
||||
async def update_workflow_status(self, trace_id: str, status: str):
|
||||
await self.ready_event.wait()
|
||||
return await self._workflow_database.update_workflow_status(trace_id, status)
|
||||
|
||||
async def list_workflows(self, user_id: str):
|
||||
await self.ready_event.wait()
|
||||
return await self._workflow_database.list_workflows(user_id)
|
||||
|
||||
async def upsert_workflow_context(self, trace_id: str, **kwargs):
|
||||
await self.ready_event.wait()
|
||||
return await self._workflow_database.upsert_workflow_context(trace_id, **kwargs)
|
||||
|
||||
async def get_workflow_context(self, trace_id: str):
|
||||
await self.ready_event.wait()
|
||||
return await self._workflow_database.get_workflow_context(trace_id)
|
||||
|
||||
# Chat History Database Methods
|
||||
async def create_chat_session(self, user_id: str, title: str = "新对话"):
|
||||
await self.ready_event.wait()
|
||||
return await self._chat_history_database.create_chat_session(user_id, title)
|
||||
|
||||
async def list_chat_sessions(self, user_id: str):
|
||||
await self.ready_event.wait()
|
||||
return await self._chat_history_database.list_chat_sessions(user_id)
|
||||
|
||||
async def add_chat_message(self, chat_id: str, message: str, message_owner: str):
|
||||
await self.ready_event.wait()
|
||||
return await self._chat_history_database.add_chat_message(
|
||||
chat_id, message, message_owner
|
||||
)
|
||||
|
||||
async def list_chat_messages(self, chat_id: str):
|
||||
await self.ready_event.wait()
|
||||
return await self._chat_history_database.list_chat_messages(chat_id)
|
||||
|
||||
Reference in New Issue
Block a user