Newer
Older
navi-1 / tests / integration / test_websocket.py
"""Integration tests for WebSocket endpoint."""

import asyncio
import json

import pytest

from navi.core.events import StreamEnd, TextDelta


def _get_orchestrator(mock_deps):
    from navi.main import app

    return app.state.container.orchestrator


@pytest.fixture(autouse=True)
def _clear_runs(mock_deps):
    """Clear orchestrator state before every WS test."""
    orchestrator = _get_orchestrator(mock_deps)
    for session_id in list(orchestrator._sessions.keys()):
        state = orchestrator._sessions.get(session_id)
        if state and state.run and state.run.task:
            state.run.task.cancel()
    orchestrator._sessions.clear()
    orchestrator._session_locks.clear()
    yield


@pytest.fixture
def fake_agent_ws(monkeypatch, mock_deps):
    """Patch orchestrator.run_agent so it broadcasts deterministic events."""
    orchestrator = _get_orchestrator(mock_deps)

    async def fake_run_agent(session_id, user_content, raw_images, display_content, files, session_store):
        run = orchestrator.get_run(session_id)
        if run is None:
            return
        await run.broadcast(("event", TextDelta(delta="Hello")))
        await run.broadcast(("event", StreamEnd(full_content="Hello")))

    monkeypatch.setattr(orchestrator, "run_agent", fake_run_agent)
    return fake_run_agent


class TestWebSocketConnect:
    def test_invalid_session(self, client):
        from starlette.testclient import WebSocketDisconnect

        with pytest.raises(WebSocketDisconnect):
            with client.websocket_connect("/ws/sessions/nonexistent"):
                pass

    @pytest.mark.anyio
    async def test_send_message(self, client, make_session, fake_agent_ws):
        session = await make_session("secretary")
        with client.websocket_connect(f"/ws/sessions/{session.id}") as ws:
            # First message on a fresh connection is session_sync
            m0 = ws.receive_json()
            assert m0["type"] == "session_sync"
            ws.send_json({"type": "message", "content": "hi"})
            # FakeAgent emits: stream_start (handler) → stream_delta → stream_end
            msgs: list[dict] = []
            for _ in range(3):
                msgs.append(ws.receive_json())

        types = [m["type"] for m in msgs]
        assert "stream_start" in types
        assert any(m.get("type") == "stream_delta" for m in msgs)
        assert any(m.get("type") == "stream_end" for m in msgs)

    @pytest.mark.anyio
    async def test_reconnect_replay(self, client, make_session, mock_deps):
        """Reconnect while a run is active — replay buffer should emit past events."""
        orchestrator = _get_orchestrator(mock_deps)
        session = await make_session("secretary")

        # Inject an active run with buffered events
        run = orchestrator.create_run(session.id)
        run.events = [
            {"type": "stream_start"},
            {"type": "stream_delta", "delta": "hello"},
        ]

        with client.websocket_connect(f"/ws/sessions/{session.id}") as ws:
            msgs = _collect_until_done(ws, max_messages=5)

        types = [m["type"] for m in msgs]
        assert "stream_start" in types
        assert "replay_start" in types
        assert any(m.get("type") == "stream_delta" for m in msgs)
        assert "replay_end" in types

        # Clean up injected run
        orchestrator._sessions.pop(session.id, None)

    @pytest.mark.anyio
    async def test_invalid_json(self, client, make_session):
        session = await make_session("secretary")
        with client.websocket_connect(f"/ws/sessions/{session.id}") as ws:
            # First message on a fresh connection is session_sync
            m0 = ws.receive_json()
            assert m0["type"] == "session_sync"
            ws.send_text("not json")
            msg = ws.receive_json()
            assert msg["type"] == "error"

    @pytest.mark.anyio
    async def test_missing_content(self, client, make_session):
        session = await make_session("secretary")
        with client.websocket_connect(f"/ws/sessions/{session.id}") as ws:
            # First message on a fresh connection is session_sync
            m0 = ws.receive_json()
            assert m0["type"] == "session_sync"
            ws.send_json({"type": "message"})
            msg = ws.receive_json()
            assert msg["type"] == "error"


class TestStopSession:
    @pytest.mark.anyio
    async def test_stop_no_active_run(self, client, make_session, mock_deps):
        session = await make_session("secretary")
        response = client.post(f"/sessions/{session.id}/stop")
        assert response.status_code == 200
        data = response.json()
        assert data["ok"] is False

    @pytest.mark.anyio
    async def test_stop_active_run(self, client, make_session, mock_deps):
        orchestrator = _get_orchestrator(mock_deps)
        session = await make_session("secretary")

        # Start a long-running agent task in background
        run = orchestrator.create_run(session.id)
        run.task = asyncio.create_task(asyncio.sleep(10))

        response = client.post(f"/sessions/{session.id}/stop")
        assert response.status_code == 200
        data = response.json()
        assert data["ok"] is True
        assert run.stop_event.is_set()
        run.task.cancel()
        try:
            await run.task
        except asyncio.CancelledError:
            pass
        orchestrator._sessions.pop(session.id, None)


# ── Helpers ──────────────────────────────────────────────────────────────────

def _collect_until_done(ws, max_messages: int = 10) -> list[dict]:
    """Collect websocket messages until stream_end, error, or max messages."""
    msgs: list[dict] = []
    for _ in range(max_messages):
        try:
            raw = ws.receive_text()
        except Exception:
            break
        try:
            msg = json.loads(raw)
        except json.JSONDecodeError:
            continue
        msgs.append(msg)
        if msg.get("type") in ("stream_end", "error", "stream_stopped", "session_sync"):
            break
    return msgs