"""End-to-end WebSocket tests for the TTS server."""
import asyncio
import json
import pytest
import pytest_asyncio
import websockets
from httpx import AsyncClient
from voice_tts.api.server import build_app
from voice_tts.tts.engine import DummyTTSEngine
@pytest_asyncio.fixture
async def dummy_server():
"""Run a local Uvicorn server with the dummy TTS backend."""
import uvicorn
engine = DummyTTSEngine(sample_rate=24_000)
app = build_app(engine)
config = uvicorn.Config(app, host="127.0.0.1", port=9876, log_level="warning")
server = uvicorn.Server(config)
task = asyncio.create_task(server.serve())
# Wait until server is ready
for _ in range(50):
if server.started:
break
await asyncio.sleep(0.05)
yield "ws://127.0.0.1:9876/ws", "http://127.0.0.1:9876/health"
server.should_exit = True
await task
@pytest.mark.asyncio
async def test_health(dummy_server):
ws_url, health_url = dummy_server
async with AsyncClient() as client:
response = await client.get(health_url)
assert response.status_code == 200
data = response.json()
assert data["status"] == "ok"
@pytest.mark.asyncio
async def test_websocket_streaming_and_stop(dummy_server):
ws_url, _ = dummy_server
received_audio = 0
received_statuses = []
async with websockets.connect(ws_url) as ws:
await ws.send(json.dumps({
"type": "init",
"session_id": "test",
"seq": 1,
}))
msg = json.loads(await asyncio.wait_for(ws.recv(), timeout=5))
assert msg["type"] == "status"
assert msg["event"] == "session_ready"
for i, chunk in enumerate(["Привет, ", "как ", "дела?"]):
await ws.send(json.dumps({"type": "text", "payload": chunk, "seq": 2 + i}))
# Give the server a moment to process each chunk sequentially,
# mirroring how a streaming LLM would emit tokens.
await asyncio.sleep(0.05)
await ws.send(json.dumps({"type": "flush", "seq": 5}))
# Collect messages until segment_finished and at least one audio chunk.
deadline = asyncio.get_event_loop().time() + 5
while asyncio.get_event_loop().time() < deadline:
msg = json.loads(await asyncio.wait_for(ws.recv(), timeout=5))
received_statuses.append((msg["type"], msg.get("event")))
if msg["type"] == "audio":
received_audio += 1
if msg["type"] == "status" and msg.get("event") == "segment_finished":
break
# Interrupt
await ws.send(json.dumps({"type": "stop", "reason": "interrupt", "seq": 6}))
# Drain any audio/status messages already in flight before the stop is processed.
for _ in range(20):
msg = json.loads(await asyncio.wait_for(ws.recv(), timeout=5))
if msg["type"] == "status" and msg.get("event") == "stopped":
break
else:
pytest.fail("Did not receive stopped status")
assert received_audio >= 1
assert ("status", "segment_started") in received_statuses
assert ("status", "segment_finished") in received_statuses