Pretor/pretor/api/platform/frontend.py

45 lines
1.9 KiB
Python

from fastapi import APIRouter, Request, Depends, HTTPException, status, WebSocket, WebSocketDisconnect
from pydantic import BaseModel
from pretor.utils.access import Accessor, TokenData
from pretor.api.platform.event import PretorEvent
from loguru import logger
client_router = APIRouter(prefix="/api/v1/adapter/client", tags=["client"])
class Message(BaseModel):
message: str
@client_router.post("")
async def create_message(message: Message,
request: Request,
token_date: TokenData = Depends(Accessor.get_current_user)):
logger.info(f"收到消息,来源:客户端,消息内容:{message.message}")
event = PretorEvent(platform="client",
user_id=str(token_date.user_id),
user_name=token_date.user_name,
message=message.message)
supervisory_node = request.app.state.supervisory_node
message = await supervisory_node.working.remote(event)
if message == "任务已创建":
global_state_machine = request.app.state.global_state_machine
global_state_machine.add.remote(event)
return {"message": event.event_id}
elif message == "未知相应类型":
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="模型回复错误")
else:
return {"message": message}
@client_router.websocket("/ws/{event_id}")
async def websocket_endpoint(websocket: WebSocket, event_id: str):
await websocket.accept()
global_state_machine = websocket.app.state.global_state_machine
try:
while True:
await websocket.send_text(await global_state_machine.get_pending(event_id))
response = await websocket.receive_text()
await global_state_machine.put_received(event_id, response)
except WebSocketDisconnect:
pass