"""Unit tests for WebSocket handler internals and reconnect logic."""
import asyncio
import json
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from fastapi import WebSocketDisconnect
from navi.api import websocket as ws_mod
from navi.core.orchestrator import AgentSessionOrchestrator, SessionRun
@pytest.fixture(autouse=True)
def _clear_state(monkeypatch):
"""Clear global state before every WS test."""
yield
@pytest.fixture
def mock_websocket():
ws = AsyncMock()
ws.accept = AsyncMock()
ws.close = AsyncMock()
ws.send_json = AsyncMock()
return ws
@pytest.fixture
def mock_session():
session = MagicMock()
session.user_id = "test-user-id"
return session
@pytest.fixture
def mock_user():
user = MagicMock()
user.id = "test-user-id"
user.role = "admin"
return user
@pytest.fixture
def fake_orchestrator():
container = MagicMock()
container.profile_registry = None
container.tool_registry = None
container.backend_registry = None
container.cp_registry = None
container.workers = []
container.memory_store = None
container.mcp_manager = None
return AgentSessionOrchestrator(container)
# ── SessionRun buffer tests ─────────────────────────────────────────────────
@pytest.mark.anyio
async def test_event_buffer_appended_and_replayed():
"""Broadcast stores serialised events; oldest evicted when cap exceeded."""
run = SessionRun()
class FakeEvent:
def __init__(self, idx: int) -> None:
self.idx = idx
def to_wire(self) -> dict:
return {"type": "stream_delta", "delta": str(self.idx)}
for i in range(3):
await run.broadcast(("event", FakeEvent(i)))
assert len(run.events) == 3
assert run.events[0] == {"type": "stream_delta", "delta": "0"}
assert run.events[2] == {"type": "stream_delta", "delta": "2"}
# Fill buffer past limit
run.events.clear()
for i in range(500 + 5):
await run.broadcast(("event", FakeEvent(i)))
assert len(run.events) == 500
assert run.events[0] == {"type": "stream_delta", "delta": "5"}
assert run.events[-1] == {"type": "stream_delta", "delta": str(500 + 4)}
# ── Reconnect / replay tests ─────────────────────────────────────────────────
@pytest.mark.anyio
async def test_reconnect_replays_buffered_events(mock_websocket, mock_session, mock_user, monkeypatch):
"""Re-attach to active run yields replay_start, buffered events, replay_end, then session_sync."""
monkeypatch.setattr(ws_mod, "get_current_user_ws", AsyncMock(return_value=mock_user))
mock_store = MagicMock()
mock_store.get = AsyncMock(return_value=mock_session)
monkeypatch.setattr(ws_mod, "get_session_store", lambda: mock_store)
monkeypatch.setattr(ws_mod, "_stream_to_client", AsyncMock(return_value=True))
fake_container = MagicMock()
fake_container.profile_registry = None
fake_container.tool_registry = None
fake_container.backend_registry = None
fake_container.cp_registry = None
fake_container.orchestrator = AgentSessionOrchestrator(fake_container)
monkeypatch.setattr("navi.api.deps._resolve_container", lambda: fake_container)
orchestrator = fake_container.orchestrator
run = orchestrator.create_run("s1")
run.events = [
{"type": "stream_delta", "delta": "hello"},
{"type": "thinking_delta", "delta": "hmm"},
]
mock_websocket.receive_text = AsyncMock(side_effect=WebSocketDisconnect())
await ws_mod.websocket_session("s1", mock_websocket)
calls = [c.args[0] for c in mock_websocket.send_json.call_args_list]
types = [c["type"] for c in calls]
assert types == [
"stream_start",
"replay_start",
"stream_delta",
"thinking_delta",
"replay_end",
"session_sync",
]
assert calls[1]["count"] == 2
orchestrator._sessions.pop("s1", None)
@pytest.mark.anyio
async def test_session_sync_after_reconnect_when_done(mock_websocket, mock_session, mock_user, monkeypatch):
"""Reconnect when no run is active → only session_sync, no replay."""
monkeypatch.setattr(ws_mod, "get_current_user_ws", AsyncMock(return_value=mock_user))
mock_store = MagicMock()
mock_store.get = AsyncMock(return_value=mock_session)
monkeypatch.setattr(ws_mod, "get_session_store", lambda: mock_store)
monkeypatch.setattr(ws_mod, "_stream_to_client", AsyncMock(return_value=True))
fake_container = MagicMock()
fake_container.profile_registry = None
fake_container.tool_registry = None
fake_container.backend_registry = None
fake_container.cp_registry = None
fake_container.orchestrator = AgentSessionOrchestrator(fake_container)
monkeypatch.setattr("navi.api.deps._resolve_container", lambda: fake_container)
mock_websocket.receive_text = AsyncMock(side_effect=WebSocketDisconnect())
await ws_mod.websocket_session("s1", mock_websocket)
calls = [c.args[0] for c in mock_websocket.send_json.call_args_list]
assert len(calls) == 1
assert calls[0]["type"] == "session_sync"
@pytest.mark.anyio
async def test_session_sync_after_recall_run(mock_websocket, mock_session, mock_user, monkeypatch):
"""Reconnect while a headless recall run is active → session_sync (no replay, no error)."""
monkeypatch.setattr(ws_mod, "get_current_user_ws", AsyncMock(return_value=mock_user))
mock_store = MagicMock()
mock_store.get = AsyncMock(return_value=mock_session)
monkeypatch.setattr(ws_mod, "get_session_store", lambda: mock_store)
monkeypatch.setattr(ws_mod, "_stream_to_client", AsyncMock(return_value=True))
fake_container = MagicMock()
fake_container.profile_registry = None
fake_container.tool_registry = None
fake_container.backend_registry = None
fake_container.cp_registry = None
fake_container.orchestrator = AgentSessionOrchestrator(fake_container)
monkeypatch.setattr("navi.api.deps._resolve_container", lambda: fake_container)
orchestrator = fake_container.orchestrator
orchestrator.mark_busy("s1")
mock_websocket.receive_text = AsyncMock(side_effect=WebSocketDisconnect())
await ws_mod.websocket_session("s1", mock_websocket)
calls = [c.args[0] for c in mock_websocket.send_json.call_args_list]
types = [c["type"] for c in calls]
assert types == ["session_sync"]
await orchestrator.clear_busy("s1")
# ── Concurrent run guard ─────────────────────────────────────────────────────
@pytest.mark.anyio
async def test_concurrent_run_guard_rejects_second_message(mock_websocket, mock_session, mock_user, monkeypatch):
"""Sending a second message while a run is active yields a WebSocket error."""
monkeypatch.setattr(ws_mod, "get_current_user_ws", AsyncMock(return_value=mock_user))
mock_store = MagicMock()
mock_store.get = AsyncMock(return_value=mock_session)
monkeypatch.setattr(ws_mod, "get_session_store", lambda: mock_store)
monkeypatch.setattr(ws_mod, "_stream_to_client", AsyncMock(return_value=True))
fake_container = MagicMock()
fake_container.profile_registry = None
fake_container.tool_registry = None
fake_container.backend_registry = None
fake_container.cp_registry = None
fake_container.orchestrator = AgentSessionOrchestrator(fake_container)
monkeypatch.setattr("navi.api.deps._resolve_container", lambda: fake_container)
orchestrator = fake_container.orchestrator
# _run_agent sleeps so the run stays registered
async def fake_run_agent(*a, **kw):
await asyncio.sleep(3600)
monkeypatch.setattr(orchestrator, "run_agent", fake_run_agent)
message_count = 0
async def fake_receive_text():
nonlocal message_count
message_count += 1
if message_count == 1:
return json.dumps({"type": "message", "content": "first"})
if message_count == 2:
return json.dumps({"type": "message", "content": "second"})
raise WebSocketDisconnect()
mock_websocket.receive_text = fake_receive_text
await ws_mod.websocket_session("s1", mock_websocket)
calls = [c.args[0] for c in mock_websocket.send_json.call_args_list]
error_calls = [c for c in calls if c["type"] == "error"]
assert len(error_calls) == 1
assert "already running" in error_calls[0]["message"]
# Cleanup background task
state = orchestrator._sessions.get("s1")
if state and state.run and state.run.task:
state.run.task.cancel()
try:
await state.run.task
except asyncio.CancelledError:
pass
orchestrator._sessions.pop("s1", None)