"""
WebSocket endpoint for streaming agent responses.
Protocol (client -> server):
{"type": "message", "content": "..."}
Protocol (server -> client):
{"type": "stream_start"}
{"type": "thinking_delta", "delta": "..."}
{"type": "thinking_end"}
{"type": "stream_delta", "delta": "..."}
{"type": "tool_started", "tool": "...", "args": {...}, "is_subagent": bool}
{"type": "tool_call", "tool": "...", "args": {...}, "result": "...", "success": bool, "is_subagent": bool}
{"type": "stream_end", "content": "..."}
{"type": "context_compressed"}
{"type": "error", "message": "..."}
"""
import asyncio
import dataclasses
import json
import structlog
from fastapi import APIRouter, WebSocket, WebSocketDisconnect
from navi.api.deps import get_session_store
from navi.core import Agent, ContextCompressed, StreamEnd, TextDelta, ThinkingDelta, ThinkingEnd, ToolEvent
from navi.core.events import ToolStarted, TurnThinking
from navi.exceptions import MaxIterationsReached, NaviError, SessionNotFound
router = APIRouter(tags=["websocket"])
log = structlog.get_logger()
# ── Per-session run state ──────────────────────────────────────────────────────
@dataclasses.dataclass
class _AgentRun:
"""Holds the running agent task and all active subscriber queues."""
task: asyncio.Task | None = None
subscribers: list[asyncio.Queue] = dataclasses.field(default_factory=list)
def subscribe(self) -> asyncio.Queue:
q: asyncio.Queue = asyncio.Queue()
self.subscribers.append(q)
return q
def unsubscribe(self, q: asyncio.Queue) -> None:
try:
self.subscribers.remove(q)
except ValueError:
pass
async def broadcast(self, item) -> None:
for q in list(self.subscribers):
await q.put(item)
# session_id → active run (present only while agent is executing)
_runs: dict[str, _AgentRun] = {}
# ── Helpers ───────────────────────────────────────────────────────────────────
def _event_to_dict(event) -> dict | None:
if isinstance(event, ThinkingDelta):
return {"type": "thinking_delta", "delta": event.delta}
if isinstance(event, ThinkingEnd):
return {"type": "thinking_end"}
if isinstance(event, TextDelta):
return {"type": "stream_delta", "delta": event.delta}
if isinstance(event, ToolStarted):
return {
"type": "tool_started",
"tool": event.tool_name,
"args": event.arguments,
"is_subagent": event.is_subagent,
}
if isinstance(event, ToolEvent):
return {
"type": "tool_call",
"tool": event.tool_name,
"args": event.arguments,
"result": event.result,
"success": event.success,
"is_subagent": event.is_subagent,
}
if isinstance(event, StreamEnd):
return {
"type": "stream_end",
"content": event.full_content,
"context_tokens": event.context_tokens,
"max_context_tokens": event.max_context_tokens,
}
if isinstance(event, ContextCompressed):
return {
"type": "context_compressed",
"messages_before": event.messages_before,
"messages_after": event.messages_after,
}
if isinstance(event, TurnThinking):
return {"type": "turn_thinking", "thinking": event.thinking, "is_subagent": event.is_subagent}
return None
async def _run_agent(
run: _AgentRun,
agent: Agent,
session_id: str,
user_content: str,
raw_images: list[str] | None,
) -> None:
"""
Execute the agent to completion, broadcasting events to all subscribers.
The session is saved by run_stream before StreamEnd — guaranteed even on disconnect.
"""
try:
async for event in agent.run_stream(session_id, user_content, images=raw_images):
await run.broadcast(("event", event))
except SessionNotFound:
await run.broadcast(("error", "Session not found"))
except MaxIterationsReached as e:
await run.broadcast(("error", str(e)))
except NaviError as e:
await run.broadcast(("error", str(e)))
except Exception as e:
log.exception("ws.agent_error", session_id=session_id)
await run.broadcast(("error", f"Internal error: {e}"))
finally:
await run.broadcast(("done", None))
_runs.pop(session_id, None)
async def _stream_to_client(websocket: WebSocket, queue: asyncio.Queue) -> bool:
"""
Forward queue items to the WebSocket until ("done", None).
Returns True if client stayed connected, False if it disconnected mid-stream.
On disconnect, continues draining silently so the agent task can finish cleanly.
"""
client_alive = True
while True:
kind, payload = await queue.get()
if kind == "done":
return client_alive
if not client_alive:
continue # keep draining so run_agent can broadcast without blocking
try:
if kind == "error":
await websocket.send_json({"type": "error", "message": payload})
elif kind == "event":
msg = _event_to_dict(payload)
if msg:
await websocket.send_json(msg)
except Exception:
client_alive = False
# ── Endpoint ──────────────────────────────────────────────────────────────────
@router.websocket("/ws/sessions/{session_id}")
async def websocket_session(session_id: str, websocket: WebSocket) -> None:
session_store = get_session_store()
session = await session_store.get(session_id)
if session is None:
await websocket.close(code=4004, reason="Session not found")
return
await websocket.accept()
log.info("ws.connected", session_id=session_id)
from navi.api.deps import get_memory_store, get_registries, get_workers
tools, profiles, backends = get_registries()
agent = Agent(
session_store, profiles, tools, backends,
workers=get_workers(), memory_store=get_memory_store(),
)
queue: asyncio.Queue | None = None
current_run: _AgentRun | None = None
try:
# Re-attach to an in-progress run (e.g. client reloaded the page mid-stream).
existing = _runs.get(session_id)
if existing is not None:
current_run = existing
queue = existing.subscribe()
log.info("ws.reattached", session_id=session_id)
await websocket.send_json({"type": "stream_start"})
connected = await _stream_to_client(websocket, queue)
existing.unsubscribe(queue)
queue = None
current_run = None
if not connected:
return # client disconnected again — stop here
while True:
raw = await websocket.receive_text()
try:
data = json.loads(raw)
except json.JSONDecodeError:
await websocket.send_json({"type": "error", "message": "Invalid JSON"})
continue
if data.get("type") != "message" or not data.get("content"):
await websocket.send_json({
"type": "error",
"message": "Expected {type: 'message', content: '...'}",
})
continue
user_content = data["content"]
raw_images: list[str] | None = data.get("images") or None
if raw_images:
cleaned = []
for img in raw_images:
if "," in img and img.startswith("data:"):
img = img.split(",", 1)[1]
cleaned.append(img)
raw_images = cleaned
# Append uploaded file paths to user content so Navi knows about them
uploaded_files: list[dict] = data.get("files") or []
if uploaded_files:
file_lines = "\n".join(
f"- {f['name']} → {f['path']}" for f in uploaded_files
)
user_content = (
user_content + f"\n\n[Uploaded files on disk:\n{file_lines}]"
)
# Register run and subscribe before starting the task so we never
# miss events even if the task is very fast.
run = _AgentRun()
queue = run.subscribe()
current_run = run
_runs[session_id] = run
run.task = asyncio.create_task(
_run_agent(run, agent, session_id, user_content, raw_images)
)
await websocket.send_json({"type": "stream_start"})
connected = await _stream_to_client(websocket, queue)
run.unsubscribe(queue)
queue = None
current_run = None
if not connected:
break # avoid calling receive_text() on a dead socket
except (WebSocketDisconnect, RuntimeError):
log.info("ws.disconnected", session_id=session_id)
finally:
# Ensure queue is removed from subscribers on any abrupt exit.
if queue is not None and current_run is not None:
current_run.unsubscribe(queue)