diff --git a/tests/unit/api/test_websocket.py b/tests/unit/api/test_websocket.py new file mode 100644 index 0000000..1c0764f --- /dev/null +++ b/tests/unit/api/test_websocket.py @@ -0,0 +1,232 @@ +"""Unit tests for WebSocket handler internals and reconnect logic.""" + +import asyncio +import json +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest +from fastapi import WebSocketDisconnect + +from navi.api import websocket as ws_mod + + +@pytest.fixture(autouse=True) +def _clear_state(monkeypatch): + """Clear global state before every WS test.""" + ws_mod._runs.clear() + ws_mod._busy_sessions.clear() + ws_mod._session_sockets.clear() + yield + + +@pytest.fixture +def mock_websocket(): + ws = AsyncMock() + ws.accept = AsyncMock() + ws.close = AsyncMock() + ws.send_json = AsyncMock() + return ws + + +@pytest.fixture +def mock_session(): + session = MagicMock() + session.user_id = "test-user-id" + return session + + +@pytest.fixture +def mock_user(): + user = MagicMock() + user.id = "test-user-id" + user.role = "admin" + return user + + +# ── _AgentRun buffer tests ─────────────────────────────────────────────────── + +@pytest.mark.anyio +async def test_event_buffer_appended_and_replayed(): + """Broadcast stores serialised events; oldest evicted when cap exceeded.""" + run = ws_mod._AgentRun() + + class FakeEvent: + def __init__(self, idx: int) -> None: + self.idx = idx + + def to_wire(self) -> dict: + return {"type": "stream_delta", "delta": str(self.idx)} + + for i in range(3): + await run.broadcast(("event", FakeEvent(i))) + + assert len(run.events) == 3 + assert run.events[0] == {"type": "stream_delta", "delta": "0"} + assert run.events[2] == {"type": "stream_delta", "delta": "2"} + + # Fill buffer past limit + run.events.clear() + for i in range(ws_mod._MAX_REPLAY_EVENTS + 5): + await run.broadcast(("event", FakeEvent(i))) + + assert len(run.events) == ws_mod._MAX_REPLAY_EVENTS + assert run.events[0] == {"type": "stream_delta", "delta": "5"} + assert run.events[-1] == {"type": "stream_delta", "delta": str(ws_mod._MAX_REPLAY_EVENTS + 4)} + + +# ── Reconnect / replay tests ─────────────────────────────────────────────────── + +@pytest.mark.anyio +async def test_reconnect_replays_buffered_events(mock_websocket, mock_session, mock_user, monkeypatch): + """Re-attach to active run yields replay_start, buffered events, replay_end, then session_sync.""" + monkeypatch.setattr(ws_mod, "get_current_user_ws", AsyncMock(return_value=mock_user)) + mock_store = MagicMock() + mock_store.get = AsyncMock(return_value=mock_session) + monkeypatch.setattr(ws_mod, "get_session_store", lambda: mock_store) + monkeypatch.setattr(ws_mod, "Agent", lambda *a, **kw: MagicMock()) + monkeypatch.setattr(ws_mod, "_stream_to_client", AsyncMock(return_value=True)) + # Provide a dummy container so _resolve_container() succeeds + fake_container = MagicMock() + fake_container.profile_registry = None + fake_container.tool_registry = None + fake_container.backend_registry = None + fake_container.cp_registry = None + monkeypatch.setattr("navi.api.deps._resolve_container", lambda: fake_container) + + run = ws_mod._AgentRun() + run.events = [ + {"type": "stream_delta", "delta": "hello"}, + {"type": "thinking_delta", "delta": "hmm"}, + ] + ws_mod._runs["s1"] = run + + mock_websocket.receive_text = AsyncMock(side_effect=WebSocketDisconnect()) + + await ws_mod.websocket_session("s1", mock_websocket) + + calls = [c.args[0] for c in mock_websocket.send_json.call_args_list] + types = [c["type"] for c in calls] + + assert types == [ + "stream_start", + "replay_start", + "stream_delta", + "thinking_delta", + "replay_end", + "session_sync", + ] + assert calls[1]["count"] == 2 + + ws_mod._runs.pop("s1", None) + + +@pytest.mark.anyio +async def test_session_sync_after_reconnect_when_done(mock_websocket, mock_session, mock_user, monkeypatch): + """Reconnect when no run is active → only session_sync, no replay.""" + monkeypatch.setattr(ws_mod, "get_current_user_ws", AsyncMock(return_value=mock_user)) + mock_store = MagicMock() + mock_store.get = AsyncMock(return_value=mock_session) + monkeypatch.setattr(ws_mod, "get_session_store", lambda: mock_store) + monkeypatch.setattr(ws_mod, "Agent", lambda *a, **kw: MagicMock()) + monkeypatch.setattr(ws_mod, "_stream_to_client", AsyncMock(return_value=True)) + # Provide a dummy container so _resolve_container() succeeds + fake_container = MagicMock() + fake_container.profile_registry = None + fake_container.tool_registry = None + fake_container.backend_registry = None + fake_container.cp_registry = None + monkeypatch.setattr("navi.api.deps._resolve_container", lambda: fake_container) + + mock_websocket.receive_text = AsyncMock(side_effect=WebSocketDisconnect()) + + await ws_mod.websocket_session("s1", mock_websocket) + + calls = [c.args[0] for c in mock_websocket.send_json.call_args_list] + assert len(calls) == 1 + assert calls[0]["type"] == "session_sync" + + +@pytest.mark.anyio +async def test_session_sync_after_recall_run(mock_websocket, mock_session, mock_user, monkeypatch): + """Reconnect while a headless recall run is active → session_sync (no replay, no error).""" + monkeypatch.setattr(ws_mod, "get_current_user_ws", AsyncMock(return_value=mock_user)) + mock_store = MagicMock() + mock_store.get = AsyncMock(return_value=mock_session) + monkeypatch.setattr(ws_mod, "get_session_store", lambda: mock_store) + monkeypatch.setattr(ws_mod, "Agent", lambda *a, **kw: MagicMock()) + monkeypatch.setattr(ws_mod, "_stream_to_client", AsyncMock(return_value=True)) + # Provide a dummy container so _resolve_container() succeeds + fake_container = MagicMock() + fake_container.profile_registry = None + fake_container.tool_registry = None + fake_container.backend_registry = None + fake_container.cp_registry = None + monkeypatch.setattr("navi.api.deps._resolve_container", lambda: fake_container) + + ws_mod._busy_sessions["s1"] = asyncio.Event() + + mock_websocket.receive_text = AsyncMock(side_effect=WebSocketDisconnect()) + + await ws_mod.websocket_session("s1", mock_websocket) + + calls = [c.args[0] for c in mock_websocket.send_json.call_args_list] + types = [c["type"] for c in calls] + assert types == ["session_sync"] + + ws_mod._busy_sessions.pop("s1", None) + + +# ── Concurrent run guard ───────────────────────────────────────────────────── + +@pytest.mark.anyio +async def test_concurrent_run_guard_rejects_second_message(mock_websocket, mock_session, mock_user, monkeypatch): + """Sending a second message while a run is active yields a WebSocket error.""" + monkeypatch.setattr(ws_mod, "get_current_user_ws", AsyncMock(return_value=mock_user)) + mock_store = MagicMock() + mock_store.get = AsyncMock(return_value=mock_session) + monkeypatch.setattr(ws_mod, "get_session_store", lambda: mock_store) + monkeypatch.setattr(ws_mod, "Agent", lambda *a, **kw: MagicMock()) + monkeypatch.setattr(ws_mod, "_stream_to_client", AsyncMock(return_value=True)) + # Provide a dummy container so _resolve_container() succeeds + fake_container = MagicMock() + fake_container.profile_registry = None + fake_container.tool_registry = None + fake_container.backend_registry = None + fake_container.cp_registry = None + monkeypatch.setattr("navi.api.deps._resolve_container", lambda: fake_container) + + # _run_agent sleeps so the run stays registered + async def fake_run_agent(*a, **kw): + await asyncio.sleep(3600) + + monkeypatch.setattr(ws_mod, "_run_agent", fake_run_agent) + + message_count = 0 + + async def fake_receive_text(): + nonlocal message_count + message_count += 1 + if message_count == 1: + return json.dumps({"type": "message", "content": "first"}) + if message_count == 2: + return json.dumps({"type": "message", "content": "second"}) + raise WebSocketDisconnect() + + mock_websocket.receive_text = fake_receive_text + + await ws_mod.websocket_session("s1", mock_websocket) + + calls = [c.args[0] for c in mock_websocket.send_json.call_args_list] + error_calls = [c for c in calls if c["type"] == "error"] + assert len(error_calls) == 1 + assert "already running" in error_calls[0]["message"] + + # Cleanup background task + run = ws_mod._runs.get("s1") + if run and run.task: + run.task.cancel() + try: + await run.task + except asyncio.CancelledError: + pass + ws_mod._runs.pop("s1", None)