fix: chat stream 走 regulatory agent 支持工具调用,修复 workflow ValidationError
1. chat.py stream 端点改为调用 regulatory_node.stream_working()(pydantic-ai run_stream),支持工具调用 + 逐 token 流式输出 2. regulatory_node 新增 stream_working 方法,通过 asyncio.Queue 推送 token 3. ConsciousnessNodeDeps.available_skills 加默认值 None,修复 ForWorkflowInput/ ForregulatoryInput 路径的 ValidationError Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
This commit is contained in:
+24
-61
@@ -14,7 +14,6 @@
|
|||||||
|
|
||||||
import json
|
import json
|
||||||
import asyncio
|
import asyncio
|
||||||
import httpx
|
|
||||||
from fastapi import APIRouter, Depends, HTTPException, Request
|
from fastapi import APIRouter, Depends, HTTPException, Request
|
||||||
from fastapi.responses import StreamingResponse
|
from fastapi.responses import StreamingResponse
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
@@ -164,7 +163,7 @@ async def stream_chat_message(
|
|||||||
request: Request,
|
request: Request,
|
||||||
token_data: TokenData = Depends(Accessor.get_current_user),
|
token_data: TokenData = Depends(Accessor.get_current_user),
|
||||||
):
|
):
|
||||||
"""SSE 流式聊天端点:逐 token 推送 AI 回复。"""
|
"""SSE 流式聊天端点:通过 regulatory_node agent 流式输出,支持工具调用。"""
|
||||||
postgres_database = ray_actor_hook("postgres_database").postgres_database
|
postgres_database = ray_actor_hook("postgres_database").postgres_database
|
||||||
|
|
||||||
# 存用户消息
|
# 存用户消息
|
||||||
@@ -172,71 +171,35 @@ async def stream_chat_message(
|
|||||||
chat_id=chat_id, message=request_body.message, message_owner="user"
|
chat_id=chat_id, message=request_body.message, message_owner="user"
|
||||||
)
|
)
|
||||||
|
|
||||||
# 获取 regulatory_node 的 provider 配置
|
# 构造 MessageRequest payload
|
||||||
node_config = await postgres_database.get_system_node_config.remote("regulatory_node")
|
payload = MessageRequest(
|
||||||
if not node_config:
|
platform="client",
|
||||||
raise HTTPException(status_code=500, detail="Regulatory node not configured")
|
user_name=token_data.user_id,
|
||||||
|
platform_id=chat_id,
|
||||||
|
message=request_body.message,
|
||||||
|
)
|
||||||
|
|
||||||
# 获取 provider 详情
|
regulatory_node = ray_actor_hook("regulatory_node").regulatory_node
|
||||||
from kilostar.core.global_state_machine.gsm_snapshot import fetch_snapshot
|
token_queue = asyncio.Queue()
|
||||||
|
|
||||||
global_state_machine = ray_actor_hook("global_state_machine").global_state_machine
|
# stream_working.remote() returns an asyncio.Task in standalone mode
|
||||||
snapshot = await fetch_snapshot(gsm_actor=global_state_machine)
|
stream_task = regulatory_node.stream_working.remote(payload, token_queue)
|
||||||
provider = snapshot.providers.get(node_config.provider_title)
|
|
||||||
if not provider:
|
|
||||||
raise HTTPException(status_code=500, detail="Provider not available")
|
|
||||||
|
|
||||||
# 加载历史消息作为上下文
|
|
||||||
history_msgs = await postgres_database.list_chat_messages.remote(chat_id=chat_id)
|
|
||||||
messages = []
|
|
||||||
|
|
||||||
system_prompt = "你是 KiloStar 助手,友善、简洁地回答用户的问题。"
|
|
||||||
if node_config.persona_id:
|
|
||||||
tpl = await postgres_database.get_template.remote(node_config.persona_id)
|
|
||||||
if tpl and tpl.system_prompt:
|
|
||||||
system_prompt += "\n" + tpl.system_prompt
|
|
||||||
messages.append({"role": "system", "content": system_prompt})
|
|
||||||
|
|
||||||
for msg in history_msgs:
|
|
||||||
role = "user" if msg.message_owner == "user" else "assistant"
|
|
||||||
messages.append({"role": role, "content": msg.message})
|
|
||||||
|
|
||||||
async def event_generator():
|
async def event_generator():
|
||||||
full_response = ""
|
full_response = ""
|
||||||
try:
|
try:
|
||||||
async with httpx.AsyncClient(timeout=120.0) as client:
|
while True:
|
||||||
url = provider.provider_url.rstrip("/") + "/chat/completions"
|
if await request.is_disconnected():
|
||||||
payload = {
|
stream_task.cancel()
|
||||||
"model": node_config.model_id,
|
break
|
||||||
"messages": messages,
|
try:
|
||||||
"stream": True,
|
token = await asyncio.wait_for(token_queue.get(), timeout=0.5)
|
||||||
}
|
except asyncio.TimeoutError:
|
||||||
async with client.stream(
|
continue
|
||||||
"POST",
|
if token is None:
|
||||||
url,
|
break
|
||||||
json=payload,
|
full_response += token
|
||||||
headers={
|
yield f"data: {json.dumps({'token': token})}\n\n"
|
||||||
"Authorization": f"Bearer {provider.provider_apikey}",
|
|
||||||
"Content-Type": "application/json",
|
|
||||||
},
|
|
||||||
) as resp:
|
|
||||||
async for line in resp.aiter_lines():
|
|
||||||
if await request.is_disconnected():
|
|
||||||
break
|
|
||||||
if not line.startswith("data: "):
|
|
||||||
continue
|
|
||||||
data_str = line[6:]
|
|
||||||
if data_str.strip() == "[DONE]":
|
|
||||||
break
|
|
||||||
try:
|
|
||||||
chunk = json.loads(data_str)
|
|
||||||
delta = chunk.get("choices", [{}])[0].get("delta", {})
|
|
||||||
token = delta.get("content", "")
|
|
||||||
if token:
|
|
||||||
full_response += token
|
|
||||||
yield f"data: {json.dumps({'token': token})}\n\n"
|
|
||||||
except (json.JSONDecodeError, IndexError, KeyError):
|
|
||||||
continue
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
from kilostar.utils.logger import get_logger
|
from kilostar.utils.logger import get_logger
|
||||||
get_logger("chat_stream").exception(f"Stream error: {e}")
|
get_logger("chat_stream").exception(f"Stream error: {e}")
|
||||||
|
|||||||
@@ -28,7 +28,7 @@ class ConsciousnessNodeDeps(DepsModel):
|
|||||||
"""ConsciousnessNode 在 pydantic-ai Agent 中使用的依赖:原始指令、当前指令以及可用 Skill 列表。"""
|
"""ConsciousnessNode 在 pydantic-ai Agent 中使用的依赖:原始指令、当前指令以及可用 Skill 列表。"""
|
||||||
original_command: str
|
original_command: str
|
||||||
command: str
|
command: str
|
||||||
available_skills: Optional[List[str]]
|
available_skills: Optional[List[str]] = None
|
||||||
|
|
||||||
class ConsciousnessNodeInput(RequestModel):
|
class ConsciousnessNodeInput(RequestModel):
|
||||||
"""ConsciousnessNode 各类入参的共同基类,仅用于打 schema 标签。"""
|
"""ConsciousnessNode 各类入参的共同基类,仅用于打 schema 标签。"""
|
||||||
|
|||||||
@@ -12,6 +12,7 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
|
import asyncio
|
||||||
import datetime
|
import datetime
|
||||||
from typing import Union
|
from typing import Union
|
||||||
from kilostar.utils.standalone_proxy import actor_class
|
from kilostar.utils.standalone_proxy import actor_class
|
||||||
@@ -125,6 +126,37 @@ class RegulatoryNode:
|
|||||||
"""
|
"""
|
||||||
return await self._run(payload)
|
return await self._run(payload)
|
||||||
|
|
||||||
|
async def stream_working(self, payload: MessageRequest, token_queue: "asyncio.Queue") -> None:
|
||||||
|
"""流式工具调用版本:逐 token 推送到 queue,工具调用结果也会通过 token 输出。
|
||||||
|
|
||||||
|
完成后 push None 作为终止信号。
|
||||||
|
"""
|
||||||
|
platform = payload.platform
|
||||||
|
user_name = payload.user_name
|
||||||
|
message = payload.message
|
||||||
|
time_str = datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
||||||
|
|
||||||
|
if self.agent is None:
|
||||||
|
await token_queue.put(None)
|
||||||
|
return
|
||||||
|
|
||||||
|
try:
|
||||||
|
deps = RegulatoryNodeDeps(
|
||||||
|
platform=platform,
|
||||||
|
user_name=user_name,
|
||||||
|
time=time_str
|
||||||
|
)
|
||||||
|
async with self.agent.run_stream(
|
||||||
|
user_prompt=message, deps=deps, output_type=str
|
||||||
|
) as stream_result:
|
||||||
|
async for delta in stream_result.stream_text(delta=True):
|
||||||
|
await token_queue.put(delta)
|
||||||
|
except Exception as e:
|
||||||
|
self.logger.exception(f"RegulatoryNode.stream_working failed: {e}")
|
||||||
|
await token_queue.put(f"\n\n[错误: {str(e)}]")
|
||||||
|
finally:
|
||||||
|
await token_queue.put(None)
|
||||||
|
|
||||||
async def _run(
|
async def _run(
|
||||||
self, payload: MessageRequest
|
self, payload: MessageRequest
|
||||||
) -> Union[MessageResponse, None]:
|
) -> Union[MessageResponse, None]:
|
||||||
|
|||||||
Reference in New Issue
Block a user