Newer
Older
navi-1 / navi / api / websocket.py
"""
WebSocket endpoint for streaming agent responses.

Protocol (client -> server):
  {"type": "message", "content": "..."}

Protocol (server -> client):
  {"type": "stream_start"}
  {"type": "thinking_delta",  "delta": "..."}
  {"type": "thinking_end"}
  {"type": "stream_delta",    "delta": "..."}
  {"type": "tool_started",    "tool": "...", "args": {...}, "is_subagent": bool}
  {"type": "tool_call",       "tool": "...", "args": {...}, "result": "...", "success": bool, "is_subagent": bool}
  {"type": "stream_end",      "content": "..."}
  {"type": "context_compressed"}
  {"type": "error",           "message": "..."}
"""

import asyncio
import dataclasses
import json

import structlog
from fastapi import APIRouter, WebSocket, WebSocketDisconnect

from navi.api.deps import get_session_store
from navi.core import Agent, ContextCompressed, StreamEnd, TextDelta, ThinkingDelta, ThinkingEnd, ToolEvent
from navi.core.events import ToolStarted, TurnThinking
from navi.exceptions import MaxIterationsReached, NaviError, SessionNotFound

router = APIRouter(tags=["websocket"])
log = structlog.get_logger()


# ── Per-session run state ──────────────────────────────────────────────────────

@dataclasses.dataclass
class _AgentRun:
    """Holds the running agent task and all active subscriber queues."""
    task: asyncio.Task | None = None
    subscribers: list[asyncio.Queue] = dataclasses.field(default_factory=list)

    def subscribe(self) -> asyncio.Queue:
        q: asyncio.Queue = asyncio.Queue()
        self.subscribers.append(q)
        return q

    def unsubscribe(self, q: asyncio.Queue) -> None:
        try:
            self.subscribers.remove(q)
        except ValueError:
            pass

    async def broadcast(self, item) -> None:
        for q in list(self.subscribers):
            await q.put(item)


# session_id → active run (present only while agent is executing)
_runs: dict[str, _AgentRun] = {}


# ── Helpers ───────────────────────────────────────────────────────────────────

def _event_to_dict(event) -> dict | None:
    if isinstance(event, ThinkingDelta):
        return {"type": "thinking_delta", "delta": event.delta}
    if isinstance(event, ThinkingEnd):
        return {"type": "thinking_end"}
    if isinstance(event, TextDelta):
        return {"type": "stream_delta", "delta": event.delta}
    if isinstance(event, ToolStarted):
        return {
            "type": "tool_started",
            "tool": event.tool_name,
            "args": event.arguments,
            "is_subagent": event.is_subagent,
        }
    if isinstance(event, ToolEvent):
        return {
            "type": "tool_call",
            "tool": event.tool_name,
            "args": event.arguments,
            "result": event.result,
            "success": event.success,
            "is_subagent": event.is_subagent,
        }
    if isinstance(event, StreamEnd):
        return {
            "type": "stream_end",
            "content": event.full_content,
            "context_tokens": event.context_tokens,
            "max_context_tokens": event.max_context_tokens,
        }
    if isinstance(event, ContextCompressed):
        return {
            "type": "context_compressed",
            "messages_before": event.messages_before,
            "messages_after": event.messages_after,
        }
    if isinstance(event, TurnThinking):
        return {"type": "turn_thinking", "thinking": event.thinking, "is_subagent": event.is_subagent}
    return None


async def _run_agent(
    run: _AgentRun,
    agent: Agent,
    session_id: str,
    user_content: str,
    raw_images: list[str] | None,
) -> None:
    """
    Execute the agent to completion, broadcasting events to all subscribers.
    The session is saved by run_stream before StreamEnd — guaranteed even on disconnect.
    """
    try:
        async for event in agent.run_stream(session_id, user_content, images=raw_images):
            await run.broadcast(("event", event))
    except SessionNotFound:
        await run.broadcast(("error", "Session not found"))
    except MaxIterationsReached as e:
        await run.broadcast(("error", str(e)))
    except NaviError as e:
        await run.broadcast(("error", str(e)))
    except Exception as e:
        log.exception("ws.agent_error", session_id=session_id)
        await run.broadcast(("error", f"Internal error: {e}"))
    finally:
        await run.broadcast(("done", None))
        _runs.pop(session_id, None)


async def _stream_to_client(websocket: WebSocket, queue: asyncio.Queue) -> bool:
    """
    Forward queue items to the WebSocket until ("done", None).
    Returns True if client stayed connected, False if it disconnected mid-stream.
    On disconnect, continues draining silently so the agent task can finish cleanly.
    """
    client_alive = True
    while True:
        kind, payload = await queue.get()
        if kind == "done":
            return client_alive
        if not client_alive:
            continue  # keep draining so run_agent can broadcast without blocking
        try:
            if kind == "error":
                await websocket.send_json({"type": "error", "message": payload})
            elif kind == "event":
                msg = _event_to_dict(payload)
                if msg:
                    await websocket.send_json(msg)
        except Exception:
            client_alive = False


# ── Endpoint ──────────────────────────────────────────────────────────────────

@router.websocket("/ws/sessions/{session_id}")
async def websocket_session(session_id: str, websocket: WebSocket) -> None:
    session_store = get_session_store()

    session = await session_store.get(session_id)
    if session is None:
        await websocket.close(code=4004, reason="Session not found")
        return

    await websocket.accept()
    log.info("ws.connected", session_id=session_id)

    from navi.api.deps import get_memory_store, get_registries, get_workers
    tools, profiles, backends = get_registries()
    agent = Agent(
        session_store, profiles, tools, backends,
        workers=get_workers(), memory_store=get_memory_store(),
    )

    queue: asyncio.Queue | None = None
    current_run: _AgentRun | None = None

    try:
        # Re-attach to an in-progress run (e.g. client reloaded the page mid-stream).
        existing = _runs.get(session_id)
        if existing is not None:
            current_run = existing
            queue = existing.subscribe()
            log.info("ws.reattached", session_id=session_id)
            await websocket.send_json({"type": "stream_start"})
            connected = await _stream_to_client(websocket, queue)
            existing.unsubscribe(queue)
            queue = None
            current_run = None
            if not connected:
                return  # client disconnected again — stop here

        while True:
            raw = await websocket.receive_text()

            try:
                data = json.loads(raw)
            except json.JSONDecodeError:
                await websocket.send_json({"type": "error", "message": "Invalid JSON"})
                continue

            if data.get("type") != "message" or not data.get("content"):
                await websocket.send_json({
                    "type": "error",
                    "message": "Expected {type: 'message', content: '...'}",
                })
                continue

            user_content = data["content"]
            raw_images: list[str] | None = data.get("images") or None
            if raw_images:
                cleaned = []
                for img in raw_images:
                    if "," in img and img.startswith("data:"):
                        img = img.split(",", 1)[1]
                    cleaned.append(img)
                raw_images = cleaned

            # Append uploaded file paths to user content so Navi knows about them
            uploaded_files: list[dict] = data.get("files") or []
            if uploaded_files:
                file_lines = "\n".join(
                    f"- {f['name']} → {f['path']}" for f in uploaded_files
                )
                user_content = (
                    user_content + f"\n\n[Uploaded files on disk:\n{file_lines}]"
                )

            # Register run and subscribe before starting the task so we never
            # miss events even if the task is very fast.
            run = _AgentRun()
            queue = run.subscribe()
            current_run = run
            _runs[session_id] = run

            run.task = asyncio.create_task(
                _run_agent(run, agent, session_id, user_content, raw_images)
            )

            await websocket.send_json({"type": "stream_start"})
            connected = await _stream_to_client(websocket, queue)
            run.unsubscribe(queue)
            queue = None
            current_run = None

            if not connected:
                break  # avoid calling receive_text() on a dead socket

    except (WebSocketDisconnect, RuntimeError):
        log.info("ws.disconnected", session_id=session_id)
    finally:
        # Ensure queue is removed from subscribers on any abrupt exit.
        if queue is not None and current_run is not None:
            current_run.unsubscribe(queue)