# Copyright 2026 zhaoxi826 # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import json import asyncio from fastapi import APIRouter, Depends, HTTPException, Request from fastapi.responses import StreamingResponse from pydantic import BaseModel from kilostar.utils.ray_hook import ray_actor_hook from kilostar.utils.access import Accessor, TokenData from kilostar.core.individual.regulatory_node.template import ( MessageRequest, MessageResponse, ) chat_router = APIRouter(prefix="/api/v1/chat", tags=["chat"]) # 单次注入历史的最大轮数(user+assistant 算一轮),防止 token 爆炸。 _HISTORY_MAX_TURNS = 20 def _build_message_history(rows) -> list: """把 DB 中的 ChatHistoryMessage 列表转成 pydantic-ai message_history 格式。 历史按时间升序,截取末尾最多 _HISTORY_MAX_TURNS*2 条;user 消息映射为 ``ModelRequest(parts=[UserPromptPart])``,assistant(``regulatory_node``)映射为 ``ModelResponse(parts=[TextPart])``。其它 owner 跳过。 """ from pydantic_ai.messages import ( ModelRequest, ModelResponse, UserPromptPart, TextPart, ) trimmed = rows[-(_HISTORY_MAX_TURNS * 2):] history: list = [] for row in trimmed: owner = row.message_owner text = row.message if not text: continue if owner == "user": history.append(ModelRequest(parts=[UserPromptPart(content=text)])) elif owner == "regulatory_node": history.append(ModelResponse(parts=[TextPart(content=text)])) return history async def _load_message_history(chat_id: str) -> list: postgres_database = ray_actor_hook("postgres_database").postgres_database rows = await postgres_database.list_chat_messages.remote(chat_id=chat_id) return _build_message_history(rows or []) def _extract_reply(resp: MessageResponse | None) -> str | None: """从 RegulatoryNode.working 的输出里取出对用户的回复文本。 RegulatoryNode 现在的 output_type 只剩 ``MessageResponse``(聊天/简单任务/汇报), 没有则视为节点降级为静默——上层不写回 chat history。 """ if resp is None: return None return resp.reply_message async def _ask_regulatory( *, user_id: str, chat_id: str, message: str, message_history: list | None = None ) -> str | None: """统一封装 chat 入口对 RegulatoryNode 的调用。""" regulatory_node = ray_actor_hook("regulatory_node").regulatory_node payload = MessageRequest( platform="client", user_name=user_id, platform_id=chat_id, message=message, ) resp: MessageResponse | None = await regulatory_node.working.remote( payload, message_history ) return _extract_reply(resp) class CreateChatRequest(BaseModel): title: str = "新对话" initial_message: str class SendMessageRequest(BaseModel): message: str @chat_router.post("") async def create_chat_session( request: CreateChatRequest, token_data: TokenData = Depends(Accessor.get_current_user), ): postgres_database = ray_actor_hook("postgres_database").postgres_database chat = await postgres_database.create_chat_session.remote( user_id=token_data.user_id, title=request.title ) # 存入用户消息 await postgres_database.add_chat_message.remote( chat_id=chat.chat_id, message=request.initial_message, message_owner="user" ) # 调用监管节点处理简单任务/交流 response_msg = await _ask_regulatory( user_id=token_data.user_id, chat_id=chat.chat_id, message=request.initial_message, ) # 存入回复消息 if response_msg: await postgres_database.add_chat_message.remote( chat_id=chat.chat_id, message=response_msg, message_owner="regulatory_node" ) return {"chat_id": chat.chat_id, "reply": response_msg} @chat_router.get("") async def list_chat_sessions( token_data: TokenData = Depends(Accessor.get_current_user), ): postgres_database = ray_actor_hook("postgres_database").postgres_database sessions = await postgres_database.list_chat_sessions.remote( user_id=token_data.user_id ) return {"sessions": sessions} @chat_router.get("/{chat_id}") async def get_chat_history( chat_id: str, token_data: TokenData = Depends(Accessor.get_current_user) ): postgres_database = ray_actor_hook("postgres_database").postgres_database messages = await postgres_database.list_chat_messages.remote(chat_id=chat_id) return {"messages": messages} @chat_router.post("/{chat_id}/reply") async def send_chat_message( chat_id: str, request: SendMessageRequest, token_data: TokenData = Depends(Accessor.get_current_user), ): postgres_database = ray_actor_hook("postgres_database").postgres_database # 先取历史(不含当前输入),再写入用户消息,避免历史里出现重复 message_history = await _load_message_history(chat_id) await postgres_database.add_chat_message.remote( chat_id=chat_id, message=request.message, message_owner="user" ) # 调用监管节点 response_msg = await _ask_regulatory( user_id=token_data.user_id, chat_id=chat_id, message=request.message, message_history=message_history, ) # 存回复 if response_msg: await postgres_database.add_chat_message.remote( chat_id=chat_id, message=response_msg, message_owner="regulatory_node" ) return {"reply": response_msg} @chat_router.delete("/{chat_id}") async def delete_chat_session( chat_id: str, token_data: TokenData = Depends(Accessor.get_current_user), ): postgres_database = ray_actor_hook("postgres_database").postgres_database session = await postgres_database.get_chat_session.remote(chat_id=chat_id) if not session: raise HTTPException(status_code=404, detail="Chat session not found") if session.user_id != token_data.user_id: raise HTTPException(status_code=403, detail="Forbidden") await postgres_database.delete_chat_session.remote(chat_id=chat_id) return {"message": "success"} @chat_router.post("/{chat_id}/stream") async def stream_chat_message( chat_id: str, request_body: SendMessageRequest, request: Request, token_data: TokenData = Depends(Accessor.get_current_user), ): """SSE 流式聊天端点:standalone 模式下逐 token 流式输出;distributed 模式 fallback 到整段回复。""" from kilostar.utils.ray_compat import _STANDALONE postgres_database = ray_actor_hook("postgres_database").postgres_database message_history = await _load_message_history(chat_id) await postgres_database.add_chat_message.remote( chat_id=chat_id, message=request_body.message, message_owner="user" ) payload = MessageRequest( platform="client", user_name=token_data.user_id, platform_id=chat_id, message=request_body.message, ) regulatory_node = ray_actor_hook("regulatory_node").regulatory_node if not _STANDALONE: async def fallback_generator(): resp = await regulatory_node.working.remote(payload, message_history) full_response = resp.reply_message if resp else "" if full_response: await postgres_database.add_chat_message.remote( chat_id=chat_id, message=full_response, message_owner="regulatory_node" ) yield f"data: {json.dumps({'token': full_response})}\n\n" yield f"data: {json.dumps({'done': True, 'full_message': full_response})}\n\n" return StreamingResponse(fallback_generator(), media_type="text/event-stream") token_queue = asyncio.Queue() stream_task = regulatory_node.stream_working.remote(payload, token_queue, message_history) async def event_generator(): full_response = "" try: while True: if await request.is_disconnected(): stream_task.cancel() break try: token = await asyncio.wait_for(token_queue.get(), timeout=0.5) except asyncio.TimeoutError: continue if token is None: break full_response += token yield f"data: {json.dumps({'token': token})}\n\n" except Exception as e: from kilostar.utils.logger import get_logger get_logger("chat_stream").exception(f"Stream error: {e}") if not full_response: full_response = "抱歉,生成回复时出错。" yield f"data: {json.dumps({'token': full_response})}\n\n" if full_response: await postgres_database.add_chat_message.remote( chat_id=chat_id, message=full_response, message_owner="regulatory_node", ) yield f"data: {json.dumps({'done': True, 'full_message': full_response})}\n\n" return StreamingResponse(event_generator(), media_type="text/event-stream")