# Copyright 2026 zhaoxi826 # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import json import asyncio import httpx from fastapi import APIRouter, Depends, HTTPException, Request from fastapi.responses import StreamingResponse from pydantic import BaseModel from kilostar.utils.ray_hook import ray_actor_hook from kilostar.utils.access import Accessor, TokenData from kilostar.core.individual.regulatory_node.template import ( MessageRequest, MessageResponse, ) chat_router = APIRouter(prefix="/api/v1/chat", tags=["chat"]) def _extract_reply(resp: MessageResponse | None) -> str | None: """从 RegulatoryNode.working 的输出里取出对用户的回复文本。 RegulatoryNode 现在的 output_type 只剩 ``MessageResponse``(聊天/简单任务/汇报), 没有则视为节点降级为静默——上层不写回 chat history。 """ if resp is None: return None return resp.reply_message async def _ask_regulatory( *, user_id: str, chat_id: str, message: str ) -> str | None: """统一封装 chat 入口对 RegulatoryNode 的调用。""" regulatory_node = ray_actor_hook("regulatory_node").regulatory_node payload = MessageRequest( platform="client", user_name=user_id, platform_id=chat_id, message=message, ) resp: MessageResponse | None = await regulatory_node.working.remote(payload) return _extract_reply(resp) class CreateChatRequest(BaseModel): title: str = "新对话" initial_message: str class SendMessageRequest(BaseModel): message: str @chat_router.post("") async def create_chat_session( request: CreateChatRequest, token_data: TokenData = Depends(Accessor.get_current_user), ): postgres_database = ray_actor_hook("postgres_database").postgres_database chat = await postgres_database.create_chat_session.remote( user_id=token_data.user_id, title=request.title ) # 存入用户消息 await postgres_database.add_chat_message.remote( chat_id=chat.chat_id, message=request.initial_message, message_owner="user" ) # 调用监管节点处理简单任务/交流 response_msg = await _ask_regulatory( user_id=token_data.user_id, chat_id=chat.chat_id, message=request.initial_message, ) # 存入回复消息 if response_msg: await postgres_database.add_chat_message.remote( chat_id=chat.chat_id, message=response_msg, message_owner="regulatory_node" ) return {"chat_id": chat.chat_id, "reply": response_msg} @chat_router.get("") async def list_chat_sessions( token_data: TokenData = Depends(Accessor.get_current_user), ): postgres_database = ray_actor_hook("postgres_database").postgres_database sessions = await postgres_database.list_chat_sessions.remote( user_id=token_data.user_id ) return {"sessions": sessions} @chat_router.get("/{chat_id}") async def get_chat_history( chat_id: str, token_data: TokenData = Depends(Accessor.get_current_user) ): postgres_database = ray_actor_hook("postgres_database").postgres_database messages = await postgres_database.list_chat_messages.remote(chat_id=chat_id) return {"messages": messages} @chat_router.post("/{chat_id}/reply") async def send_chat_message( chat_id: str, request: SendMessageRequest, token_data: TokenData = Depends(Accessor.get_current_user), ): postgres_database = ray_actor_hook("postgres_database").postgres_database # 存用户消息 await postgres_database.add_chat_message.remote( chat_id=chat_id, message=request.message, message_owner="user" ) # 调用监管节点 response_msg = await _ask_regulatory( user_id=token_data.user_id, chat_id=chat_id, message=request.message, ) # 存回复 if response_msg: await postgres_database.add_chat_message.remote( chat_id=chat_id, message=response_msg, message_owner="regulatory_node" ) return {"reply": response_msg} @chat_router.delete("/{chat_id}") async def delete_chat_session( chat_id: str, token_data: TokenData = Depends(Accessor.get_current_user), ): postgres_database = ray_actor_hook("postgres_database").postgres_database session = await postgres_database.get_chat_session.remote(chat_id=chat_id) if not session: raise HTTPException(status_code=404, detail="Chat session not found") if session.user_id != token_data.user_id: raise HTTPException(status_code=403, detail="Forbidden") await postgres_database.delete_chat_session.remote(chat_id=chat_id) return {"message": "success"} @chat_router.post("/{chat_id}/stream") async def stream_chat_message( chat_id: str, request_body: SendMessageRequest, request: Request, token_data: TokenData = Depends(Accessor.get_current_user), ): """SSE 流式聊天端点:逐 token 推送 AI 回复。""" postgres_database = ray_actor_hook("postgres_database").postgres_database # 存用户消息 await postgres_database.add_chat_message.remote( chat_id=chat_id, message=request_body.message, message_owner="user" ) # 获取 regulatory_node 的 provider 配置 node_config = await postgres_database.get_system_node_config.remote("regulatory_node") if not node_config: raise HTTPException(status_code=500, detail="Regulatory node not configured") # 获取 provider 详情 from kilostar.core.global_state_machine.gsm_snapshot import fetch_snapshot global_state_machine = ray_actor_hook("global_state_machine").global_state_machine snapshot = await fetch_snapshot(gsm_actor=global_state_machine) provider = snapshot.providers.get(node_config.provider_title) if not provider: raise HTTPException(status_code=500, detail="Provider not available") # 加载历史消息作为上下文 history_msgs = await postgres_database.list_chat_messages.remote(chat_id=chat_id) messages = [] system_prompt = "你是 KiloStar 助手,友善、简洁地回答用户的问题。" if node_config.persona_id: tpl = await postgres_database.get_template.remote(node_config.persona_id) if tpl and tpl.system_prompt: system_prompt += "\n" + tpl.system_prompt messages.append({"role": "system", "content": system_prompt}) for msg in history_msgs: role = "user" if msg.message_owner == "user" else "assistant" messages.append({"role": role, "content": msg.message}) async def event_generator(): full_response = "" try: async with httpx.AsyncClient(timeout=120.0) as client: url = provider.provider_url.rstrip("/") + "/chat/completions" payload = { "model": node_config.model_id, "messages": messages, "stream": True, } async with client.stream( "POST", url, json=payload, headers={ "Authorization": f"Bearer {provider.provider_apikey}", "Content-Type": "application/json", }, ) as resp: async for line in resp.aiter_lines(): if await request.is_disconnected(): break if not line.startswith("data: "): continue data_str = line[6:] if data_str.strip() == "[DONE]": break try: chunk = json.loads(data_str) delta = chunk.get("choices", [{}])[0].get("delta", {}) token = delta.get("content", "") if token: full_response += token yield f"data: {json.dumps({'token': token})}\n\n" except (json.JSONDecodeError, IndexError, KeyError): continue except Exception as e: from kilostar.utils.logger import get_logger get_logger("chat_stream").exception(f"Stream error: {e}") if not full_response: full_response = "抱歉,生成回复时出错。" yield f"data: {json.dumps({'token': full_response})}\n\n" # 流结束,存入数据库 if full_response: await postgres_database.add_chat_message.remote( chat_id=chat_id, message=full_response, message_owner="regulatory_node", ) yield f"data: {json.dumps({'done': True, 'full_message': full_response})}\n\n" return StreamingResponse(event_generator(), media_type="text/event-stream")