"""
WebSocket endpoint for streaming agent responses.
Protocol (client -> server):
{"type": "message", "content": "..."}
Protocol (server -> client):
{"type": "stream_start"}
{"type": "thinking_delta", "delta": "..."}
{"type": "thinking_end"}
{"type": "stream_delta", "delta": "..."}
{"type": "tool_started", "tool": "...", "args": {...}, "is_subagent": bool}
{"type": "tool_call", "tool": "...", "args": {...}, "result": "...", "success": bool, "is_subagent": bool}
{"type": "stream_end", "content": "..."}
{"type": "context_compressed"}
{"type": "error", "message": "..."}
"""
import asyncio
import dataclasses
import json
import structlog
from fastapi import APIRouter, WebSocket, WebSocketDisconnect
from navi.api.deps import get_session_store
from navi.core import Agent, ContextCompressed, StreamEnd, TextDelta, ThinkingDelta, ThinkingEnd, ToolEvent
from navi.core.events import PlanningStatus, PlanReady, ProfileSwitched, StreamStopped, ToolStarted, TurnThinking
from navi.exceptions import MaxIterationsReached, NaviError, SessionNotFound
router = APIRouter(tags=["websocket"])
log = structlog.get_logger()
_MAX_REPLAY_EVENTS = 500 # cap replay buffer to avoid unbounded growth
# ── Per-session run state ──────────────────────────────────────────────────────
@dataclasses.dataclass
class _AgentRun:
"""Holds the running agent task and all active subscriber queues."""
task: asyncio.Task | None = None
stop_event: asyncio.Event = dataclasses.field(default_factory=asyncio.Event)
subscribers: list[asyncio.Queue] = dataclasses.field(default_factory=list)
# Replay buffer: all serialised event dicts emitted so far this turn.
# Used to reconstruct the UI for clients that reconnect mid-stream.
events: list[dict] = dataclasses.field(default_factory=list)
def subscribe(self) -> asyncio.Queue:
q: asyncio.Queue = asyncio.Queue()
self.subscribers.append(q)
return q
def unsubscribe(self, q: asyncio.Queue) -> None:
try:
self.subscribers.remove(q)
except ValueError:
pass
async def broadcast(self, item) -> None:
kind, payload = item
# Serialise and buffer every agent event so reconnecting clients can replay.
if kind == "event":
ev_dict = _event_to_dict(payload)
if ev_dict:
self.events.append(ev_dict)
# Evict oldest events to keep memory and replay cost bounded.
if len(self.events) > _MAX_REPLAY_EVENTS:
self.events.pop(0)
for q in list(self.subscribers):
await q.put(item)
# session_id → active run (present only while agent is executing)
_runs: dict[str, _AgentRun] = {}
# ── Helpers ───────────────────────────────────────────────────────────────────
def _event_to_dict(event) -> dict | None:
if isinstance(event, ThinkingDelta):
return {"type": "thinking_delta", "delta": event.delta}
if isinstance(event, ThinkingEnd):
return {"type": "thinking_end"}
if isinstance(event, TextDelta):
return {"type": "stream_delta", "delta": event.delta}
if isinstance(event, ToolStarted):
return {
"type": "tool_started",
"tool": event.tool_name,
"args": event.arguments,
"is_subagent": event.is_subagent,
}
if isinstance(event, ToolEvent):
return {
"type": "tool_call",
"tool": event.tool_name,
"args": event.arguments,
"result": event.result,
"success": event.success,
"is_subagent": event.is_subagent,
"metadata": event.metadata,
}
if isinstance(event, StreamEnd):
return {
"type": "stream_end",
"content": event.full_content,
"context_tokens": event.context_tokens,
"max_context_tokens": event.max_context_tokens,
"elapsed_seconds": event.elapsed_seconds,
"tool_call_count": event.tool_call_count,
"token_count": event.token_count,
}
if isinstance(event, ContextCompressed):
return {
"type": "context_compressed",
"messages_before": event.messages_before,
"messages_after": event.messages_after,
"summary": event.summary,
}
if isinstance(event, TurnThinking):
return {"type": "turn_thinking", "thinking": event.thinking, "is_subagent": event.is_subagent}
if isinstance(event, ProfileSwitched):
return {"type": "profile_switched", "profile_id": event.profile_id, "profile_name": event.profile_name}
if isinstance(event, StreamStopped):
return {"type": "stream_stopped"}
if isinstance(event, PlanningStatus):
return {"type": "planning_status", "phase": event.phase, "label": event.label, "is_subagent": event.is_subagent}
if isinstance(event, PlanReady):
return {"type": "plan_ready", "plan": event.plan, "is_subagent": event.is_subagent}
return None
async def _run_agent(
run: _AgentRun,
agent: Agent,
session_id: str,
user_content: str,
raw_images: list[str] | None,
display_content: str | None = None,
) -> None:
"""
Execute the agent to completion, broadcasting events to all subscribers.
The session is saved by run_stream before StreamEnd — guaranteed even on disconnect.
"""
from navi.tools.base import current_stop_event
current_stop_event.set(run.stop_event)
try:
async for event in agent.run_stream(
session_id, user_content, images=raw_images, display_message=display_content
):
await run.broadcast(("event", event))
except asyncio.CancelledError:
log.info("ws.agent_stopped", session_id=session_id)
await run.broadcast(("stopped", None))
raise # re-raise so the task is properly marked cancelled
except SessionNotFound:
await run.broadcast(("error", "Session not found"))
except MaxIterationsReached as e:
await run.broadcast(("error", str(e)))
except NaviError as e:
await run.broadcast(("error", str(e)))
except Exception as e:
log.exception("ws.agent_error", session_id=session_id)
await run.broadcast(("error", f"Internal error: {e}"))
finally:
await run.broadcast(("done", None))
_runs.pop(session_id, None)
_HEARTBEAT_INTERVAL = 20.0 # seconds — keeps the browser from dropping idle connections
async def _stream_to_client(websocket: WebSocket, queue: asyncio.Queue) -> bool:
"""
Forward queue items to the WebSocket until ("done", None).
Returns True if client stayed connected, False if it disconnected mid-stream.
On disconnect, continues draining silently so the agent task can finish cleanly.
Sends periodic heartbeat pings while waiting so the browser doesn't time out
during long-running silent operations (e.g. AIHelper LLM calls).
"""
client_alive = True
while True:
try:
kind, payload = await asyncio.wait_for(queue.get(), timeout=_HEARTBEAT_INTERVAL)
except asyncio.TimeoutError:
# No event arrived — send a heartbeat to keep the WS alive
if client_alive:
try:
await websocket.send_json({"type": "heartbeat"})
except Exception:
client_alive = False
continue
if kind == "done":
return client_alive
if not client_alive:
continue # keep draining so run_agent can broadcast without blocking
try:
if kind == "error":
await websocket.send_json({"type": "error", "message": payload})
elif kind == "stopped":
await websocket.send_json({"type": "stream_stopped"})
elif kind == "event":
msg = _event_to_dict(payload)
if msg:
await websocket.send_json(msg)
except Exception:
client_alive = False
# ── Endpoints ─────────────────────────────────────────────────────────────────
@router.post("/sessions/{session_id}/stop")
async def stop_session(session_id: str) -> dict:
"""Signal the running agent for this session to stop cooperatively."""
run = _runs.get(session_id)
if run is not None:
run.stop_event.set()
return {"ok": True}
return {"ok": False, "reason": "no active run"}
@router.websocket("/ws/sessions/{session_id}")
async def websocket_session(session_id: str, websocket: WebSocket) -> None:
session_store = get_session_store()
session = await session_store.get(session_id)
if session is None:
await websocket.close(code=4004, reason="Session not found")
return
await websocket.accept()
log.info("ws.connected", session_id=session_id)
from navi.api.deps import get_memory_store, get_registries, get_workers
tools, profiles, backends, cp_registry = get_registries()
agent = Agent(
session_store, profiles, tools, backends,
workers=get_workers(), memory_store=get_memory_store(), cp_registry=cp_registry,
)
queue: asyncio.Queue | None = None
current_run: _AgentRun | None = None
try:
# Re-attach to an in-progress run (e.g. client reloaded the page mid-stream).
existing = _runs.get(session_id)
if existing is not None:
current_run = existing
# Subscribe BEFORE noting replay_count — single-threaded async, no race:
# any broadcast() that happens after subscribe() goes into our queue,
# and we replay only events[0:replay_count] from the buffer.
queue = existing.subscribe()
replay_count = len(existing.events)
log.info("ws.reattached", session_id=session_id, replay_count=replay_count)
await websocket.send_json({"type": "stream_start"})
# Replay all events emitted before we subscribed so the client can
# reconstruct the full in-progress UI state without any gaps.
if replay_count > 0:
await websocket.send_json({"type": "replay_start", "count": replay_count})
for ev_dict in existing.events[:replay_count]:
try:
await websocket.send_json(ev_dict)
except Exception:
existing.unsubscribe(queue)
return
await websocket.send_json({"type": "replay_end"})
connected = await _stream_to_client(websocket, queue)
existing.unsubscribe(queue)
queue = None
current_run = None
if not connected:
return # client disconnected again — stop here
# Stream finished — tell the client to sync session history so it sees
# the full saved response (handles any events missed during disconnect).
await websocket.send_json({"type": "session_sync"})
else:
# No active run — if this is a reconnect after the agent already finished,
# the client needs to reload session history to see the saved response.
await websocket.send_json({"type": "session_sync"})
while True:
raw = await websocket.receive_text()
try:
data = json.loads(raw)
except json.JSONDecodeError:
await websocket.send_json({"type": "error", "message": "Invalid JSON"})
continue
if data.get("type") != "message" or not data.get("content"):
await websocket.send_json({
"type": "error",
"message": "Expected {type: 'message', content: '...'}",
})
continue
original_content = data["content"]
user_content = original_content
raw_images: list[str] | None = data.get("images") or None
if raw_images:
# Guard against abuse: limit count and total payload size
_MAX_IMAGES = 10
_MAX_IMAGE_BYTES = 5 * 1024 * 1024 # 5 MB per image (base64 ~6.7 MB string)
if len(raw_images) > _MAX_IMAGES:
await websocket.send_json({
"type": "error",
"message": f"Too many images ({len(raw_images)}). Max {_MAX_IMAGES} allowed.",
})
continue
cleaned = []
for img in raw_images:
if "," in img and img.startswith("data:"):
img = img.split(",", 1)[1]
if len(img.encode("utf-8")) > _MAX_IMAGE_BYTES:
await websocket.send_json({
"type": "error",
"message": "Image exceeds 5 MB limit.",
})
cleaned = None
break
cleaned.append(img)
if cleaned is None:
continue
raw_images = cleaned
# Tell the model the inline images are already in its multimodal context,
# so it doesn't hallucinate a path/URL and call image_view to "load" them.
if raw_images:
n = len(raw_images)
noun = "image" if n == 1 else "images"
user_content = (
user_content
+ f"\n\n[{n} {noun} attached inline — already in your context, no extra loading needed.]"
)
# Append uploaded file paths to user content so Navi knows about them
uploaded_files: list[dict] = data.get("files") or []
if uploaded_files:
file_lines = "\n".join(
f"- {f['name']} → {f['path']}" for f in uploaded_files
)
user_content = (
user_content + f"\n\n[Uploaded files on disk:\n{file_lines}]"
)
# Guard against concurrent runs for the same session.
if session_id in _runs:
await websocket.send_json({
"type": "error",
"message": "Agent is already running for this session.",
})
continue
# Register run and subscribe before starting the task so we never
# miss events even if the task is very fast.
run = _AgentRun()
queue = run.subscribe()
current_run = run
_runs[session_id] = run
run.task = asyncio.create_task(
_run_agent(run, agent, session_id, user_content, raw_images, original_content)
)
await websocket.send_json({"type": "stream_start"})
connected = await _stream_to_client(websocket, queue)
run.unsubscribe(queue)
queue = None
current_run = None
if not connected:
break # avoid calling receive_text() on a dead socket
except (WebSocketDisconnect, RuntimeError):
log.info("ws.disconnected", session_id=session_id)
finally:
# Ensure queue is removed from subscribers on any abrupt exit.
if queue is not None and current_run is not None:
current_run.unsubscribe(queue)