Newer
Older
navi-1 / clients / terminal / ws_client.py
"""WebSocket client for streaming Navi responses."""

from __future__ import annotations

import asyncio
import json

import websockets

from clients.terminal.config import settings
from clients.terminal.render import Renderer


class NaviWebSocketClient:
    """Connect to /ws/sessions/<id> and render events."""

    def __init__(
        self,
        session_id: str,
        renderer: Renderer | None = None,
    ) -> None:
        self.session_id = session_id
        self.renderer = renderer or Renderer(
            show_thinking=settings.show_thinking,
            show_events=settings.show_events,
        )
        self.url = settings.websocket_url(session_id)
        self._ws: websockets.ClientConnection | None = None
        self._stop_event = asyncio.Event()
        self._input_queue: asyncio.Queue[str | None] = asyncio.Queue()

    async def connect(self) -> None:
        self._ws = await websockets.connect(self.url)

    async def close(self) -> None:
        if self._ws:
            await self._ws.close()
            self._ws = None

    async def send(self, content: str) -> None:
        if not self._ws:
            raise RuntimeError("WebSocket is not connected")
        await self._ws.send(json.dumps({"type": "message", "content": content}))

    async def receive_loop(self) -> None:
        if not self._ws:
            raise RuntimeError("WebSocket is not connected")
        try:
            async for raw in self._ws:
                try:
                    msg = json.loads(raw)
                except json.JSONDecodeError:
                    continue
                self.renderer.render(msg)
                if msg.get("type") in ("stream_end", "error"):
                    self._stop_event.set()
        except websockets.exceptions.ConnectionClosed:
            pass

    async def input_loop(self) -> None:
        while True:
            content = await self._input_queue.get()
            if content is None:
                break
            await self.send(content)

    def enqueue(self, content: str) -> None:
        self._input_queue.put_nowait(content)

    def stop_input(self) -> None:
        self._input_queue.put_nowait(None)

    async def wait_for_stream_end(self, timeout: float = 600.0) -> None:
        try:
            await asyncio.wait_for(self._stop_event.wait(), timeout=timeout)
        except asyncio.TimeoutError:
            pass

    def reset_stop(self) -> None:
        self._stop_event.clear()

    async def run_one_shot(self, content: str) -> None:
        await self.connect()
        try:
            receive_task = asyncio.create_task(self.receive_loop())
            await self.send(content)
            await self.wait_for_stream_end()
            self.stop_input()
            receive_task.cancel()
            try:
                await receive_task
            except asyncio.CancelledError:
                pass
        finally:
            await self.close()

    async def run_interactive(self) -> None:
        await self.connect()
        try:
            receive_task = asyncio.create_task(self.receive_loop())
            input_task = asyncio.create_task(self.input_loop())
            await asyncio.gather(receive_task, input_task)
        except websockets.exceptions.ConnectionClosed:
            pass
        finally:
            await self.close()