"""Unit tests for WebSocket handler internals and reconnect logic."""
import asyncio
import json
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from fastapi import WebSocketDisconnect
from navi.api import websocket as ws_mod
from navi.core.orchestrator import AgentSessionOrchestrator, SessionRun
@pytest.fixture(autouse=True)
def _clear_state(monkeypatch):
"""Clear global state before every WS test."""
yield
@pytest.fixture
def mock_websocket():
ws = AsyncMock()
ws.accept = AsyncMock()
ws.close = AsyncMock()
ws.send_json = AsyncMock()
return ws
@pytest.fixture
def mock_session():
session = MagicMock()
session.user_id = "test-user-id"
return session
@pytest.fixture
def mock_user():
user = MagicMock()
user.id = "test-user-id"
user.role = "admin"
return user
@pytest.fixture
def fake_orchestrator():
container = MagicMock()
container.profile_registry = None
container.tool_registry = None
container.backend_registry = None
container.cp_registry = None
container.workers = []
container.memory_store = None
container.mcp_manager = None
return AgentSessionOrchestrator(container)
# ── SessionRun buffer tests ─────────────────────────────────────────────────
@pytest.mark.anyio
async def test_event_buffer_appended_and_replayed():
"""Broadcast stores serialised events; oldest evicted when cap exceeded."""
run = SessionRun()
class FakeEvent:
def __init__(self, idx: int) -> None:
self.idx = idx
def to_wire(self) -> dict:
return {"type": "stream_delta", "delta": str(self.idx)}
for i in range(3):
await run.broadcast(("event", FakeEvent(i)))
assert len(run.events) == 3
assert run.events[0] == {"type": "stream_delta", "delta": "0"}
assert run.events[2] == {"type": "stream_delta", "delta": "2"}
# Fill buffer past limit
run.events.clear()
for i in range(500 + 5):
await run.broadcast(("event", FakeEvent(i)))
assert len(run.events) == 500
assert run.events[0] == {"type": "stream_delta", "delta": "5"}
assert run.events[-1] == {"type": "stream_delta", "delta": str(500 + 4)}
# ── Reconnect / replay tests ─────────────────────────────────────────────────
@pytest.mark.anyio
async def test_reconnect_replays_buffered_events(mock_websocket, mock_session, mock_user, monkeypatch):
"""Re-attach to active run yields replay_start, buffered events, replay_end, then session_sync."""
monkeypatch.setattr(ws_mod, "get_current_user_ws", AsyncMock(return_value=mock_user))
mock_store = MagicMock()
mock_store.get = AsyncMock(return_value=mock_session)
monkeypatch.setattr(ws_mod, "get_session_store", lambda: mock_store)
monkeypatch.setattr(ws_mod, "_stream_to_client", AsyncMock(return_value=True))
fake_container = MagicMock()
fake_container.profile_registry = None
fake_container.tool_registry = None
fake_container.backend_registry = None
fake_container.cp_registry = None
fake_container.orchestrator = AgentSessionOrchestrator(fake_container)
monkeypatch.setattr("navi.api.deps._resolve_container", lambda: fake_container)
orchestrator = fake_container.orchestrator
run = orchestrator.create_run("s1")
run.events = [
{"type": "stream_delta", "delta": "hello"},
{"type": "thinking_delta", "delta": "hmm"},
]
mock_websocket.receive_text = AsyncMock(side_effect=WebSocketDisconnect())
await ws_mod.websocket_session("s1", mock_websocket)
calls = [c.args[0] for c in mock_websocket.send_json.call_args_list]
types = [c["type"] for c in calls]
assert types == [
"stream_start",
"replay_start",
"stream_delta",
"thinking_delta",
"replay_end",
"session_sync",
]
assert calls[1]["count"] == 2
orchestrator._sessions.pop("s1", None)
@pytest.mark.anyio
async def test_session_sync_after_reconnect_when_done(mock_websocket, mock_session, mock_user, monkeypatch):
"""Reconnect when no run is active → only session_sync, no replay."""
monkeypatch.setattr(ws_mod, "get_current_user_ws", AsyncMock(return_value=mock_user))
mock_store = MagicMock()
mock_store.get = AsyncMock(return_value=mock_session)
monkeypatch.setattr(ws_mod, "get_session_store", lambda: mock_store)
monkeypatch.setattr(ws_mod, "_stream_to_client", AsyncMock(return_value=True))
fake_container = MagicMock()
fake_container.profile_registry = None
fake_container.tool_registry = None
fake_container.backend_registry = None
fake_container.cp_registry = None
fake_container.orchestrator = AgentSessionOrchestrator(fake_container)
monkeypatch.setattr("navi.api.deps._resolve_container", lambda: fake_container)
mock_websocket.receive_text = AsyncMock(side_effect=WebSocketDisconnect())
await ws_mod.websocket_session("s1", mock_websocket)
calls = [c.args[0] for c in mock_websocket.send_json.call_args_list]
assert len(calls) == 1
assert calls[0]["type"] == "session_sync"
@pytest.mark.anyio
async def test_session_sync_after_recall_run(mock_websocket, mock_session, mock_user, monkeypatch):
"""Reconnect while a headless recall run is active → session_sync (no replay, no error)."""
monkeypatch.setattr(ws_mod, "get_current_user_ws", AsyncMock(return_value=mock_user))
mock_store = MagicMock()
mock_store.get = AsyncMock(return_value=mock_session)
monkeypatch.setattr(ws_mod, "get_session_store", lambda: mock_store)
monkeypatch.setattr(ws_mod, "_stream_to_client", AsyncMock(return_value=True))
fake_container = MagicMock()
fake_container.profile_registry = None
fake_container.tool_registry = None
fake_container.backend_registry = None
fake_container.cp_registry = None
fake_container.orchestrator = AgentSessionOrchestrator(fake_container)
monkeypatch.setattr("navi.api.deps._resolve_container", lambda: fake_container)
orchestrator = fake_container.orchestrator
orchestrator.mark_busy("s1")
mock_websocket.receive_text = AsyncMock(side_effect=WebSocketDisconnect())
await ws_mod.websocket_session("s1", mock_websocket)
calls = [c.args[0] for c in mock_websocket.send_json.call_args_list]
types = [c["type"] for c in calls]
assert types == ["session_sync"]
await orchestrator.clear_busy("s1")
# ── Concurrent run guard ─────────────────────────────────────────────────────
@pytest.mark.anyio
async def test_concurrent_run_guard_rejects_second_message(mock_websocket, mock_session, mock_user, monkeypatch):
"""Sending a second message while a run is active yields a WebSocket error."""
monkeypatch.setattr(ws_mod, "get_current_user_ws", AsyncMock(return_value=mock_user))
mock_store = MagicMock()
mock_store.get = AsyncMock(return_value=mock_session)
monkeypatch.setattr(ws_mod, "get_session_store", lambda: mock_store)
monkeypatch.setattr(ws_mod, "_stream_to_client", AsyncMock(return_value=True))
fake_container = MagicMock()
fake_container.profile_registry = None
fake_container.tool_registry = None
fake_container.backend_registry = None
fake_container.cp_registry = None
fake_container.orchestrator = AgentSessionOrchestrator(fake_container)
monkeypatch.setattr("navi.api.deps._resolve_container", lambda: fake_container)
orchestrator = fake_container.orchestrator
# _start_agent_run will create a background task; fake run_agent sleeps so
# the run stays registered long enough for the second message to be rejected.
async def fake_run_agent(*a, **kw):
await asyncio.sleep(3600)
monkeypatch.setattr(orchestrator, "run_agent", fake_run_agent)
message_count = 0
async def fake_receive_text():
nonlocal message_count
message_count += 1
if message_count == 1:
return json.dumps({"type": "message", "content": "first"})
if message_count == 2:
return json.dumps({"type": "message", "content": "second"})
raise WebSocketDisconnect()
mock_websocket.receive_text = fake_receive_text
await ws_mod.websocket_session("s1", mock_websocket)
calls = [c.args[0] for c in mock_websocket.send_json.call_args_list]
error_calls = [c for c in calls if c["type"] == "error"]
assert len(error_calls) == 1
assert "already running" in error_calls[0]["message"]
# Cleanup background task
state = orchestrator._sessions.get("s1")
if state and state.run and state.run.task:
state.run.task.cancel()
try:
await state.run.task
except asyncio.CancelledError:
pass
orchestrator._sessions.pop("s1", None)
# ── Form submit tests ──────────────────────────────────────────────────────
@pytest.mark.anyio
async def test_form_submit_starts_hidden_agent_run(mock_websocket, mock_session, mock_user, monkeypatch):
"""A valid form_submit creates a hidden user message and starts the agent."""
monkeypatch.setattr(ws_mod, "get_current_user_ws", AsyncMock(return_value=mock_user))
mock_store = MagicMock()
mock_store.get = AsyncMock(return_value=mock_session)
monkeypatch.setattr(ws_mod, "get_session_store", lambda: mock_store)
monkeypatch.setattr(ws_mod, "_stream_to_client", AsyncMock(return_value=True))
fake_container = MagicMock()
fake_container.profile_registry = None
fake_container.tool_registry = None
fake_container.backend_registry = None
fake_container.cp_registry = None
fake_container.orchestrator = AgentSessionOrchestrator(fake_container)
monkeypatch.setattr("navi.api.deps._resolve_container", lambda: fake_container)
orchestrator = fake_container.orchestrator
captured = {}
async def fake_run_agent(*a, **kw):
captured["args"] = list(a)
captured["kwargs"] = dict(kw)
await asyncio.sleep(0)
monkeypatch.setattr(orchestrator, "run_agent", fake_run_agent)
message_count = 0
async def fake_receive_text():
nonlocal message_count
message_count += 1
if message_count == 1:
return json.dumps({
"type": "form_submit",
"form_id": "form_1",
"values": {"district": "Pechersk", "budget": 80000},
})
raise WebSocketDisconnect()
mock_websocket.receive_text = fake_receive_text
await ws_mod.websocket_session("s1", mock_websocket)
# _start_agent_run launches run_agent as a background task; give it one
# event-loop tick so the fake coroutine populates `captured` before we assert.
await asyncio.sleep(0)
assert captured["kwargs"].get("hidden") is True
assert "form_1" in captured["args"][1]
assert "Pechersk" in captured["args"][1]
orchestrator._sessions.pop("s1", None)
@pytest.mark.anyio
async def test_form_submit_rejects_missing_form_id_or_values(mock_websocket, mock_session, mock_user, monkeypatch):
"""Incomplete form_submit frames return a WebSocket error without starting a run."""
monkeypatch.setattr(ws_mod, "get_current_user_ws", AsyncMock(return_value=mock_user))
mock_store = MagicMock()
mock_store.get = AsyncMock(return_value=mock_session)
monkeypatch.setattr(ws_mod, "get_session_store", lambda: mock_store)
monkeypatch.setattr(ws_mod, "_stream_to_client", AsyncMock(return_value=True))
fake_container = MagicMock()
fake_container.profile_registry = None
fake_container.tool_registry = None
fake_container.backend_registry = None
fake_container.cp_registry = None
fake_container.orchestrator = AgentSessionOrchestrator(fake_container)
monkeypatch.setattr("navi.api.deps._resolve_container", lambda: fake_container)
orchestrator = fake_container.orchestrator
run_started = False
async def fake_run_agent(*a, **kw):
nonlocal run_started
run_started = True
monkeypatch.setattr(orchestrator, "run_agent", fake_run_agent)
message_count = 0
async def fake_receive_text():
nonlocal message_count
message_count += 1
if message_count == 1:
return json.dumps({"type": "form_submit", "form_id": "form_1"})
if message_count == 2:
return json.dumps({"type": "form_submit", "values": {"x": 1}})
raise WebSocketDisconnect()
mock_websocket.receive_text = fake_receive_text
await ws_mod.websocket_session("s1", mock_websocket)
calls = [c.args[0] for c in mock_websocket.send_json.call_args_list]
error_calls = [c for c in calls if c["type"] == "error"]
assert len(error_calls) == 2
assert not run_started
orchestrator._sessions.pop("s1", None)