Newer
Older
navi-1 / tests / unit / api / test_websocket.py
"""Unit tests for WebSocket handler internals and reconnect logic."""

import asyncio
import json
from unittest.mock import AsyncMock, MagicMock, patch

import pytest
from fastapi import WebSocketDisconnect

from navi.api import websocket as ws_mod
from navi.core.orchestrator import AgentSessionOrchestrator, SessionRun


@pytest.fixture(autouse=True)
def _clear_state(monkeypatch):
    """Clear global state before every WS test."""
    yield


@pytest.fixture
def mock_websocket():
    ws = AsyncMock()
    ws.accept = AsyncMock()
    ws.close = AsyncMock()
    ws.send_json = AsyncMock()
    return ws


@pytest.fixture
def mock_session():
    session = MagicMock()
    session.user_id = "test-user-id"
    return session


@pytest.fixture
def mock_user():
    user = MagicMock()
    user.id = "test-user-id"
    user.role = "admin"
    return user


@pytest.fixture
def fake_orchestrator():
    container = MagicMock()
    container.profile_registry = None
    container.tool_registry = None
    container.backend_registry = None
    container.cp_registry = None
    container.workers = []
    container.memory_store = None
    container.mcp_manager = None
    return AgentSessionOrchestrator(container)


# ── SessionRun buffer tests ─────────────────────────────────────────────────

@pytest.mark.anyio
async def test_event_buffer_appended_and_replayed():
    """Broadcast stores serialised events; oldest evicted when cap exceeded."""
    run = SessionRun()

    class FakeEvent:
        def __init__(self, idx: int) -> None:
            self.idx = idx

        def to_wire(self) -> dict:
            return {"type": "stream_delta", "delta": str(self.idx)}

    for i in range(3):
        await run.broadcast(("event", FakeEvent(i)))

    assert len(run.events) == 3
    assert run.events[0] == {"type": "stream_delta", "delta": "0"}
    assert run.events[2] == {"type": "stream_delta", "delta": "2"}

    # Fill buffer past limit
    run.events.clear()
    for i in range(500 + 5):
        await run.broadcast(("event", FakeEvent(i)))

    assert len(run.events) == 500
    assert run.events[0] == {"type": "stream_delta", "delta": "5"}
    assert run.events[-1] == {"type": "stream_delta", "delta": str(500 + 4)}


# ── Reconnect / replay tests ─────────────────────────────────────────────────

@pytest.mark.anyio
async def test_reconnect_replays_buffered_events(mock_websocket, mock_session, mock_user, monkeypatch):
    """Re-attach to active run yields replay_start, buffered events, replay_end, then session_sync."""
    monkeypatch.setattr(ws_mod, "get_current_user_ws", AsyncMock(return_value=mock_user))
    mock_store = MagicMock()
    mock_store.get = AsyncMock(return_value=mock_session)
    monkeypatch.setattr(ws_mod, "get_session_store", lambda: mock_store)
    monkeypatch.setattr(ws_mod, "_stream_to_client", AsyncMock(return_value=True))

    fake_container = MagicMock()
    fake_container.profile_registry = None
    fake_container.tool_registry = None
    fake_container.backend_registry = None
    fake_container.cp_registry = None
    fake_container.orchestrator = AgentSessionOrchestrator(fake_container)
    monkeypatch.setattr("navi.api.deps._resolve_container", lambda: fake_container)

    orchestrator = fake_container.orchestrator
    run = orchestrator.create_run("s1")
    run.events = [
        {"type": "stream_delta", "delta": "hello"},
        {"type": "thinking_delta", "delta": "hmm"},
    ]

    mock_websocket.receive_text = AsyncMock(side_effect=WebSocketDisconnect())

    await ws_mod.websocket_session("s1", mock_websocket)

    calls = [c.args[0] for c in mock_websocket.send_json.call_args_list]
    types = [c["type"] for c in calls]

    assert types == [
        "stream_start",
        "replay_start",
        "stream_delta",
        "thinking_delta",
        "replay_end",
        "session_sync",
    ]
    assert calls[1]["count"] == 2

    orchestrator._sessions.pop("s1", None)


@pytest.mark.anyio
async def test_session_sync_after_reconnect_when_done(mock_websocket, mock_session, mock_user, monkeypatch):
    """Reconnect when no run is active → only session_sync, no replay."""
    monkeypatch.setattr(ws_mod, "get_current_user_ws", AsyncMock(return_value=mock_user))
    mock_store = MagicMock()
    mock_store.get = AsyncMock(return_value=mock_session)
    monkeypatch.setattr(ws_mod, "get_session_store", lambda: mock_store)
    monkeypatch.setattr(ws_mod, "_stream_to_client", AsyncMock(return_value=True))

    fake_container = MagicMock()
    fake_container.profile_registry = None
    fake_container.tool_registry = None
    fake_container.backend_registry = None
    fake_container.cp_registry = None
    fake_container.orchestrator = AgentSessionOrchestrator(fake_container)
    monkeypatch.setattr("navi.api.deps._resolve_container", lambda: fake_container)

    mock_websocket.receive_text = AsyncMock(side_effect=WebSocketDisconnect())

    await ws_mod.websocket_session("s1", mock_websocket)

    calls = [c.args[0] for c in mock_websocket.send_json.call_args_list]
    assert len(calls) == 1
    assert calls[0]["type"] == "session_sync"


@pytest.mark.anyio
async def test_session_sync_after_recall_run(mock_websocket, mock_session, mock_user, monkeypatch):
    """Reconnect while a headless recall run is active → session_sync (no replay, no error)."""
    monkeypatch.setattr(ws_mod, "get_current_user_ws", AsyncMock(return_value=mock_user))
    mock_store = MagicMock()
    mock_store.get = AsyncMock(return_value=mock_session)
    monkeypatch.setattr(ws_mod, "get_session_store", lambda: mock_store)
    monkeypatch.setattr(ws_mod, "_stream_to_client", AsyncMock(return_value=True))

    fake_container = MagicMock()
    fake_container.profile_registry = None
    fake_container.tool_registry = None
    fake_container.backend_registry = None
    fake_container.cp_registry = None
    fake_container.orchestrator = AgentSessionOrchestrator(fake_container)
    monkeypatch.setattr("navi.api.deps._resolve_container", lambda: fake_container)

    orchestrator = fake_container.orchestrator
    orchestrator.mark_busy("s1")

    mock_websocket.receive_text = AsyncMock(side_effect=WebSocketDisconnect())

    await ws_mod.websocket_session("s1", mock_websocket)

    calls = [c.args[0] for c in mock_websocket.send_json.call_args_list]
    types = [c["type"] for c in calls]
    assert types == ["session_sync"]

    await orchestrator.clear_busy("s1")


# ── Concurrent run guard ─────────────────────────────────────────────────────

@pytest.mark.anyio
async def test_concurrent_run_guard_rejects_second_message(mock_websocket, mock_session, mock_user, monkeypatch):
    """Sending a second message while a run is active yields a WebSocket error."""
    monkeypatch.setattr(ws_mod, "get_current_user_ws", AsyncMock(return_value=mock_user))
    mock_store = MagicMock()
    mock_store.get = AsyncMock(return_value=mock_session)
    monkeypatch.setattr(ws_mod, "get_session_store", lambda: mock_store)
    monkeypatch.setattr(ws_mod, "_stream_to_client", AsyncMock(return_value=True))

    fake_container = MagicMock()
    fake_container.profile_registry = None
    fake_container.tool_registry = None
    fake_container.backend_registry = None
    fake_container.cp_registry = None
    fake_container.orchestrator = AgentSessionOrchestrator(fake_container)
    monkeypatch.setattr("navi.api.deps._resolve_container", lambda: fake_container)

    orchestrator = fake_container.orchestrator

    # _start_agent_run will create a background task; fake run_agent sleeps so
    # the run stays registered long enough for the second message to be rejected.
    async def fake_run_agent(*a, **kw):
        await asyncio.sleep(3600)

    monkeypatch.setattr(orchestrator, "run_agent", fake_run_agent)

    message_count = 0

    async def fake_receive_text():
        nonlocal message_count
        message_count += 1
        if message_count == 1:
            return json.dumps({"type": "message", "content": "first"})
        if message_count == 2:
            return json.dumps({"type": "message", "content": "second"})
        raise WebSocketDisconnect()

    mock_websocket.receive_text = fake_receive_text

    await ws_mod.websocket_session("s1", mock_websocket)

    calls = [c.args[0] for c in mock_websocket.send_json.call_args_list]
    error_calls = [c for c in calls if c["type"] == "error"]
    assert len(error_calls) == 1
    assert "already running" in error_calls[0]["message"]

    # Cleanup background task
    state = orchestrator._sessions.get("s1")
    if state and state.run and state.run.task:
        state.run.task.cancel()
        try:
            await state.run.task
        except asyncio.CancelledError:
            pass
    orchestrator._sessions.pop("s1", None)


# ── Form submit tests ──────────────────────────────────────────────────────

@pytest.mark.anyio
async def test_form_submit_starts_hidden_agent_run(mock_websocket, mock_session, mock_user, monkeypatch):
    """A valid form_submit creates a hidden user message and starts the agent."""
    monkeypatch.setattr(ws_mod, "get_current_user_ws", AsyncMock(return_value=mock_user))
    mock_store = MagicMock()
    mock_store.get = AsyncMock(return_value=mock_session)
    monkeypatch.setattr(ws_mod, "get_session_store", lambda: mock_store)
    monkeypatch.setattr(ws_mod, "_stream_to_client", AsyncMock(return_value=True))

    fake_container = MagicMock()
    fake_container.profile_registry = None
    fake_container.tool_registry = None
    fake_container.backend_registry = None
    fake_container.cp_registry = None
    fake_container.orchestrator = AgentSessionOrchestrator(fake_container)
    monkeypatch.setattr("navi.api.deps._resolve_container", lambda: fake_container)

    orchestrator = fake_container.orchestrator
    captured = {}

    async def fake_run_agent(*a, **kw):
        captured["args"] = list(a)
        captured["kwargs"] = dict(kw)
        await asyncio.sleep(0)

    monkeypatch.setattr(orchestrator, "run_agent", fake_run_agent)

    message_count = 0

    async def fake_receive_text():
        nonlocal message_count
        message_count += 1
        if message_count == 1:
            return json.dumps({
                "type": "form_submit",
                "form_id": "form_1",
                "values": {"district": "Pechersk", "budget": 80000},
            })
        raise WebSocketDisconnect()

    mock_websocket.receive_text = fake_receive_text

    await ws_mod.websocket_session("s1", mock_websocket)

    # _start_agent_run launches run_agent as a background task; give it one
    # event-loop tick so the fake coroutine populates `captured` before we assert.
    await asyncio.sleep(0)

    assert captured["kwargs"].get("hidden") is True
    assert "form_1" in captured["args"][1]
    assert "Pechersk" in captured["args"][1]

    orchestrator._sessions.pop("s1", None)


@pytest.mark.anyio
async def test_form_submit_rejects_missing_form_id_or_values(mock_websocket, mock_session, mock_user, monkeypatch):
    """Incomplete form_submit frames return a WebSocket error without starting a run."""
    monkeypatch.setattr(ws_mod, "get_current_user_ws", AsyncMock(return_value=mock_user))
    mock_store = MagicMock()
    mock_store.get = AsyncMock(return_value=mock_session)
    monkeypatch.setattr(ws_mod, "get_session_store", lambda: mock_store)
    monkeypatch.setattr(ws_mod, "_stream_to_client", AsyncMock(return_value=True))

    fake_container = MagicMock()
    fake_container.profile_registry = None
    fake_container.tool_registry = None
    fake_container.backend_registry = None
    fake_container.cp_registry = None
    fake_container.orchestrator = AgentSessionOrchestrator(fake_container)
    monkeypatch.setattr("navi.api.deps._resolve_container", lambda: fake_container)

    orchestrator = fake_container.orchestrator
    run_started = False

    async def fake_run_agent(*a, **kw):
        nonlocal run_started
        run_started = True

    monkeypatch.setattr(orchestrator, "run_agent", fake_run_agent)

    message_count = 0

    async def fake_receive_text():
        nonlocal message_count
        message_count += 1
        if message_count == 1:
            return json.dumps({"type": "form_submit", "form_id": "form_1"})
        if message_count == 2:
            return json.dumps({"type": "form_submit", "values": {"x": 1}})
        raise WebSocketDisconnect()

    mock_websocket.receive_text = fake_receive_text

    await ws_mod.websocket_session("s1", mock_websocket)

    calls = [c.args[0] for c in mock_websocket.send_json.call_args_list]
    error_calls = [c for c in calls if c["type"] == "error"]
    assert len(error_calls) == 2
    assert not run_started

    orchestrator._sessions.pop("s1", None)