Newer
Older
voice / tests / test_server.py
"""End-to-end WebSocket tests for the TTS server."""

import asyncio
import json

import pytest
import pytest_asyncio
import websockets
from httpx import AsyncClient

from voice_tts.api.server import build_app
from voice_tts.tts.engine import DummyTTSEngine


@pytest_asyncio.fixture
async def dummy_server():
    """Run a local Uvicorn server with the dummy TTS backend."""
    import uvicorn

    engine = DummyTTSEngine(sample_rate=24_000)
    app = build_app(engine)
    config = uvicorn.Config(app, host="127.0.0.1", port=9876, log_level="warning")
    server = uvicorn.Server(config)
    task = asyncio.create_task(server.serve())
    # Wait until server is ready
    for _ in range(50):
        if server.started:
            break
        await asyncio.sleep(0.05)
    yield "ws://127.0.0.1:9876/ws", "http://127.0.0.1:9876/health"
    server.should_exit = True
    await task


@pytest.mark.asyncio
async def test_health(dummy_server):
    ws_url, health_url = dummy_server
    async with AsyncClient() as client:
        response = await client.get(health_url)
    assert response.status_code == 200
    data = response.json()
    assert data["status"] == "ok"


@pytest.mark.asyncio
async def test_websocket_streaming_and_stop(dummy_server):
    ws_url, _ = dummy_server
    received_audio = 0
    received_statuses = []

    async with websockets.connect(ws_url) as ws:
        await ws.send(json.dumps({
            "type": "init",
            "session_id": "test",
            "seq": 1,
        }))
        msg = json.loads(await asyncio.wait_for(ws.recv(), timeout=5))
        assert msg["type"] == "status"
        assert msg["event"] == "session_ready"

        for i, chunk in enumerate(["Привет, ", "как ", "дела?"]):
            await ws.send(json.dumps({"type": "text", "payload": chunk, "seq": 2 + i}))
            # Give the server a moment to process each chunk sequentially,
            # mirroring how a streaming LLM would emit tokens.
            await asyncio.sleep(0.05)

        await ws.send(json.dumps({"type": "flush", "seq": 5}))

        # Collect messages until segment_finished and at least one audio chunk.
        deadline = asyncio.get_event_loop().time() + 5
        while asyncio.get_event_loop().time() < deadline:
            msg = json.loads(await asyncio.wait_for(ws.recv(), timeout=5))
            received_statuses.append((msg["type"], msg.get("event")))
            if msg["type"] == "audio":
                received_audio += 1
            if msg["type"] == "status" and msg.get("event") == "segment_finished":
                break

        # Interrupt
        await ws.send(json.dumps({"type": "stop", "reason": "interrupt", "seq": 6}))

        # Drain any audio/status messages already in flight before the stop is processed.
        for _ in range(20):
            msg = json.loads(await asyncio.wait_for(ws.recv(), timeout=5))
            if msg["type"] == "status" and msg.get("event") == "stopped":
                break
        else:
            pytest.fail("Did not receive stopped status")

    assert received_audio >= 1
    assert ("status", "segment_started") in received_statuses
    assert ("status", "segment_finished") in received_statuses