Newer
Older
navi-1 / navi / api / websocket.py
"""
WebSocket endpoint for streaming agent responses.

Protocol (client -> server):
  {"type": "message", "content": "..."}

Protocol (server -> client):
  {"type": "stream_start"}
  {"type": "stream_delta",    "delta": "..."}       # text chunk
  {"type": "tool_call",       "tool": "...", "args": {...}, "result": "...", "success": bool}
  {"type": "stream_end",      "content": "..."}     # full assembled response
  {"type": "error",           "message": "..."}
"""

import json

import structlog
from fastapi import APIRouter, WebSocket, WebSocketDisconnect

from navi.api.deps import get_agent, get_session_store
from navi.core import Agent, InMemorySessionStore, StreamEnd, TextDelta, ToolEvent
from navi.exceptions import MaxIterationsReached, NaviError, SessionNotFound

router = APIRouter(tags=["websocket"])
log = structlog.get_logger()


@router.websocket("/ws/sessions/{session_id}")
async def websocket_session(session_id: str, websocket: WebSocket) -> None:
    session_store = get_session_store()

    # Validate session exists before accepting
    session = await session_store.get(session_id)
    if session is None:
        await websocket.close(code=4004, reason="Session not found")
        return

    await websocket.accept()
    log.info("ws.connected", session_id=session_id)

    # Build agent (can't use FastAPI Depends inside WebSocket directly)
    from navi.api.deps import _registries
    tools, profiles, backends = _registries()
    agent = Agent(session_store, profiles, tools, backends)

    try:
        while True:
            raw = await websocket.receive_text()
            try:
                data = json.loads(raw)
            except json.JSONDecodeError:
                await websocket.send_json({"type": "error", "message": "Invalid JSON"})
                continue

            if data.get("type") != "message" or not data.get("content"):
                await websocket.send_json({"type": "error", "message": "Expected {type: 'message', content: '...'}"})
                continue

            user_content = data["content"]
            await websocket.send_json({"type": "stream_start"})

            try:
                async for event in agent.run_stream(session_id, user_content):
                    if isinstance(event, TextDelta):
                        await websocket.send_json({"type": "stream_delta", "delta": event.delta})
                    elif isinstance(event, ToolEvent):
                        await websocket.send_json({
                            "type": "tool_call",
                            "tool": event.tool_name,
                            "args": event.arguments,
                            "result": event.result,
                            "success": event.success,
                        })
                    elif isinstance(event, StreamEnd):
                        await websocket.send_json({"type": "stream_end", "content": event.full_content})

            except SessionNotFound:
                await websocket.send_json({"type": "error", "message": "Session not found"})
            except MaxIterationsReached as e:
                await websocket.send_json({"type": "error", "message": str(e)})
            except NaviError as e:
                await websocket.send_json({"type": "error", "message": str(e)})
            except Exception as e:
                log.exception("ws.agent_error", session_id=session_id)
                await websocket.send_json({"type": "error", "message": f"Internal error: {e}"})

    except WebSocketDisconnect:
        log.info("ws.disconnected", session_id=session_id)