Newer
Older
voice / src / voice_tts / api / server.py
"""WebSocket server and session lifecycle."""

import asyncio
from contextlib import asynccontextmanager
import json
from pathlib import Path

from fastapi import FastAPI, WebSocket, WebSocketDisconnect
from loguru import logger

from voice_tts.api.protocol import (
    AudioMessage,
    ClientMessageUnion,
    ConfigMessage,
    EmotionMessage,
    ErrorMessage,
    FlushMessage,
    InitMessage,
    ServerMessageUnion,
    StatusMessage,
    StopMessage,
    TextMessage,
)
from voice_tts.audio.formats import float_to_pcm16, pcm16_to_base64
from voice_tts.config import settings
from voice_tts.session.state import SessionState, VoiceProfile
from voice_tts.tts.engine import DummyTTSEngine, TTSEngine
from voice_tts.tts.f5_backend import F5TTSEngine
from voice_tts.tts.segmenter import Segmenter


# Supported TTS backends
_BACKEND_MAP: dict[str, type[TTSEngine]] = {
    "dummy": DummyTTSEngine,
    "f5_tts": F5TTSEngine,
}


class SessionManager:
    """Manages the lifecycle of a single WebSocket TTS session."""

    def __init__(self, websocket: WebSocket, engine: TTSEngine):
        self.ws = websocket
        self.engine = engine
        self.segmenter = Segmenter(
            min_length=settings.min_segment_length,
            max_length=settings.max_segment_length,
        )
        self.state = SessionState(session_id="")
        self._running = True
        self._tasks: list[asyncio.Task] = []
        self._send_lock = asyncio.Lock()
        self._synth_lock = asyncio.Lock()

    async def run(self) -> None:
        await self.ws.accept()
        logger.info("WebSocket connection accepted")

        # Start audio sender worker
        sender_task = asyncio.create_task(self._audio_sender())
        self._tasks.append(sender_task)

        try:
            while self._running:
                raw = await self.ws.receive_text()
                try:
                    data = json.loads(raw)
                    msg = self._parse_message(data)
                except Exception as exc:
                    await self._send(ErrorMessage(message=f"Invalid message: {exc}"))
                    continue

                await self._handle_message(msg)
        except WebSocketDisconnect:
            logger.info("Client disconnected")
        finally:
            self._stop_all()
            for task in self._tasks:
                task.cancel()
                try:
                    await task
                except asyncio.CancelledError:
                    pass

    def _parse_message(self, data: dict) -> ClientMessageUnion:
        msg_type = data.get("type")
        if msg_type == "init":
            return InitMessage(**data)
        if msg_type == "text":
            return TextMessage(**data)
        if msg_type == "flush":
            return FlushMessage(**data)
        if msg_type == "stop":
            return StopMessage(**data)
        if msg_type == "emotion":
            return EmotionMessage(**data)
        if msg_type == "config":
            return ConfigMessage(**data)
        raise ValueError(f"Unknown message type: {msg_type}")

    async def _handle_message(self, msg: ClientMessageUnion) -> None:
        if isinstance(msg, InitMessage):
            await self._handle_init(msg)
        elif isinstance(msg, TextMessage):
            await self._handle_text(msg)
        elif isinstance(msg, FlushMessage):
            await self._handle_flush(msg)
        elif isinstance(msg, StopMessage):
            await self._handle_stop(msg)
        elif isinstance(msg, EmotionMessage):
            await self._handle_emotion(msg)
        elif isinstance(msg, ConfigMessage):
            await self._handle_config(msg)

    async def _handle_init(self, msg: InitMessage) -> None:
        self.state.session_id = msg.session_id or "default"
        self.state.language = msg.language
        self.state.speed = msg.speed
        self.state.emotion = msg.emotion
        self.state.clear_buffer()
        self.state.reset_stop()

        voice = VoiceProfile()
        if msg.voice_ref:
            voice.default_ref = Path(msg.voice_ref)
        if msg.voice_refs:
            voice.emotion_refs = {
                emotion: Path(path) for emotion, path in msg.voice_refs.items()
            }
        self.state.voice = voice

        await self._send(StatusMessage(event="session_ready", seq=msg.seq))
        logger.info(
            "Session initialized: id={} language={} speed={} emotion={}",
            self.state.session_id,
            self.state.language,
            self.state.speed,
            self.state.emotion,
        )

    async def _handle_text(self, msg: TextMessage) -> None:
        if self.state.is_stopped():
            # After a stop, new text implicitly resets the stop flag
            self.state.reset_stop()

        self.state.text_buffer += msg.payload
        if msg.emotion:
            self.state.emotion = msg.emotion

        from voice_tts.tts.utils import preprocess_text_for_tts

        self.state.text_buffer = preprocess_text_for_tts(self.state.text_buffer)
        remaining, segments = self.segmenter.feed(self.state.text_buffer)
        self.state.text_buffer = remaining

        for segment in segments:
            asyncio.create_task(self._synthesize_segment(segment.text, msg.seq))

        # If this text chunk itself ends a sentence and the buffer is too short,
        # we might still be holding it. Force a flush so short complete sentences
        # (e.g. "дела?") are not stuck waiting for more input.
        if _looks_like_complete_phrase(msg.payload) and self.state.text_buffer.strip():
            segments = self.segmenter.flush(self.state.text_buffer)
            self.state.text_buffer = ""
            for segment in segments:
                asyncio.create_task(self._synthesize_segment(segment.text, msg.seq))

    async def _handle_flush(self, msg: FlushMessage) -> None:
        segments = self.segmenter.flush(self.state.text_buffer)
        self.state.text_buffer = ""
        for segment in segments:
            asyncio.create_task(self._synthesize_segment(segment.text, msg.seq))

    async def _handle_stop(self, msg: StopMessage) -> None:
        logger.info("Stop requested: {}", msg.reason)
        self.state.stop()
        self.state.clear_buffer()

        # Drain pending audio queue
        while not self.state.audio_queue.empty():
            try:
                self.state.audio_queue.get_nowait()
                self.state.audio_queue.task_done()
            except asyncio.QueueEmpty:
                break

        await self._send(StatusMessage(event="stopped", reason=msg.reason, seq=msg.seq))
        # Don't kill the connection: agent may continue the session after stop.

    async def _handle_emotion(self, msg: EmotionMessage) -> None:
        self.state.emotion = msg.emotion
        await self._send(
            StatusMessage(
                event="config_updated",
                extra={"emotion": msg.emotion},
                seq=msg.seq,
            )
        )

    async def _handle_config(self, msg: ConfigMessage) -> None:
        if msg.speed is not None:
            self.state.speed = msg.speed
        if msg.language is not None:
            self.state.language = msg.language
        await self._send(
            StatusMessage(
                event="config_updated",
                extra={"speed": self.state.speed, "language": self.state.language},
                seq=msg.seq,
            )
        )

    async def _synthesize_segment(self, text: str, seq: int | None) -> None:
        if self.state.is_stopped():
            return

        segment_seq = self.state.next_segment_seq()
        await self._send(StatusMessage(event="segment_started", seq=segment_seq))

        ref_path = self.state.voice.ref_for(self.state.emotion)

        try:
            # Serialize GPU inference and run it off the main event loop.
            async with self._synth_lock:
                audio = await asyncio.to_thread(
                    SessionManager._sync_synthesize,
                    self.engine,
                    text,
                    ref_path,
                    self.state.language,
                    self.state.speed,
                    self.state.emotion,
                )
        except Exception as exc:
            logger.exception("TTS synthesis failed")
            await self._send(ErrorMessage(message=f"TTS failed: {exc}", seq=segment_seq))
            return

        if self.state.is_stopped():
            return

        pcm = float_to_pcm16(audio)
        await self.state.audio_queue.put((segment_seq, pcm))

    async def _audio_sender(self) -> None:
        while self._running:
            try:
                segment_seq, pcm = await asyncio.wait_for(self.state.audio_queue.get(), timeout=0.5)
            except asyncio.TimeoutError:
                continue

            if self.state.is_stopped():
                self.state.audio_queue.task_done()
                continue

            await self._send(
                AudioMessage(
                    sample_rate=self.engine.sample_rate,
                    data=pcm16_to_base64(pcm),
                    seq=segment_seq,
                )
            )
            await self._send(StatusMessage(event="segment_finished", seq=segment_seq))
            self.state.audio_queue.task_done()

    async def _send(self, msg: ServerMessageUnion) -> None:
        async with self._send_lock:
            try:
                await self.ws.send_text(msg.model_dump_json(exclude_none=True))
            except Exception as exc:
                logger.warning("Failed to send WebSocket message: {}", exc)
                self._running = False

    @staticmethod
    def _sync_synthesize(
        engine: "TTSEngine",
        text: str,
        ref_path: Path | None,
        language: str,
        speed: float,
        emotion: str,
    ) -> "np.ndarray":
        """Thread-safe wrapper around the synchronous TTS inference call."""
        import numpy as np

        # The dummy backend is async; run it on a transient event loop in this thread.
        if isinstance(engine, DummyTTSEngine):
            try:
                loop = asyncio.get_running_loop()
            except RuntimeError:
                loop = None

            async def _run() -> "np.ndarray":
                return await engine.synthesize(
                    text=text,
                    ref_audio_path=ref_path,
                    language=language,
                    speed=speed,
                    emotion=emotion,
                )

            if loop is not None:
                return asyncio.run_coroutine_threadsafe(_run(), loop).result()
            return asyncio.run(_run())

        # F5-TTS exposes an async synthesize method that blocks on CPU/CUDA work.
        # Inside a thread from asyncio.to_thread there is no running loop, so we
        # drive the coroutine with a fresh transient event loop.
        return asyncio.run(
            engine.synthesize(
                text=text,
                ref_audio_path=ref_path,
                language=language,
                speed=speed,
                emotion=emotion,
            )
        )

    def _stop_all(self) -> None:
        self._running = False
        self.state.stop()


def _looks_like_complete_phrase(text: str) -> bool:
    """Heuristic: chunk ends with strong punctuation -> sentence is complete."""
    text = text.rstrip()
    return bool(text) and text[-1] in ".。!??!"


def _create_engine() -> TTSEngine:
    backend = settings.tts_backend
    engine_cls = _BACKEND_MAP.get(backend)
    if engine_cls is None:
        raise RuntimeError(
            f"Unknown TTS backend: {backend}. "
            f"Available backends: {list(_BACKEND_MAP.keys())}"
        )
    engine = engine_cls(sample_rate=settings.tts_sample_rate)
    if hasattr(engine, "load"):
        engine.load()
    return engine


def build_app(engine: TTSEngine | None = None) -> FastAPI:
    tts_engine = engine or _create_engine()

    @asynccontextmanager
    async def _lifespan(app: FastAPI) -> None:
        if settings.warmup and settings.default_voice_ref is not None:
            try:
                logger.info(
                    "Warming up {} with reference {} ...",
                    settings.tts_backend,
                    settings.default_voice_ref,
                )
                await asyncio.to_thread(
                    SessionManager._sync_synthesize,
                    tts_engine,
                    settings.warmup_text,
                    settings.default_voice_ref,
                    "en",
                    1.0,
                    "neutral",
                )
                logger.info("Warm-up complete.")
            except Exception as exc:
                logger.warning("Warm-up failed (continuing anyway): {}", exc)
        else:
            logger.info("Warm-up skipped.")
        yield

    app = FastAPI(title="Voice TTS", version="0.1.0", lifespan=_lifespan)

    @app.websocket("/ws")
    async def websocket_endpoint(websocket: WebSocket) -> None:
        session = SessionManager(websocket, tts_engine)
        await session.run()

    @app.get("/health")
    async def health() -> dict:
        return {"status": "ok", "backend": settings.tts_backend}

    return app


app = build_app()