"""WebSocket client for streaming Navi responses."""
from __future__ import annotations
import asyncio
import json
import websockets
from clients.terminal.config import settings
from clients.terminal.render import Renderer
class NaviWebSocketClient:
"""Connect to /ws/sessions/<id> and render events."""
def __init__(
self,
session_id: str,
renderer: Renderer | None = None,
) -> None:
self.session_id = session_id
self.renderer = renderer or Renderer(
show_thinking=settings.show_thinking,
show_events=settings.show_events,
)
self.url = settings.websocket_url(session_id)
self._ws: websockets.ClientConnection | None = None
self._stop_event = asyncio.Event()
self._input_queue: asyncio.Queue[str | None] = asyncio.Queue()
async def connect(self) -> None:
self._ws = await websockets.connect(self.url)
async def close(self) -> None:
if self._ws:
await self._ws.close()
self._ws = None
async def send(self, content: str) -> None:
if not self._ws:
raise RuntimeError("WebSocket is not connected")
await self._ws.send(json.dumps({"type": "message", "content": content}))
async def receive_loop(self) -> None:
if not self._ws:
raise RuntimeError("WebSocket is not connected")
try:
async for raw in self._ws:
try:
msg = json.loads(raw)
except json.JSONDecodeError:
continue
self.renderer.render(msg)
if msg.get("type") in ("stream_end", "error"):
self._stop_event.set()
except websockets.exceptions.ConnectionClosed:
pass
async def input_loop(self) -> None:
while True:
content = await self._input_queue.get()
if content is None:
break
await self.send(content)
def enqueue(self, content: str) -> None:
self._input_queue.put_nowait(content)
def stop_input(self) -> None:
self._input_queue.put_nowait(None)
async def wait_for_stream_end(self, timeout: float = 600.0) -> None:
try:
await asyncio.wait_for(self._stop_event.wait(), timeout=timeout)
except asyncio.TimeoutError:
pass
def reset_stop(self) -> None:
self._stop_event.clear()
async def run_one_shot(self, content: str) -> None:
await self.connect()
try:
receive_task = asyncio.create_task(self.receive_loop())
await self.send(content)
await self.wait_for_stream_end()
self.stop_input()
receive_task.cancel()
try:
await receive_task
except asyncio.CancelledError:
pass
finally:
await self.close()
async def run_interactive(self) -> None:
await self.connect()
try:
receive_task = asyncio.create_task(self.receive_loop())
input_task = asyncio.create_task(self.input_loop())
await asyncio.gather(receive_task, input_task)
except websockets.exceptions.ConnectionClosed:
pass
finally:
await self.close()