Newer
Older
voice / src / voice_tts / tts / xtts_backend.py
"""XTTS-v2 backend for local GPU inference with zero-shot voice cloning."""

from pathlib import Path

import numpy as np
from loguru import logger

from voice_tts.config import settings
from voice_tts.tts import register as _register_backend
from voice_tts.tts.engine import TTSEngine

try:
    import torch
    import torchaudio
    from TTS.api import TTS
    import TTS.utils.io as tts_io

    XTTS_AVAILABLE = True
except ImportError as exc:
    logger.warning("TTS/torch dependencies not available: {}", exc)
    XTTS_AVAILABLE = False
    torch = None
    torchaudio = None


_XTTS_SAMPLE_RATE = 24_000


def _patch_weights_only() -> None:
    """XTTS-v2 checkpoints contain legacy pickle classes; allow full load.

    PyTorch 2.6+ defaults ``torch.load(..., weights_only=True)``. XTTS-v2
    checkpoints require several Coqui classes to be in the safe-globals list
    in addition to forcing ``weights_only=False``. This function registers
    those classes globally and patches TTS's fsspec loader to default to the
    legacy behavior.
    """
    if not XTTS_AVAILABLE:
        return

    if not hasattr(tts_io, "load_fsspec"):
        return

    # PyTorch 2.6+ safe-global allow-list for XTTS-v2 checkpoint pickles.
    if hasattr(torch, "serialization") and hasattr(
        torch.serialization, "add_safe_globals"
    ):
        from TTS.config import shared_configs as _shared_configs
        from TTS.tts.configs.xtts_config import XttsConfig as _XttsConfig
        from TTS.tts.models.xtts import XttsArgs as _XttsArgs
        from TTS.tts.models.xtts import XttsAudioConfig as _XttsAudioConfig

        for _cls in (
            _shared_configs.BaseDatasetConfig,
            _XttsConfig,
            _XttsArgs,
            _XttsAudioConfig,
        ):
            try:
                torch.serialization.add_safe_globals([_cls])
            except Exception:
                pass

    _orig = tts_io.load_fsspec

    def _patched(model_path: str, map_location: str = "cpu", **kwargs):
        kwargs.setdefault("weights_only", False)
        return _orig(model_path, map_location=map_location, **kwargs)

    tts_io.load_fsspec = _patched  # type: ignore[assignment]


@_register_backend("xtts_v2")
class XTTSv2Engine(TTSEngine):
    """XTTS-v2 backend supporting English and Russian zero-shot voice cloning."""

    sample_rate: int = _XTTS_SAMPLE_RATE

    def __init__(
        self,
        model_name: str | None = None,
        sample_rate: int | None = None,
        device: str | None = None,
        gpu: bool | None = None,
    ):
        super().__init__()
        self.model_name = model_name or settings.tts_model_name
        self.sample_rate = sample_rate or settings.tts_sample_rate
        self.device = device or settings.device
        self.gpu = (gpu if gpu is not None else True) and self.device != "cpu"

        self._model: "TTS | None" = None

    def _is_supported_language(self, language: str) -> bool:
        # XTTS-v2 supported languages: en, es, fr, de, it, pt, pl, tr, ru, nl, cs,
        # ar, zh-cn, hu, ko, ja, hi. We currently expose en and ru.
        return language.lower() in {"en", "ru"}

    def load(self) -> None:
        if not XTTS_AVAILABLE:
            raise RuntimeError(
                "coqui TTS package is not installed. Install it: pip install TTS"
            )

        logger.info("Loading XTTS-v2 model {} ...", self.model_name)
        _patch_weights_only()

        # Environment flag required by Coqui to download XTTS.
        import os

        os.environ.setdefault("COQUI_TOS_AGREED", "1")

        self._model = TTS(self.model_name, gpu=self.gpu)
        # Force the requested device if not already there.
        if self.device.startswith("cuda"):
            self._model = self._model.to("cuda")
        elif self.device == "cpu":
            self._model = self._model.to("cpu")

        self.sample_rate = self._model.synthesizer.output_sample_rate or _XTTS_SAMPLE_RATE
        logger.info("XTTS-v2 loaded. Output sample rate: {}", self.sample_rate)

    async def warm_up(self) -> None:
        if self._model is None:
            self.load()
        logger.info("Warm-up skipped: provide a reference audio before warm-up.")

    def _normalize_language(self, language: str) -> str:
        language = language.lower()
        if language.startswith("ru"):
            return "ru"
        if language.startswith("en"):
            return "en"
        return language

    async def synthesize(
        self,
        text: str,
        ref_audio_path: Path | None,
        language: str,
        speed: float,
        emotion: str,
        ref_text: str | None = None,
    ) -> np.ndarray:
        if self._model is None:
            self.load()

        assert self._model is not None

        if isinstance(ref_audio_path, str):
            ref_audio_path = Path(ref_audio_path)

        if ref_audio_path is None:
            raise ValueError("XTTS-v2 requires a reference audio file (voice_ref).")
        if not ref_audio_path.exists():
            raise FileNotFoundError(f"Reference audio not found: {ref_audio_path}")

        lang = self._normalize_language(language)
        if not self._is_supported_language(lang):
            raise ValueError(
                f"Language '{language}' is not supported by XTTS-v2. "
                "Currently enabled languages: en, ru."
            )

        # Speed is not a direct argument for XTTS.tts_to_file, but we can resample
        # the audio to approximate it. We keep the produced samples and stretch them.
        out_path = "/tmp/opencode/xtts_synth_tmp.wav"
        self._model.tts_to_file(
            text=text,
            speaker_wav=str(ref_audio_path),
            language=lang,
            file_path=out_path,
        )

        wav, sr = torchaudio.load(out_path)
        wav = wav.mean(dim=0).numpy().astype(np.float32)

        # Resample if the model returned a different rate.
        if sr != self.sample_rate:
            resampler = torchaudio.transforms.Resample(sr, self.sample_rate)
            wav = resampler(torch.from_numpy(wav)).numpy().astype(np.float32)

        # Normalize to [-1, 1].
        peak = np.max(np.abs(wav))
        if peak > 1.0:
            wav = wav / peak

        # Approximate speed change by resampling. Simple resampling changes
        # pitch slightly; for small adjustments around 1.0 it is acceptable,
        # but for larger factors we keep the duration change while reducing
        # pitch shift artifacts via phase vocoder would be too expensive.
        # The default speed comes from settings.tts_speed and can be overridden
        # per-request via the WebSocket config/speak messages.
        effective_speed = speed if speed is not None else settings.tts_speed
        if effective_speed != 1.0 and effective_speed > 0:
            new_rate = int(self.sample_rate * (1.0 / effective_speed))
            resampler = torchaudio.transforms.Resample(self.sample_rate, new_rate)
            wav = resampler(torch.from_numpy(wav)).numpy().astype(np.float32)

        if wav.ndim == 0:
            wav = np.zeros(1, dtype=np.float32)
        elif wav.ndim > 1:
            wav = wav.squeeze()

        return wav