Newer
Older
navi-1 / tests / integration / test_websocket.py
"""Integration tests for WebSocket endpoint."""

import asyncio
import json

import pytest
from fastapi.testclient import TestClient

from navi.core.events import StreamEnd, TextDelta
from navi.llm.base import Message


class FakeAgent:
    """Deterministic agent for WebSocket tests."""

    def __init__(self, stream_events=None, run_response="Hello") -> None:
        self._stream_events = stream_events or []
        self._run_response = run_response

    async def run(self, session_id: str, user_message: str, images=None) -> str:
        return self._run_response

    async def run_stream(self, session_id, user_message, images=None, display_message=None):
        for ev in self._stream_events:
            yield ev


@pytest.fixture(autouse=True)
def _clear_runs(monkeypatch):
    """Clear the module-level _runs dict before every WS test."""
    import navi.api.websocket as ws_mod

    ws_mod._runs.clear()
    yield


@pytest.fixture
def fake_agent_ws(monkeypatch, mock_deps):
    """Patch Agent in websocket module so handlers use FakeAgent."""
    import navi.api.websocket as ws_mod

    events = [
        TextDelta(delta="Hello"),
        StreamEnd(full_content="Hello"),
    ]
    fake = FakeAgent(stream_events=events)
    monkeypatch.setattr(ws_mod, "Agent", lambda *a, **kw: fake)
    return fake


class TestWebSocketConnect:
    def test_invalid_session(self, client):
        from starlette.testclient import WebSocketDisconnect

        with pytest.raises(WebSocketDisconnect):
            with client.websocket_connect("/ws/sessions/nonexistent"):
                pass

    @pytest.mark.anyio
    async def test_send_message(self, client, make_session, fake_agent_ws):
        session = await make_session("secretary")
        with client.websocket_connect(f"/ws/sessions/{session.id}") as ws:
            # First message on a fresh connection is session_sync
            m0 = ws.receive_json()
            assert m0["type"] == "session_sync"
            ws.send_json({"type": "message", "content": "hi"})
            # FakeAgent emits: stream_start (handler) → stream_delta → stream_end
            msgs: list[dict] = []
            for _ in range(3):
                msgs.append(ws.receive_json())

        types = [m["type"] for m in msgs]
        assert "stream_start" in types
        assert any(m.get("type") == "stream_delta" for m in msgs)
        assert any(m.get("type") == "stream_end" for m in msgs)

    @pytest.mark.anyio
    async def test_reconnect_replay(self, client, make_session, monkeypatch):
        """Reconnect while a run is active — replay buffer should emit past events."""
        import navi.api.websocket as ws_mod

        session = await make_session("secretary")

        # Inject an active run with buffered events
        run = ws_mod._AgentRun()
        run.events = [
            {"type": "stream_start"},
            {"type": "stream_delta", "delta": "hello"},
        ]
        ws_mod._runs[session.id] = run

        with client.websocket_connect(f"/ws/sessions/{session.id}") as ws:
            msgs = _collect_until_done(ws, max_messages=5)

        types = [m["type"] for m in msgs]
        assert "stream_start" in types
        assert "replay_start" in types
        assert any(m.get("type") == "stream_delta" for m in msgs)
        assert "replay_end" in types

        # Clean up injected run
        ws_mod._runs.pop(session.id, None)
        if run.task:
            run.task.cancel()

    @pytest.mark.anyio
    async def test_invalid_json(self, client, make_session):
        session = await make_session("secretary")
        with client.websocket_connect(f"/ws/sessions/{session.id}") as ws:
            # First message on a fresh connection is session_sync
            m0 = ws.receive_json()
            assert m0["type"] == "session_sync"
            ws.send_text("not json")
            msg = ws.receive_json()
            assert msg["type"] == "error"

    @pytest.mark.anyio
    async def test_missing_content(self, client, make_session):
        session = await make_session("secretary")
        with client.websocket_connect(f"/ws/sessions/{session.id}") as ws:
            # First message on a fresh connection is session_sync
            m0 = ws.receive_json()
            assert m0["type"] == "session_sync"
            ws.send_json({"type": "message"})
            msg = ws.receive_json()
            assert msg["type"] == "error"


class TestStopSession:
    @pytest.mark.anyio
    async def test_stop_no_active_run(self, client, make_session):
        session = await make_session("secretary")
        response = client.post(f"/sessions/{session.id}/stop")
        assert response.status_code == 200
        data = response.json()
        assert data["ok"] is False

    @pytest.mark.anyio
    async def test_stop_active_run(self, client, make_session, monkeypatch):
        import navi.api.websocket as ws_mod

        session = await make_session("secretary")

        # Start a long-running agent task in background
        run = ws_mod._AgentRun()
        run.task = asyncio.create_task(asyncio.sleep(10))
        ws_mod._runs[session.id] = run

        response = client.post(f"/sessions/{session.id}/stop")
        assert response.status_code == 200
        data = response.json()
        assert data["ok"] is True
        assert run.stop_event.is_set()
        run.task.cancel()
        try:
            await run.task
        except asyncio.CancelledError:
            pass


# ── Helpers ──────────────────────────────────────────────────────────────────

def _collect_until_done(ws, max_messages: int = 10) -> list[dict]:
    """Collect websocket messages until stream_end, error, or max messages."""
    msgs: list[dict] = []
    for _ in range(max_messages):
        try:
            raw = ws.receive_text()
        except Exception:
            break
        try:
            msg = json.loads(raw)
        except json.JSONDecodeError:
            continue
        msgs.append(msg)
        if msg.get("type") in ("stream_end", "error", "stream_stopped", "session_sync"):
            break
    return msgs