"""WebSocket endpoint for streaming agent responses.
Protocol (client -> server):
{"type": "message", "content": "..."}
Protocol (server -> client):
{"type": "stream_start"}
{"type": "thinking_delta", "delta": "..."}
{"type": "thinking_end"}
{"type": "stream_delta", "delta": "..."}
{"type": "tool_started", "tool": "...", "args": {...}, "is_subagent": bool}
{"type": "tool_call", "tool": "...", "args": {...}, "result": "...", "success": bool, "is_subagent": bool}
{"type": "stream_end", "content": "..."}
{"type": "context_compressed"}
{"type": "error", "message": "..."}
"""
import asyncio
import json
import structlog
from fastapi import APIRouter, Depends, WebSocket, WebSocketDisconnect
from typing import Annotated
from navi.api.deps import get_orchestrator, get_session_store
from navi.auth.deps import get_current_user, get_current_user_ws
from navi.auth import User
from navi.auth.deps import check_session_access
from navi.core import SessionStore
from navi.core.event_bus import get_event_bus
from navi.core.events import AgentEvent, RecallUpdate
router = APIRouter(tags=["websocket"])
log = structlog.get_logger()
# ── Helpers ───────────────────────────────────────────────────────────────────
def _event_to_dict(event) -> dict | None:
if hasattr(event, "to_wire"):
return event.to_wire()
return None
_HEARTBEAT_INTERVAL = 20.0 # seconds — keeps the browser from dropping idle connections
async def _stream_to_client(websocket: WebSocket, queue: asyncio.Queue) -> bool:
"""
Forward queue items to the WebSocket until ("done", None).
Returns True if client stayed connected, False if it disconnected mid-stream.
On disconnect, continues draining silently so the agent task can finish cleanly.
Sends periodic heartbeat pings while waiting so the browser doesn't time out
during long-running silent operations (e.g. AIHelper LLM calls).
"""
client_alive = True
while True:
try:
kind, payload = await asyncio.wait_for(queue.get(), timeout=_HEARTBEAT_INTERVAL)
except asyncio.TimeoutError:
# No event arrived — send a heartbeat to keep the WS alive
if client_alive:
try:
await websocket.send_json({"type": "heartbeat"})
except Exception:
client_alive = False
continue
if kind == "done":
return client_alive
if not client_alive:
continue # keep draining so run_agent can broadcast without blocking
try:
if kind == "error":
await websocket.send_json({"type": "error", "message": payload})
elif kind == "stopped":
await websocket.send_json({"type": "stream_stopped"})
elif kind == "event":
msg = _event_to_dict(payload)
if msg:
await websocket.send_json(msg)
except Exception:
client_alive = False
# ── Endpoints ─────────────────────────────────────────────────────────────────
@router.post("/sessions/{session_id}/stop")
async def stop_session(
session_id: str,
store: Annotated[SessionStore, Depends(get_session_store)],
user: Annotated[User | None, Depends(get_current_user)] = None,
) -> dict:
"""Signal the running agent for this session to stop cooperatively."""
session = await store.get(session_id)
if session is None:
raise HTTPException(status_code=404, detail="Session not found")
if user is None and session.user_id is not None:
raise HTTPException(status_code=401, detail="Authentication required")
if user is not None:
check_session_access(session, user)
orchestrator = get_orchestrator()
ok = orchestrator.stop(session_id)
return {"ok": ok} if ok else {"ok": False, "reason": "no active run"}
@router.websocket("/ws/sessions/{session_id}")
async def websocket_session(
session_id: str,
websocket: WebSocket,
) -> None:
log.info("ws.handler_enter", session_id=session_id)
# Resolve user manually — bypass FastAPI Depends() to avoid HTTP 403
# on the upgrade request when auth resolution fails.
try:
user: User | None = await get_current_user_ws(websocket)
except Exception as exc:
log.warning("ws.resolve_user_exc", session_id=session_id, exc_type=type(exc).__name__, error=str(exc))
user = None
log.info("ws.user_resolved", session_id=session_id, user_id=user.id if user else None)
session_store = get_session_store()
orchestrator = get_orchestrator()
session = await session_store.get(session_id)
if session is None:
log.warning("ws.session_not_found", session_id=session_id)
await websocket.close(code=4004, reason="Session not found")
return
log.info("ws.session_found", session_id=session_id, session_user_id=session.user_id)
# Accept the WebSocket before checking access so that auth failures can be
# sent as WebSocket close codes rather than HTTP 403 on the upgrade request.
await websocket.accept()
orchestrator.add_websocket(session_id, websocket)
log.info("ws.accepted", session_id=session_id)
if user is None:
# Anonymous users may only connect to legacy sessions (no owner).
if session.user_id is not None:
log.warning("ws.anonymous_denied", session_id=session_id)
await websocket.close(code=4003, reason="Authentication required")
return
else:
try:
check_session_access(session, user)
log.info("ws.access_granted", session_id=session_id, user_id=user.id)
except Exception:
log.warning("ws.access_denied", session_id=session_id, user_id=user.id)
await websocket.close(code=4003, reason="Access denied")
return
queue: asyncio.Queue | None = None
current_run = None
try:
# Re-attach to an in-progress run (e.g. client reloaded the page mid-stream).
existing = orchestrator.get_run(session_id)
if existing is not None:
current_run = existing
# Subscribe BEFORE noting replay_count — single-threaded async, no race:
# any broadcast() that happens after subscribe() goes into our queue,
# and we replay only events[0:replay_count] from the buffer.
queue = existing.subscribe()
replay_count = len(existing.events)
log.info("ws.reattached", session_id=session_id, replay_count=replay_count)
await websocket.send_json({"type": "stream_start"})
# Replay all events emitted before we subscribed so the client can
# reconstruct the full in-progress UI state without any gaps.
if replay_count > 0:
await websocket.send_json({"type": "replay_start", "count": replay_count})
for ev_dict in existing.events[:replay_count]:
try:
await websocket.send_json(ev_dict)
except Exception:
existing.unsubscribe(queue)
return
await websocket.send_json({"type": "replay_end"})
connected = await _stream_to_client(websocket, queue)
existing.unsubscribe(queue)
queue = None
current_run = None
if not connected:
return # client disconnected again — stop here
# Stream finished — tell the client to sync session history so it sees
# the full saved response (handles any events missed during disconnect).
await websocket.send_json({"type": "session_sync"})
else:
# No active run — if this is a reconnect after the agent already finished,
# the client needs to reload session history to see the saved response.
await websocket.send_json({"type": "session_sync"})
while True:
raw = await websocket.receive_text()
try:
data = json.loads(raw)
except json.JSONDecodeError:
await websocket.send_json({"type": "error", "message": "Invalid JSON"})
continue
if data.get("type") != "message" or not data.get("content"):
await websocket.send_json({
"type": "error",
"message": "Expected {type: 'message', content: '...'}",
})
continue
original_content = data["content"]
user_content = original_content
raw_images: list[str] | None = data.get("images") or None
if raw_images:
# Guard against abuse: limit count and total payload size
_MAX_IMAGES = 8
_MAX_IMAGE_BYTES_TOTAL = 50 * 1024 * 1024 # 50 MB total payload
if len(raw_images) > _MAX_IMAGES:
await websocket.send_json({
"type": "error",
"message": f"Too many images ({len(raw_images)}). Max {_MAX_IMAGES} allowed.",
})
continue
cleaned = []
total_bytes = 0
for img in raw_images:
if "," in img and img.startswith("data:"):
img = img.split(",", 1)[1]
img_bytes = len(img.encode("utf-8"))
if total_bytes + img_bytes > _MAX_IMAGE_BYTES_TOTAL:
await websocket.send_json({
"type": "error",
"message": "Total image payload exceeds 50 MB limit.",
})
cleaned = None
break
total_bytes += img_bytes
cleaned.append(img)
if cleaned is None:
continue
raw_images = cleaned
# Tell the model the inline images are already in its multimodal context,
# so it doesn't hallucinate a path/URL and call image_view to "load" them.
if raw_images:
n = len(raw_images)
noun = "image" if n == 1 else "images"
user_content = (
user_content
+ f"\n\n[{n} {noun} attached inline — already in your context, no extra loading needed.]"
)
# Append uploaded file paths to user content so Navi knows about them
uploaded_files: list[dict] = data.get("files") or []
if uploaded_files:
file_lines = "\n".join(
f"- {f['name']} → {f['path']}" for f in uploaded_files
)
user_content = (
user_content + f"\n\n[Uploaded files on disk:\n{file_lines}]"
)
# Guard against concurrent runs for the same session (atomically).
async with orchestrator.session_lock(session_id):
if orchestrator.is_running(session_id):
await websocket.send_json({
"type": "error",
"message": "Agent is already running for this session.",
})
continue
# Register run and subscribe before starting the task so we never
# miss events even if the task is very fast.
run = orchestrator.create_run(session_id)
queue = run.subscribe()
current_run = run
# Set user context for tool sandboxing (inherited by the agent task)
from navi.tools._internal.base import current_user_id as _uid_var, current_user_role as _role_var, current_user_info as _uinfo_var
if user is not None:
uid_token = _uid_var.set(user.id)
role_token = _role_var.set(user.role)
uinfo_token = _uinfo_var.set(user.model_dump(mode="json"))
else:
uid_token = _uid_var.set(None)
role_token = _role_var.set("user")
uinfo_token = _uinfo_var.set(None)
try:
run.task = asyncio.create_task(
orchestrator.run_agent(
session_id, user_content, raw_images, original_content, uploaded_files, session_store
)
)
except Exception:
# If create_task fails, ensure the run is cleaned up so the session
# isn't stuck as "running" forever.
orchestrator.clear_run(session_id)
raise
finally:
_uid_var.reset(uid_token)
_role_var.reset(role_token)
_uinfo_var.reset(uinfo_token)
await websocket.send_json({"type": "stream_start"})
connected = await _stream_to_client(websocket, queue)
run.unsubscribe(queue)
queue = None
current_run = None
if not connected:
break # avoid calling receive_text() on a dead socket
except (WebSocketDisconnect, RuntimeError):
log.info("ws.disconnected", session_id=session_id)
finally:
# Ensure queue is removed from subscribers on any abrupt exit.
if queue is not None and current_run is not None:
current_run.unsubscribe(queue)
# Remove this socket from the session tracking set.
await orchestrator.remove_websocket(session_id, websocket)