Newer
Older
navi-1 / tests / unit / api / test_websocket.py
@Eugene Sukhodolskiy Eugene Sukhodolskiy on 25 May 8 KB Fix 19 issues found in full codebase review
"""Unit tests for WebSocket handler internals and reconnect logic."""

import asyncio
import json
from unittest.mock import AsyncMock, MagicMock, patch

import pytest
from fastapi import WebSocketDisconnect

from navi.api import websocket as ws_mod
from navi.core.orchestrator import AgentSessionOrchestrator, SessionRun


@pytest.fixture(autouse=True)
def _clear_state(monkeypatch):
    """Clear global state before every WS test."""
    yield


@pytest.fixture
def mock_websocket():
    ws = AsyncMock()
    ws.accept = AsyncMock()
    ws.close = AsyncMock()
    ws.send_json = AsyncMock()
    return ws


@pytest.fixture
def mock_session():
    session = MagicMock()
    session.user_id = "test-user-id"
    return session


@pytest.fixture
def mock_user():
    user = MagicMock()
    user.id = "test-user-id"
    user.role = "admin"
    return user


@pytest.fixture
def fake_orchestrator():
    container = MagicMock()
    container.profile_registry = None
    container.tool_registry = None
    container.backend_registry = None
    container.cp_registry = None
    container.workers = []
    container.memory_store = None
    container.mcp_manager = None
    return AgentSessionOrchestrator(container)


# ── SessionRun buffer tests ─────────────────────────────────────────────────

@pytest.mark.anyio
async def test_event_buffer_appended_and_replayed():
    """Broadcast stores serialised events; oldest evicted when cap exceeded."""
    run = SessionRun()

    class FakeEvent:
        def __init__(self, idx: int) -> None:
            self.idx = idx

        def to_wire(self) -> dict:
            return {"type": "stream_delta", "delta": str(self.idx)}

    for i in range(3):
        await run.broadcast(("event", FakeEvent(i)))

    assert len(run.events) == 3
    assert run.events[0] == {"type": "stream_delta", "delta": "0"}
    assert run.events[2] == {"type": "stream_delta", "delta": "2"}

    # Fill buffer past limit
    run.events.clear()
    for i in range(500 + 5):
        await run.broadcast(("event", FakeEvent(i)))

    assert len(run.events) == 500
    assert run.events[0] == {"type": "stream_delta", "delta": "5"}
    assert run.events[-1] == {"type": "stream_delta", "delta": str(500 + 4)}


# ── Reconnect / replay tests ─────────────────────────────────────────────────

@pytest.mark.anyio
async def test_reconnect_replays_buffered_events(mock_websocket, mock_session, mock_user, monkeypatch):
    """Re-attach to active run yields replay_start, buffered events, replay_end, then session_sync."""
    monkeypatch.setattr(ws_mod, "get_current_user_ws", AsyncMock(return_value=mock_user))
    mock_store = MagicMock()
    mock_store.get = AsyncMock(return_value=mock_session)
    monkeypatch.setattr(ws_mod, "get_session_store", lambda: mock_store)
    monkeypatch.setattr(ws_mod, "_stream_to_client", AsyncMock(return_value=True))

    fake_container = MagicMock()
    fake_container.profile_registry = None
    fake_container.tool_registry = None
    fake_container.backend_registry = None
    fake_container.cp_registry = None
    fake_container.orchestrator = AgentSessionOrchestrator(fake_container)
    monkeypatch.setattr("navi.api.deps._resolve_container", lambda: fake_container)

    orchestrator = fake_container.orchestrator
    run = orchestrator.create_run("s1")
    run.events = [
        {"type": "stream_delta", "delta": "hello"},
        {"type": "thinking_delta", "delta": "hmm"},
    ]

    mock_websocket.receive_text = AsyncMock(side_effect=WebSocketDisconnect())

    await ws_mod.websocket_session("s1", mock_websocket)

    calls = [c.args[0] for c in mock_websocket.send_json.call_args_list]
    types = [c["type"] for c in calls]

    assert types == [
        "stream_start",
        "replay_start",
        "stream_delta",
        "thinking_delta",
        "replay_end",
        "session_sync",
    ]
    assert calls[1]["count"] == 2

    orchestrator._sessions.pop("s1", None)


@pytest.mark.anyio
async def test_session_sync_after_reconnect_when_done(mock_websocket, mock_session, mock_user, monkeypatch):
    """Reconnect when no run is active → only session_sync, no replay."""
    monkeypatch.setattr(ws_mod, "get_current_user_ws", AsyncMock(return_value=mock_user))
    mock_store = MagicMock()
    mock_store.get = AsyncMock(return_value=mock_session)
    monkeypatch.setattr(ws_mod, "get_session_store", lambda: mock_store)
    monkeypatch.setattr(ws_mod, "_stream_to_client", AsyncMock(return_value=True))

    fake_container = MagicMock()
    fake_container.profile_registry = None
    fake_container.tool_registry = None
    fake_container.backend_registry = None
    fake_container.cp_registry = None
    fake_container.orchestrator = AgentSessionOrchestrator(fake_container)
    monkeypatch.setattr("navi.api.deps._resolve_container", lambda: fake_container)

    mock_websocket.receive_text = AsyncMock(side_effect=WebSocketDisconnect())

    await ws_mod.websocket_session("s1", mock_websocket)

    calls = [c.args[0] for c in mock_websocket.send_json.call_args_list]
    assert len(calls) == 1
    assert calls[0]["type"] == "session_sync"


@pytest.mark.anyio
async def test_session_sync_after_recall_run(mock_websocket, mock_session, mock_user, monkeypatch):
    """Reconnect while a headless recall run is active → session_sync (no replay, no error)."""
    monkeypatch.setattr(ws_mod, "get_current_user_ws", AsyncMock(return_value=mock_user))
    mock_store = MagicMock()
    mock_store.get = AsyncMock(return_value=mock_session)
    monkeypatch.setattr(ws_mod, "get_session_store", lambda: mock_store)
    monkeypatch.setattr(ws_mod, "_stream_to_client", AsyncMock(return_value=True))

    fake_container = MagicMock()
    fake_container.profile_registry = None
    fake_container.tool_registry = None
    fake_container.backend_registry = None
    fake_container.cp_registry = None
    fake_container.orchestrator = AgentSessionOrchestrator(fake_container)
    monkeypatch.setattr("navi.api.deps._resolve_container", lambda: fake_container)

    orchestrator = fake_container.orchestrator
    orchestrator.mark_busy("s1")

    mock_websocket.receive_text = AsyncMock(side_effect=WebSocketDisconnect())

    await ws_mod.websocket_session("s1", mock_websocket)

    calls = [c.args[0] for c in mock_websocket.send_json.call_args_list]
    types = [c["type"] for c in calls]
    assert types == ["session_sync"]

    await orchestrator.clear_busy("s1")


# ── Concurrent run guard ─────────────────────────────────────────────────────

@pytest.mark.anyio
async def test_concurrent_run_guard_rejects_second_message(mock_websocket, mock_session, mock_user, monkeypatch):
    """Sending a second message while a run is active yields a WebSocket error."""
    monkeypatch.setattr(ws_mod, "get_current_user_ws", AsyncMock(return_value=mock_user))
    mock_store = MagicMock()
    mock_store.get = AsyncMock(return_value=mock_session)
    monkeypatch.setattr(ws_mod, "get_session_store", lambda: mock_store)
    monkeypatch.setattr(ws_mod, "_stream_to_client", AsyncMock(return_value=True))

    fake_container = MagicMock()
    fake_container.profile_registry = None
    fake_container.tool_registry = None
    fake_container.backend_registry = None
    fake_container.cp_registry = None
    fake_container.orchestrator = AgentSessionOrchestrator(fake_container)
    monkeypatch.setattr("navi.api.deps._resolve_container", lambda: fake_container)

    orchestrator = fake_container.orchestrator

    # _run_agent sleeps so the run stays registered
    async def fake_run_agent(*a, **kw):
        await asyncio.sleep(3600)

    monkeypatch.setattr(orchestrator, "run_agent", fake_run_agent)

    message_count = 0

    async def fake_receive_text():
        nonlocal message_count
        message_count += 1
        if message_count == 1:
            return json.dumps({"type": "message", "content": "first"})
        if message_count == 2:
            return json.dumps({"type": "message", "content": "second"})
        raise WebSocketDisconnect()

    mock_websocket.receive_text = fake_receive_text

    await ws_mod.websocket_session("s1", mock_websocket)

    calls = [c.args[0] for c in mock_websocket.send_json.call_args_list]
    error_calls = [c for c in calls if c["type"] == "error"]
    assert len(error_calls) == 1
    assert "already running" in error_calls[0]["message"]

    # Cleanup background task
    state = orchestrator._sessions.get("s1")
    if state and state.run and state.run.task:
        state.run.task.cancel()
        try:
            await state.run.task
        except asyncio.CancelledError:
            pass
    orchestrator._sessions.pop("s1", None)