6d658b4f4d
- 工具系统从 kilostar/plugin/tool_plugin/ 迁移到 data/toolset/(manifest.json 声明式) - 新增 plugin_runtime 模块:BaseOrganization / GlobalPluginManager / loader / tool_bridge - 新增 org_task + org_task_event 表及 DAO(alembic 0009) - 新增 /api/v1/plugin 路由(submit/status/stream/install/reload) - 新增 data/plugin/example_dept 示例重型插件 - regulatory_node 支持聊天历史上下文注入 - send_file 改为 artifact 存盘 + SSE 推送下载链接 - 前端 WorkflowFileCard 组件 + ToolSettings README 渲染 - utils 整理:合并 access/role_check、standalone_proxy→ray_compat、删除废弃模块 - 项目结构文档移至 docs/STRUCTURE.md 并详细展开 Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
271 lines
9.5 KiB
Python
271 lines
9.5 KiB
Python
# Copyright 2026 zhaoxi826
|
||
#
|
||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||
# you may not use this file except in compliance with the License.
|
||
# You may obtain a copy of the License at
|
||
#
|
||
# http://www.apache.org/licenses/LICENSE-2.0
|
||
#
|
||
# Unless required by applicable law or agreed to in writing, software
|
||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||
# See the License for the specific language governing permissions and
|
||
# limitations under the License.
|
||
|
||
import json
|
||
import asyncio
|
||
from fastapi import APIRouter, Depends, HTTPException, Request
|
||
from fastapi.responses import StreamingResponse
|
||
from pydantic import BaseModel
|
||
from kilostar.utils.ray_hook import ray_actor_hook
|
||
from kilostar.utils.access import Accessor, TokenData
|
||
from kilostar.core.individual.regulatory_node.template import (
|
||
MessageRequest,
|
||
MessageResponse,
|
||
)
|
||
|
||
chat_router = APIRouter(prefix="/api/v1/chat", tags=["chat"])
|
||
|
||
# 单次注入历史的最大轮数(user+assistant 算一轮),防止 token 爆炸。
|
||
_HISTORY_MAX_TURNS = 20
|
||
|
||
|
||
def _build_message_history(rows) -> list:
|
||
"""把 DB 中的 ChatHistoryMessage 列表转成 pydantic-ai message_history 格式。
|
||
|
||
历史按时间升序,截取末尾最多 _HISTORY_MAX_TURNS*2 条;user 消息映射为
|
||
``ModelRequest(parts=[UserPromptPart])``,assistant(``regulatory_node``)映射为
|
||
``ModelResponse(parts=[TextPart])``。其它 owner 跳过。
|
||
"""
|
||
from pydantic_ai.messages import (
|
||
ModelRequest, ModelResponse, UserPromptPart, TextPart,
|
||
)
|
||
|
||
trimmed = rows[-(_HISTORY_MAX_TURNS * 2):]
|
||
history: list = []
|
||
for row in trimmed:
|
||
owner = row.message_owner
|
||
text = row.message
|
||
if not text:
|
||
continue
|
||
if owner == "user":
|
||
history.append(ModelRequest(parts=[UserPromptPart(content=text)]))
|
||
elif owner == "regulatory_node":
|
||
history.append(ModelResponse(parts=[TextPart(content=text)]))
|
||
return history
|
||
|
||
|
||
async def _load_message_history(chat_id: str) -> list:
|
||
postgres_database = ray_actor_hook("postgres_database").postgres_database
|
||
rows = await postgres_database.list_chat_messages.remote(chat_id=chat_id)
|
||
return _build_message_history(rows or [])
|
||
|
||
|
||
def _extract_reply(resp: MessageResponse | None) -> str | None:
|
||
"""从 RegulatoryNode.working 的输出里取出对用户的回复文本。
|
||
|
||
RegulatoryNode 现在的 output_type 只剩 ``MessageResponse``(聊天/简单任务/汇报),
|
||
没有则视为节点降级为静默——上层不写回 chat history。
|
||
"""
|
||
if resp is None:
|
||
return None
|
||
return resp.reply_message
|
||
|
||
|
||
async def _ask_regulatory(
|
||
*, user_id: str, chat_id: str, message: str, message_history: list | None = None
|
||
) -> str | None:
|
||
"""统一封装 chat 入口对 RegulatoryNode 的调用。"""
|
||
regulatory_node = ray_actor_hook("regulatory_node").regulatory_node
|
||
payload = MessageRequest(
|
||
platform="client",
|
||
user_name=user_id,
|
||
platform_id=chat_id,
|
||
message=message,
|
||
)
|
||
resp: MessageResponse | None = await regulatory_node.working.remote(
|
||
payload, message_history
|
||
)
|
||
return _extract_reply(resp)
|
||
|
||
|
||
class CreateChatRequest(BaseModel):
|
||
title: str = "新对话"
|
||
initial_message: str
|
||
|
||
|
||
class SendMessageRequest(BaseModel):
|
||
message: str
|
||
|
||
|
||
@chat_router.post("")
|
||
async def create_chat_session(
|
||
request: CreateChatRequest,
|
||
token_data: TokenData = Depends(Accessor.get_current_user),
|
||
):
|
||
postgres_database = ray_actor_hook("postgres_database").postgres_database
|
||
chat = await postgres_database.create_chat_session.remote(
|
||
user_id=token_data.user_id, title=request.title
|
||
)
|
||
|
||
# 存入用户消息
|
||
await postgres_database.add_chat_message.remote(
|
||
chat_id=chat.chat_id, message=request.initial_message, message_owner="user"
|
||
)
|
||
|
||
# 调用监管节点处理简单任务/交流
|
||
response_msg = await _ask_regulatory(
|
||
user_id=token_data.user_id,
|
||
chat_id=chat.chat_id,
|
||
message=request.initial_message,
|
||
)
|
||
|
||
# 存入回复消息
|
||
if response_msg:
|
||
await postgres_database.add_chat_message.remote(
|
||
chat_id=chat.chat_id, message=response_msg, message_owner="regulatory_node"
|
||
)
|
||
|
||
return {"chat_id": chat.chat_id, "reply": response_msg}
|
||
|
||
|
||
@chat_router.get("")
|
||
async def list_chat_sessions(
|
||
token_data: TokenData = Depends(Accessor.get_current_user),
|
||
):
|
||
postgres_database = ray_actor_hook("postgres_database").postgres_database
|
||
sessions = await postgres_database.list_chat_sessions.remote(
|
||
user_id=token_data.user_id
|
||
)
|
||
return {"sessions": sessions}
|
||
|
||
|
||
@chat_router.get("/{chat_id}")
|
||
async def get_chat_history(
|
||
chat_id: str, token_data: TokenData = Depends(Accessor.get_current_user)
|
||
):
|
||
postgres_database = ray_actor_hook("postgres_database").postgres_database
|
||
messages = await postgres_database.list_chat_messages.remote(chat_id=chat_id)
|
||
return {"messages": messages}
|
||
|
||
|
||
@chat_router.post("/{chat_id}/reply")
|
||
async def send_chat_message(
|
||
chat_id: str,
|
||
request: SendMessageRequest,
|
||
token_data: TokenData = Depends(Accessor.get_current_user),
|
||
):
|
||
postgres_database = ray_actor_hook("postgres_database").postgres_database
|
||
# 先取历史(不含当前输入),再写入用户消息,避免历史里出现重复
|
||
message_history = await _load_message_history(chat_id)
|
||
await postgres_database.add_chat_message.remote(
|
||
chat_id=chat_id, message=request.message, message_owner="user"
|
||
)
|
||
|
||
# 调用监管节点
|
||
response_msg = await _ask_regulatory(
|
||
user_id=token_data.user_id,
|
||
chat_id=chat_id,
|
||
message=request.message,
|
||
message_history=message_history,
|
||
)
|
||
|
||
# 存回复
|
||
if response_msg:
|
||
await postgres_database.add_chat_message.remote(
|
||
chat_id=chat_id, message=response_msg, message_owner="regulatory_node"
|
||
)
|
||
|
||
return {"reply": response_msg}
|
||
|
||
|
||
@chat_router.delete("/{chat_id}")
|
||
async def delete_chat_session(
|
||
chat_id: str,
|
||
token_data: TokenData = Depends(Accessor.get_current_user),
|
||
):
|
||
postgres_database = ray_actor_hook("postgres_database").postgres_database
|
||
session = await postgres_database.get_chat_session.remote(chat_id=chat_id)
|
||
if not session:
|
||
raise HTTPException(status_code=404, detail="Chat session not found")
|
||
if session.user_id != token_data.user_id:
|
||
raise HTTPException(status_code=403, detail="Forbidden")
|
||
await postgres_database.delete_chat_session.remote(chat_id=chat_id)
|
||
return {"message": "success"}
|
||
|
||
|
||
@chat_router.post("/{chat_id}/stream")
|
||
async def stream_chat_message(
|
||
chat_id: str,
|
||
request_body: SendMessageRequest,
|
||
request: Request,
|
||
token_data: TokenData = Depends(Accessor.get_current_user),
|
||
):
|
||
"""SSE 流式聊天端点:standalone 模式下逐 token 流式输出;distributed 模式 fallback 到整段回复。"""
|
||
from kilostar.utils.ray_compat import _STANDALONE
|
||
|
||
postgres_database = ray_actor_hook("postgres_database").postgres_database
|
||
|
||
message_history = await _load_message_history(chat_id)
|
||
|
||
await postgres_database.add_chat_message.remote(
|
||
chat_id=chat_id, message=request_body.message, message_owner="user"
|
||
)
|
||
|
||
payload = MessageRequest(
|
||
platform="client",
|
||
user_name=token_data.user_id,
|
||
platform_id=chat_id,
|
||
message=request_body.message,
|
||
)
|
||
|
||
regulatory_node = ray_actor_hook("regulatory_node").regulatory_node
|
||
|
||
if not _STANDALONE:
|
||
async def fallback_generator():
|
||
resp = await regulatory_node.working.remote(payload, message_history)
|
||
full_response = resp.reply_message if resp else ""
|
||
if full_response:
|
||
await postgres_database.add_chat_message.remote(
|
||
chat_id=chat_id, message=full_response, message_owner="regulatory_node"
|
||
)
|
||
yield f"data: {json.dumps({'token': full_response})}\n\n"
|
||
yield f"data: {json.dumps({'done': True, 'full_message': full_response})}\n\n"
|
||
|
||
return StreamingResponse(fallback_generator(), media_type="text/event-stream")
|
||
|
||
token_queue = asyncio.Queue()
|
||
stream_task = regulatory_node.stream_working.remote(payload, token_queue, message_history)
|
||
|
||
async def event_generator():
|
||
full_response = ""
|
||
try:
|
||
while True:
|
||
if await request.is_disconnected():
|
||
stream_task.cancel()
|
||
break
|
||
try:
|
||
token = await asyncio.wait_for(token_queue.get(), timeout=0.5)
|
||
except asyncio.TimeoutError:
|
||
continue
|
||
if token is None:
|
||
break
|
||
full_response += token
|
||
yield f"data: {json.dumps({'token': token})}\n\n"
|
||
except Exception as e:
|
||
from kilostar.utils.logger import get_logger
|
||
get_logger("chat_stream").exception(f"Stream error: {e}")
|
||
if not full_response:
|
||
full_response = "抱歉,生成回复时出错。"
|
||
yield f"data: {json.dumps({'token': full_response})}\n\n"
|
||
|
||
if full_response:
|
||
await postgres_database.add_chat_message.remote(
|
||
chat_id=chat_id,
|
||
message=full_response,
|
||
message_owner="regulatory_node",
|
||
)
|
||
yield f"data: {json.dumps({'done': True, 'full_message': full_response})}\n\n"
|
||
|
||
return StreamingResponse(event_generator(), media_type="text/event-stream")
|