"""
WebSocket endpoint for streaming agent responses.
Protocol (client -> server):
{"type": "message", "content": "..."}
Protocol (server -> client):
{"type": "stream_start"}
{"type": "thinking_delta", "delta": "..."} # reasoning chunk
{"type": "thinking_end"} # reasoning done
{"type": "stream_delta", "delta": "..."} # text chunk
{"type": "tool_call", "tool": "...", "args": {...}, "result": "...", "success": bool}
{"type": "stream_end", "content": "..."} # full assembled response
{"type": "error", "message": "..."}
"""
import json
import structlog
from fastapi import APIRouter, WebSocket, WebSocketDisconnect
from navi.api.deps import get_agent, get_session_store
from navi.core import Agent, ContextCompressed, InMemorySessionStore, StreamEnd, TextDelta, ThinkingDelta, ThinkingEnd, ToolEvent
from navi.exceptions import MaxIterationsReached, NaviError, SessionNotFound
router = APIRouter(tags=["websocket"])
log = structlog.get_logger()
@router.websocket("/ws/sessions/{session_id}")
async def websocket_session(session_id: str, websocket: WebSocket) -> None:
session_store = get_session_store()
# Validate session exists before accepting
session = await session_store.get(session_id)
if session is None:
await websocket.close(code=4004, reason="Session not found")
return
await websocket.accept()
log.info("ws.connected", session_id=session_id)
# Build agent (can't use FastAPI Depends inside WebSocket directly)
from navi.api.deps import get_memory_store, get_registries, get_workers
tools, profiles, backends = get_registries()
agent = Agent(session_store, profiles, tools, backends,
workers=get_workers(), memory_store=get_memory_store())
try:
while True:
raw = await websocket.receive_text()
try:
data = json.loads(raw)
except json.JSONDecodeError:
await websocket.send_json({"type": "error", "message": "Invalid JSON"})
continue
if data.get("type") != "message" or not data.get("content"):
await websocket.send_json({"type": "error", "message": "Expected {type: 'message', content: '...'}"})
continue
user_content = data["content"]
# images: list of base64 strings (data URI prefix already stripped by client)
raw_images: list[str] | None = data.get("images") or None
if raw_images:
# Strip data URI prefix if client sent it with prefix
cleaned = []
for img in raw_images:
if "," in img and img.startswith("data:"):
img = img.split(",", 1)[1]
cleaned.append(img)
raw_images = cleaned
await websocket.send_json({"type": "stream_start"})
try:
async for event in agent.run_stream(session_id, user_content, images=raw_images):
if isinstance(event, ThinkingDelta):
await websocket.send_json({"type": "thinking_delta", "delta": event.delta})
elif isinstance(event, ThinkingEnd):
await websocket.send_json({"type": "thinking_end"})
elif isinstance(event, TextDelta):
await websocket.send_json({"type": "stream_delta", "delta": event.delta})
elif isinstance(event, ToolEvent):
await websocket.send_json({
"type": "tool_call",
"tool": event.tool_name,
"args": event.arguments,
"result": event.result,
"success": event.success,
})
elif isinstance(event, StreamEnd):
await websocket.send_json({
"type": "stream_end",
"content": event.full_content,
"context_tokens": event.context_tokens,
"max_context_tokens": event.max_context_tokens,
})
elif isinstance(event, ContextCompressed):
await websocket.send_json({
"type": "context_compressed",
"messages_before": event.messages_before,
"messages_after": event.messages_after,
})
except SessionNotFound:
await websocket.send_json({"type": "error", "message": "Session not found"})
except MaxIterationsReached as e:
await websocket.send_json({"type": "error", "message": str(e)})
except NaviError as e:
await websocket.send_json({"type": "error", "message": str(e)})
except Exception as e:
log.exception("ws.agent_error", session_id=session_id)
await websocket.send_json({"type": "error", "message": f"Internal error: {e}"})
except WebSocketDisconnect:
log.info("ws.disconnected", session_id=session_id)