"""Integration tests for WebSocket endpoint."""
import asyncio
import json
import pytest
from fastapi.testclient import TestClient
from navi.core.events import StreamEnd, TextDelta
from navi.llm.base import Message
class FakeAgent:
"""Deterministic agent for WebSocket tests."""
def __init__(self, stream_events=None, run_response="Hello") -> None:
self._stream_events = stream_events or []
self._run_response = run_response
async def run(self, session_id: str, user_message: str, images=None) -> str:
return self._run_response
async def run_stream(self, session_id, user_message, images=None, display_message=None):
for ev in self._stream_events:
yield ev
@pytest.fixture(autouse=True)
def _clear_runs(monkeypatch):
"""Clear the module-level _runs dict before every WS test."""
import navi.api.websocket as ws_mod
ws_mod._runs.clear()
yield
@pytest.fixture
def fake_agent_ws(monkeypatch, mock_deps):
"""Patch Agent in websocket module so handlers use FakeAgent."""
import navi.api.websocket as ws_mod
events = [
TextDelta(delta="Hello"),
StreamEnd(full_content="Hello"),
]
fake = FakeAgent(stream_events=events)
monkeypatch.setattr(ws_mod, "Agent", lambda *a, **kw: fake)
return fake
class TestWebSocketConnect:
def test_invalid_session(self, client):
from starlette.testclient import WebSocketDisconnect
with pytest.raises(WebSocketDisconnect):
with client.websocket_connect("/ws/sessions/nonexistent"):
pass
@pytest.mark.anyio
async def test_send_message(self, client, make_session, fake_agent_ws):
session = await make_session("secretary")
with client.websocket_connect(f"/ws/sessions/{session.id}") as ws:
# First message on a fresh connection is session_sync
m0 = ws.receive_json()
assert m0["type"] == "session_sync"
ws.send_json({"type": "message", "content": "hi"})
# FakeAgent emits: stream_start (handler) → stream_delta → stream_end
msgs: list[dict] = []
for _ in range(3):
msgs.append(ws.receive_json())
types = [m["type"] for m in msgs]
assert "stream_start" in types
assert any(m.get("type") == "stream_delta" for m in msgs)
assert any(m.get("type") == "stream_end" for m in msgs)
@pytest.mark.anyio
async def test_reconnect_replay(self, client, make_session, monkeypatch):
"""Reconnect while a run is active — replay buffer should emit past events."""
import navi.api.websocket as ws_mod
session = await make_session("secretary")
# Inject an active run with buffered events
run = ws_mod._AgentRun()
run.events = [
{"type": "stream_start"},
{"type": "stream_delta", "delta": "hello"},
]
ws_mod._runs[session.id] = run
with client.websocket_connect(f"/ws/sessions/{session.id}") as ws:
msgs = _collect_until_done(ws, max_messages=5)
types = [m["type"] for m in msgs]
assert "stream_start" in types
assert "replay_start" in types
assert any(m.get("type") == "stream_delta" for m in msgs)
assert "replay_end" in types
# Clean up injected run
ws_mod._runs.pop(session.id, None)
if run.task:
run.task.cancel()
@pytest.mark.anyio
async def test_invalid_json(self, client, make_session):
session = await make_session("secretary")
with client.websocket_connect(f"/ws/sessions/{session.id}") as ws:
# First message on a fresh connection is session_sync
m0 = ws.receive_json()
assert m0["type"] == "session_sync"
ws.send_text("not json")
msg = ws.receive_json()
assert msg["type"] == "error"
@pytest.mark.anyio
async def test_missing_content(self, client, make_session):
session = await make_session("secretary")
with client.websocket_connect(f"/ws/sessions/{session.id}") as ws:
# First message on a fresh connection is session_sync
m0 = ws.receive_json()
assert m0["type"] == "session_sync"
ws.send_json({"type": "message"})
msg = ws.receive_json()
assert msg["type"] == "error"
class TestStopSession:
@pytest.mark.anyio
async def test_stop_no_active_run(self, client, make_session):
session = await make_session("secretary")
response = client.post(f"/sessions/{session.id}/stop")
assert response.status_code == 200
data = response.json()
assert data["ok"] is False
@pytest.mark.anyio
async def test_stop_active_run(self, client, make_session, monkeypatch):
import navi.api.websocket as ws_mod
session = await make_session("secretary")
# Start a long-running agent task in background
run = ws_mod._AgentRun()
run.task = asyncio.create_task(asyncio.sleep(10))
ws_mod._runs[session.id] = run
response = client.post(f"/sessions/{session.id}/stop")
assert response.status_code == 200
data = response.json()
assert data["ok"] is True
assert run.stop_event.is_set()
run.task.cancel()
try:
await run.task
except asyncio.CancelledError:
pass
# ── Helpers ──────────────────────────────────────────────────────────────────
def _collect_until_done(ws, max_messages: int = 10) -> list[dict]:
"""Collect websocket messages until stream_end, error, or max messages."""
msgs: list[dict] = []
for _ in range(max_messages):
try:
raw = ws.receive_text()
except Exception:
break
try:
msg = json.loads(raw)
except json.JSONDecodeError:
continue
msgs.append(msg)
if msg.get("type") in ("stream_end", "error", "stream_stopped", "session_sync"):
break
return msgs