fix: chat stream 走 regulatory agent 支持工具调用,修复 workflow ValidationError

1. chat.py stream 端点改为调用 regulatory_node.stream_working()(pydantic-ai
   run_stream),支持工具调用 + 逐 token 流式输出
2. regulatory_node 新增 stream_working 方法,通过 asyncio.Queue 推送 token
3. ConsciousnessNodeDeps.available_skills 加默认值 None,修复 ForWorkflowInput/
   ForregulatoryInput 路径的 ValidationError

Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
This commit is contained in:
2026-06-05 07:37:59 +00:00
parent 6f1bc27101
commit b61524e5d9
3 changed files with 57 additions and 62 deletions
+19 -56
View File
@@ -14,7 +14,6 @@
import json import json
import asyncio import asyncio
import httpx
from fastapi import APIRouter, Depends, HTTPException, Request from fastapi import APIRouter, Depends, HTTPException, Request
from fastapi.responses import StreamingResponse from fastapi.responses import StreamingResponse
from pydantic import BaseModel from pydantic import BaseModel
@@ -164,7 +163,7 @@ async def stream_chat_message(
request: Request, request: Request,
token_data: TokenData = Depends(Accessor.get_current_user), token_data: TokenData = Depends(Accessor.get_current_user),
): ):
"""SSE 流式聊天端点:逐 token 推送 AI 回复""" """SSE 流式聊天端点:通过 regulatory_node agent 流式输出,支持工具调用"""
postgres_database = ray_actor_hook("postgres_database").postgres_database postgres_database = ray_actor_hook("postgres_database").postgres_database
# 存用户消息 # 存用户消息
@@ -172,71 +171,35 @@ async def stream_chat_message(
chat_id=chat_id, message=request_body.message, message_owner="user" chat_id=chat_id, message=request_body.message, message_owner="user"
) )
# 获取 regulatory_node 的 provider 配置 # 构造 MessageRequest payload
node_config = await postgres_database.get_system_node_config.remote("regulatory_node") payload = MessageRequest(
if not node_config: platform="client",
raise HTTPException(status_code=500, detail="Regulatory node not configured") user_name=token_data.user_id,
platform_id=chat_id,
message=request_body.message,
)
# 获取 provider 详情 regulatory_node = ray_actor_hook("regulatory_node").regulatory_node
from kilostar.core.global_state_machine.gsm_snapshot import fetch_snapshot token_queue = asyncio.Queue()
global_state_machine = ray_actor_hook("global_state_machine").global_state_machine # stream_working.remote() returns an asyncio.Task in standalone mode
snapshot = await fetch_snapshot(gsm_actor=global_state_machine) stream_task = regulatory_node.stream_working.remote(payload, token_queue)
provider = snapshot.providers.get(node_config.provider_title)
if not provider:
raise HTTPException(status_code=500, detail="Provider not available")
# 加载历史消息作为上下文
history_msgs = await postgres_database.list_chat_messages.remote(chat_id=chat_id)
messages = []
system_prompt = "你是 KiloStar 助手,友善、简洁地回答用户的问题。"
if node_config.persona_id:
tpl = await postgres_database.get_template.remote(node_config.persona_id)
if tpl and tpl.system_prompt:
system_prompt += "\n" + tpl.system_prompt
messages.append({"role": "system", "content": system_prompt})
for msg in history_msgs:
role = "user" if msg.message_owner == "user" else "assistant"
messages.append({"role": role, "content": msg.message})
async def event_generator(): async def event_generator():
full_response = "" full_response = ""
try: try:
async with httpx.AsyncClient(timeout=120.0) as client: while True:
url = provider.provider_url.rstrip("/") + "/chat/completions"
payload = {
"model": node_config.model_id,
"messages": messages,
"stream": True,
}
async with client.stream(
"POST",
url,
json=payload,
headers={
"Authorization": f"Bearer {provider.provider_apikey}",
"Content-Type": "application/json",
},
) as resp:
async for line in resp.aiter_lines():
if await request.is_disconnected(): if await request.is_disconnected():
break stream_task.cancel()
if not line.startswith("data: "):
continue
data_str = line[6:]
if data_str.strip() == "[DONE]":
break break
try: try:
chunk = json.loads(data_str) token = await asyncio.wait_for(token_queue.get(), timeout=0.5)
delta = chunk.get("choices", [{}])[0].get("delta", {}) except asyncio.TimeoutError:
token = delta.get("content", "") continue
if token: if token is None:
break
full_response += token full_response += token
yield f"data: {json.dumps({'token': token})}\n\n" yield f"data: {json.dumps({'token': token})}\n\n"
except (json.JSONDecodeError, IndexError, KeyError):
continue
except Exception as e: except Exception as e:
from kilostar.utils.logger import get_logger from kilostar.utils.logger import get_logger
get_logger("chat_stream").exception(f"Stream error: {e}") get_logger("chat_stream").exception(f"Stream error: {e}")
@@ -28,7 +28,7 @@ class ConsciousnessNodeDeps(DepsModel):
"""ConsciousnessNode 在 pydantic-ai Agent 中使用的依赖:原始指令、当前指令以及可用 Skill 列表。""" """ConsciousnessNode 在 pydantic-ai Agent 中使用的依赖:原始指令、当前指令以及可用 Skill 列表。"""
original_command: str original_command: str
command: str command: str
available_skills: Optional[List[str]] available_skills: Optional[List[str]] = None
class ConsciousnessNodeInput(RequestModel): class ConsciousnessNodeInput(RequestModel):
"""ConsciousnessNode 各类入参的共同基类,仅用于打 schema 标签。""" """ConsciousnessNode 各类入参的共同基类,仅用于打 schema 标签。"""
@@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import asyncio
import datetime import datetime
from typing import Union from typing import Union
from kilostar.utils.standalone_proxy import actor_class from kilostar.utils.standalone_proxy import actor_class
@@ -125,6 +126,37 @@ class RegulatoryNode:
""" """
return await self._run(payload) return await self._run(payload)
async def stream_working(self, payload: MessageRequest, token_queue: "asyncio.Queue") -> None:
"""流式工具调用版本:逐 token 推送到 queue,工具调用结果也会通过 token 输出。
完成后 push None 作为终止信号。
"""
platform = payload.platform
user_name = payload.user_name
message = payload.message
time_str = datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")
if self.agent is None:
await token_queue.put(None)
return
try:
deps = RegulatoryNodeDeps(
platform=platform,
user_name=user_name,
time=time_str
)
async with self.agent.run_stream(
user_prompt=message, deps=deps, output_type=str
) as stream_result:
async for delta in stream_result.stream_text(delta=True):
await token_queue.put(delta)
except Exception as e:
self.logger.exception(f"RegulatoryNode.stream_working failed: {e}")
await token_queue.put(f"\n\n[错误: {str(e)}]")
finally:
await token_queue.put(None)
async def _run( async def _run(
self, payload: MessageRequest self, payload: MessageRequest
) -> Union[MessageResponse, None]: ) -> Union[MessageResponse, None]: