"""Integration tests for WebSocket endpoint."""
import asyncio
import json
import pytest
from navi.core.events import StreamEnd, TextDelta
def _get_orchestrator(mock_deps):
from navi.main import app
return app.state.container.orchestrator
@pytest.fixture(autouse=True)
def _clear_runs(mock_deps):
"""Clear orchestrator state before every WS test."""
orchestrator = _get_orchestrator(mock_deps)
for session_id in list(orchestrator._sessions.keys()):
state = orchestrator._sessions.get(session_id)
if state and state.run and state.run.task:
state.run.task.cancel()
orchestrator._sessions.clear()
orchestrator._session_locks.clear()
yield
@pytest.fixture
def fake_agent_ws(monkeypatch, mock_deps):
"""Patch orchestrator.run_agent so it broadcasts deterministic events."""
orchestrator = _get_orchestrator(mock_deps)
async def fake_run_agent(session_id, user_content, raw_images, display_content, files, session_store):
run = orchestrator.get_run(session_id)
if run is None:
return
await run.broadcast(("event", TextDelta(delta="Hello")))
await run.broadcast(("event", StreamEnd(full_content="Hello")))
monkeypatch.setattr(orchestrator, "run_agent", fake_run_agent)
return fake_run_agent
class TestWebSocketConnect:
def test_invalid_session(self, client):
from starlette.testclient import WebSocketDisconnect
with pytest.raises(WebSocketDisconnect):
with client.websocket_connect("/ws/sessions/nonexistent"):
pass
@pytest.mark.anyio
async def test_send_message(self, client, make_session, fake_agent_ws):
session = await make_session("secretary")
with client.websocket_connect(f"/ws/sessions/{session.id}") as ws:
# First message on a fresh connection is session_sync
m0 = ws.receive_json()
assert m0["type"] == "session_sync"
ws.send_json({"type": "message", "content": "hi"})
# FakeAgent emits: stream_start (handler) → stream_delta → stream_end
msgs: list[dict] = []
for _ in range(3):
msgs.append(ws.receive_json())
types = [m["type"] for m in msgs]
assert "stream_start" in types
assert any(m.get("type") == "stream_delta" for m in msgs)
assert any(m.get("type") == "stream_end" for m in msgs)
@pytest.mark.anyio
async def test_reconnect_replay(self, client, make_session, mock_deps):
"""Reconnect while a run is active — replay buffer should emit past events."""
orchestrator = _get_orchestrator(mock_deps)
session = await make_session("secretary")
# Inject an active run with buffered events
run = orchestrator.create_run(session.id)
run.events = [
{"type": "stream_start"},
{"type": "stream_delta", "delta": "hello"},
]
with client.websocket_connect(f"/ws/sessions/{session.id}") as ws:
msgs = _collect_until_done(ws, max_messages=5)
types = [m["type"] for m in msgs]
assert "stream_start" in types
assert "replay_start" in types
assert any(m.get("type") == "stream_delta" for m in msgs)
assert "replay_end" in types
# Clean up injected run
orchestrator._sessions.pop(session.id, None)
@pytest.mark.anyio
async def test_invalid_json(self, client, make_session):
session = await make_session("secretary")
with client.websocket_connect(f"/ws/sessions/{session.id}") as ws:
# First message on a fresh connection is session_sync
m0 = ws.receive_json()
assert m0["type"] == "session_sync"
ws.send_text("not json")
msg = ws.receive_json()
assert msg["type"] == "error"
@pytest.mark.anyio
async def test_missing_content(self, client, make_session):
session = await make_session("secretary")
with client.websocket_connect(f"/ws/sessions/{session.id}") as ws:
# First message on a fresh connection is session_sync
m0 = ws.receive_json()
assert m0["type"] == "session_sync"
ws.send_json({"type": "message"})
msg = ws.receive_json()
assert msg["type"] == "error"
class TestStopSession:
@pytest.mark.anyio
async def test_stop_no_active_run(self, client, make_session, mock_deps):
session = await make_session("secretary")
response = client.post(f"/sessions/{session.id}/stop")
assert response.status_code == 200
data = response.json()
assert data["ok"] is False
@pytest.mark.anyio
async def test_stop_active_run(self, client, make_session, mock_deps):
orchestrator = _get_orchestrator(mock_deps)
session = await make_session("secretary")
# Start a long-running agent task in background
run = orchestrator.create_run(session.id)
run.task = asyncio.create_task(asyncio.sleep(10))
response = client.post(f"/sessions/{session.id}/stop")
assert response.status_code == 200
data = response.json()
assert data["ok"] is True
assert run.stop_event.is_set()
run.task.cancel()
try:
await run.task
except asyncio.CancelledError:
pass
orchestrator._sessions.pop(session.id, None)
# ── Helpers ──────────────────────────────────────────────────────────────────
def _collect_until_done(ws, max_messages: int = 10) -> list[dict]:
"""Collect websocket messages until stream_end, error, or max messages."""
msgs: list[dict] = []
for _ in range(max_messages):
try:
raw = ws.receive_text()
except Exception:
break
try:
msg = json.loads(raw)
except json.JSONDecodeError:
continue
msgs.append(msg)
if msg.get("type") in ("stream_end", "error", "stream_stopped", "session_sync"):
break
return msgs