Newer
Older
navi-1 / navi / api / websocket.py
"""
WebSocket endpoint for streaming agent responses.

Protocol (client -> server):
  {"type": "message", "content": "..."}

Protocol (server -> client):
  {"type": "stream_start"}
  {"type": "thinking_delta",  "delta": "..."}
  {"type": "thinking_end"}
  {"type": "stream_delta",    "delta": "..."}
  {"type": "tool_started",    "tool": "...", "args": {...}, "is_subagent": bool}
  {"type": "tool_call",       "tool": "...", "args": {...}, "result": "...", "success": bool, "is_subagent": bool}
  {"type": "stream_end",      "content": "..."}
  {"type": "context_compressed"}
  {"type": "error",           "message": "..."}
"""

import asyncio
import dataclasses
import json

import structlog
from fastapi import APIRouter, WebSocket, WebSocketDisconnect

from navi.api.deps import get_session_store
from navi.core import Agent, ContextCompressed, StreamEnd, TextDelta, ThinkingDelta, ThinkingEnd, ToolEvent
from navi.core.events import PlanningStatus, PlanReady, ProfileSwitched, StreamStopped, ToolStarted, TurnThinking
from navi.exceptions import MaxIterationsReached, NaviError, SessionNotFound

router = APIRouter(tags=["websocket"])
log = structlog.get_logger()


_MAX_REPLAY_EVENTS = 500  # cap replay buffer to avoid unbounded growth

# ── Per-session run state ──────────────────────────────────────────────────────

@dataclasses.dataclass
class _AgentRun:
    """Holds the running agent task and all active subscriber queues."""
    task: asyncio.Task | None = None
    stop_event: asyncio.Event = dataclasses.field(default_factory=asyncio.Event)
    subscribers: list[asyncio.Queue] = dataclasses.field(default_factory=list)
    # Replay buffer: all serialised event dicts emitted so far this turn.
    # Used to reconstruct the UI for clients that reconnect mid-stream.
    events: list[dict] = dataclasses.field(default_factory=list)

    def subscribe(self) -> asyncio.Queue:
        q: asyncio.Queue = asyncio.Queue()
        self.subscribers.append(q)
        return q

    def unsubscribe(self, q: asyncio.Queue) -> None:
        try:
            self.subscribers.remove(q)
        except ValueError:
            pass

    async def broadcast(self, item) -> None:
        kind, payload = item
        # Serialise and buffer every agent event so reconnecting clients can replay.
        if kind == "event":
            ev_dict = _event_to_dict(payload)
            if ev_dict:
                self.events.append(ev_dict)
                # Evict oldest events to keep memory and replay cost bounded.
                if len(self.events) > _MAX_REPLAY_EVENTS:
                    self.events.pop(0)
        for q in list(self.subscribers):
            await q.put(item)


# session_id → active run (present only while agent is executing)
_runs: dict[str, _AgentRun] = {}


# ── Helpers ───────────────────────────────────────────────────────────────────

def _event_to_dict(event) -> dict | None:
    if isinstance(event, ThinkingDelta):
        return {"type": "thinking_delta", "delta": event.delta}
    if isinstance(event, ThinkingEnd):
        return {"type": "thinking_end"}
    if isinstance(event, TextDelta):
        return {"type": "stream_delta", "delta": event.delta}
    if isinstance(event, ToolStarted):
        return {
            "type": "tool_started",
            "tool": event.tool_name,
            "args": event.arguments,
            "is_subagent": event.is_subagent,
        }
    if isinstance(event, ToolEvent):
        return {
            "type": "tool_call",
            "tool": event.tool_name,
            "args": event.arguments,
            "result": event.result,
            "success": event.success,
            "is_subagent": event.is_subagent,
            "metadata": event.metadata,
        }
    if isinstance(event, StreamEnd):
        return {
            "type": "stream_end",
            "content": event.full_content,
            "context_tokens": event.context_tokens,
            "max_context_tokens": event.max_context_tokens,
            "elapsed_seconds": event.elapsed_seconds,
            "tool_call_count": event.tool_call_count,
            "token_count": event.token_count,
        }
    if isinstance(event, ContextCompressed):
        return {
            "type": "context_compressed",
            "messages_before": event.messages_before,
            "messages_after": event.messages_after,
            "summary": event.summary,
        }
    if isinstance(event, TurnThinking):
        return {"type": "turn_thinking", "thinking": event.thinking, "is_subagent": event.is_subagent}
    if isinstance(event, ProfileSwitched):
        return {"type": "profile_switched", "profile_id": event.profile_id, "profile_name": event.profile_name}
    if isinstance(event, StreamStopped):
        return {"type": "stream_stopped"}
    if isinstance(event, PlanningStatus):
        return {"type": "planning_status", "phase": event.phase, "label": event.label, "is_subagent": event.is_subagent}
    if isinstance(event, PlanReady):
        return {"type": "plan_ready", "plan": event.plan, "is_subagent": event.is_subagent}
    return None


async def _run_agent(
    run: _AgentRun,
    agent: Agent,
    session_id: str,
    user_content: str,
    raw_images: list[str] | None,
    display_content: str | None = None,
) -> None:
    """
    Execute the agent to completion, broadcasting events to all subscribers.
    The session is saved by run_stream before StreamEnd — guaranteed even on disconnect.
    """
    from navi.tools.base import current_stop_event
    current_stop_event.set(run.stop_event)

    try:
        async for event in agent.run_stream(
            session_id, user_content, images=raw_images, display_message=display_content
        ):
            await run.broadcast(("event", event))
    except asyncio.CancelledError:
        log.info("ws.agent_stopped", session_id=session_id)
        await run.broadcast(("stopped", None))
        raise  # re-raise so the task is properly marked cancelled
    except SessionNotFound:
        await run.broadcast(("error", "Session not found"))
    except MaxIterationsReached as e:
        await run.broadcast(("error", str(e)))
    except NaviError as e:
        await run.broadcast(("error", str(e)))
    except Exception as e:
        log.exception("ws.agent_error", session_id=session_id)
        await run.broadcast(("error", f"Internal error: {e}"))
    finally:
        await run.broadcast(("done", None))
        _runs.pop(session_id, None)


_HEARTBEAT_INTERVAL = 20.0  # seconds — keeps the browser from dropping idle connections


async def _stream_to_client(websocket: WebSocket, queue: asyncio.Queue) -> bool:
    """
    Forward queue items to the WebSocket until ("done", None).
    Returns True if client stayed connected, False if it disconnected mid-stream.
    On disconnect, continues draining silently so the agent task can finish cleanly.
    Sends periodic heartbeat pings while waiting so the browser doesn't time out
    during long-running silent operations (e.g. AIHelper LLM calls).
    """
    client_alive = True
    while True:
        try:
            kind, payload = await asyncio.wait_for(queue.get(), timeout=_HEARTBEAT_INTERVAL)
        except asyncio.TimeoutError:
            # No event arrived — send a heartbeat to keep the WS alive
            if client_alive:
                try:
                    await websocket.send_json({"type": "heartbeat"})
                except Exception:
                    client_alive = False
            continue

        if kind == "done":
            return client_alive
        if not client_alive:
            continue  # keep draining so run_agent can broadcast without blocking
        try:
            if kind == "error":
                await websocket.send_json({"type": "error", "message": payload})
            elif kind == "stopped":
                await websocket.send_json({"type": "stream_stopped"})
            elif kind == "event":
                msg = _event_to_dict(payload)
                if msg:
                    await websocket.send_json(msg)
        except Exception:
            client_alive = False


# ── Endpoints ─────────────────────────────────────────────────────────────────

@router.post("/sessions/{session_id}/stop")
async def stop_session(session_id: str) -> dict:
    """Signal the running agent for this session to stop cooperatively."""
    run = _runs.get(session_id)
    if run is not None:
        run.stop_event.set()
        return {"ok": True}
    return {"ok": False, "reason": "no active run"}


@router.websocket("/ws/sessions/{session_id}")
async def websocket_session(session_id: str, websocket: WebSocket) -> None:
    session_store = get_session_store()

    session = await session_store.get(session_id)
    if session is None:
        await websocket.close(code=4004, reason="Session not found")
        return

    await websocket.accept()
    log.info("ws.connected", session_id=session_id)

    from navi.api.deps import get_memory_store, get_registries, get_workers
    tools, profiles, backends, cp_registry = get_registries()
    agent = Agent(
        session_store, profiles, tools, backends,
        workers=get_workers(), memory_store=get_memory_store(), cp_registry=cp_registry,
    )

    queue: asyncio.Queue | None = None
    current_run: _AgentRun | None = None

    try:
        # Re-attach to an in-progress run (e.g. client reloaded the page mid-stream).
        existing = _runs.get(session_id)
        if existing is not None:
            current_run = existing
            # Subscribe BEFORE noting replay_count — single-threaded async, no race:
            # any broadcast() that happens after subscribe() goes into our queue,
            # and we replay only events[0:replay_count] from the buffer.
            queue = existing.subscribe()
            replay_count = len(existing.events)
            log.info("ws.reattached", session_id=session_id, replay_count=replay_count)
            await websocket.send_json({"type": "stream_start"})
            # Replay all events emitted before we subscribed so the client can
            # reconstruct the full in-progress UI state without any gaps.
            if replay_count > 0:
                await websocket.send_json({"type": "replay_start", "count": replay_count})
                for ev_dict in existing.events[:replay_count]:
                    try:
                        await websocket.send_json(ev_dict)
                    except Exception:
                        existing.unsubscribe(queue)
                        return
                await websocket.send_json({"type": "replay_end"})
            connected = await _stream_to_client(websocket, queue)
            existing.unsubscribe(queue)
            queue = None
            current_run = None
            if not connected:
                return  # client disconnected again — stop here
            # Stream finished — tell the client to sync session history so it sees
            # the full saved response (handles any events missed during disconnect).
            await websocket.send_json({"type": "session_sync"})
        else:
            # No active run — if this is a reconnect after the agent already finished,
            # the client needs to reload session history to see the saved response.
            await websocket.send_json({"type": "session_sync"})

        while True:
            raw = await websocket.receive_text()

            try:
                data = json.loads(raw)
            except json.JSONDecodeError:
                await websocket.send_json({"type": "error", "message": "Invalid JSON"})
                continue

            if data.get("type") != "message" or not data.get("content"):
                await websocket.send_json({
                    "type": "error",
                    "message": "Expected {type: 'message', content: '...'}",
                })
                continue

            original_content = data["content"]
            user_content = original_content
            raw_images: list[str] | None = data.get("images") or None
            if raw_images:
                # Guard against abuse: limit count and total payload size
                _MAX_IMAGES = 10
                _MAX_IMAGE_BYTES = 5 * 1024 * 1024  # 5 MB per image (base64 ~6.7 MB string)
                if len(raw_images) > _MAX_IMAGES:
                    await websocket.send_json({
                        "type": "error",
                        "message": f"Too many images ({len(raw_images)}). Max {_MAX_IMAGES} allowed.",
                    })
                    continue
                cleaned = []
                for img in raw_images:
                    if "," in img and img.startswith("data:"):
                        img = img.split(",", 1)[1]
                    if len(img.encode("utf-8")) > _MAX_IMAGE_BYTES:
                        await websocket.send_json({
                            "type": "error",
                            "message": "Image exceeds 5 MB limit.",
                        })
                        cleaned = None
                        break
                    cleaned.append(img)
                if cleaned is None:
                    continue
                raw_images = cleaned

            # Tell the model the inline images are already in its multimodal context,
            # so it doesn't hallucinate a path/URL and call image_view to "load" them.
            if raw_images:
                n = len(raw_images)
                noun = "image" if n == 1 else "images"
                user_content = (
                    user_content
                    + f"\n\n[{n} {noun} attached inline — already in your context, no extra loading needed.]"
                )

            # Append uploaded file paths to user content so Navi knows about them
            uploaded_files: list[dict] = data.get("files") or []
            if uploaded_files:
                file_lines = "\n".join(
                    f"- {f['name']} → {f['path']}" for f in uploaded_files
                )
                user_content = (
                    user_content + f"\n\n[Uploaded files on disk:\n{file_lines}]"
                )

            # Guard against concurrent runs for the same session.
            if session_id in _runs:
                await websocket.send_json({
                    "type": "error",
                    "message": "Agent is already running for this session.",
                })
                continue

            # Register run and subscribe before starting the task so we never
            # miss events even if the task is very fast.
            run = _AgentRun()
            queue = run.subscribe()
            current_run = run
            _runs[session_id] = run

            run.task = asyncio.create_task(
                _run_agent(run, agent, session_id, user_content, raw_images, original_content)
            )

            await websocket.send_json({"type": "stream_start"})
            connected = await _stream_to_client(websocket, queue)
            run.unsubscribe(queue)
            queue = None
            current_run = None

            if not connected:
                break  # avoid calling receive_text() on a dead socket

    except (WebSocketDisconnect, RuntimeError):
        log.info("ws.disconnected", session_id=session_id)
    finally:
        # Ensure queue is removed from subscribers on any abrupt exit.
        if queue is not None and current_run is not None:
            current_run.unsubscribe(queue)