"""
WebSocket endpoint for streaming agent responses.
Protocol (client -> server):
{"type": "message", "content": "..."}
Protocol (server -> client):
{"type": "stream_start"}
{"type": "stream_delta", "delta": "..."} # text chunk
{"type": "tool_call", "tool": "...", "args": {...}, "result": "...", "success": bool}
{"type": "stream_end", "content": "..."} # full assembled response
{"type": "error", "message": "..."}
"""
import json
import structlog
from fastapi import APIRouter, WebSocket, WebSocketDisconnect
from navi.api.deps import get_agent, get_session_store
from navi.core import Agent, InMemorySessionStore, StreamEnd, TextDelta, ToolEvent
from navi.exceptions import MaxIterationsReached, NaviError, SessionNotFound
router = APIRouter(tags=["websocket"])
log = structlog.get_logger()
@router.websocket("/ws/sessions/{session_id}")
async def websocket_session(session_id: str, websocket: WebSocket) -> None:
session_store = get_session_store()
# Validate session exists before accepting
session = await session_store.get(session_id)
if session is None:
await websocket.close(code=4004, reason="Session not found")
return
await websocket.accept()
log.info("ws.connected", session_id=session_id)
# Build agent (can't use FastAPI Depends inside WebSocket directly)
from navi.api.deps import _registries
tools, profiles, backends = _registries()
agent = Agent(session_store, profiles, tools, backends)
try:
while True:
raw = await websocket.receive_text()
try:
data = json.loads(raw)
except json.JSONDecodeError:
await websocket.send_json({"type": "error", "message": "Invalid JSON"})
continue
if data.get("type") != "message" or not data.get("content"):
await websocket.send_json({"type": "error", "message": "Expected {type: 'message', content: '...'}"})
continue
user_content = data["content"]
await websocket.send_json({"type": "stream_start"})
try:
async for event in agent.run_stream(session_id, user_content):
if isinstance(event, TextDelta):
await websocket.send_json({"type": "stream_delta", "delta": event.delta})
elif isinstance(event, ToolEvent):
await websocket.send_json({
"type": "tool_call",
"tool": event.tool_name,
"args": event.arguments,
"result": event.result,
"success": event.success,
})
elif isinstance(event, StreamEnd):
await websocket.send_json({"type": "stream_end", "content": event.full_content})
except SessionNotFound:
await websocket.send_json({"type": "error", "message": "Session not found"})
except MaxIterationsReached as e:
await websocket.send_json({"type": "error", "message": str(e)})
except NaviError as e:
await websocket.send_json({"type": "error", "message": str(e)})
except Exception as e:
log.exception("ws.agent_error", session_id=session_id)
await websocket.send_json({"type": "error", "message": f"Internal error: {e}"})
except WebSocketDisconnect:
log.info("ws.disconnected", session_id=session_id)