"""XTTS-v2 backend for local GPU inference with zero-shot voice cloning."""
from pathlib import Path
import numpy as np
from loguru import logger
from voice_tts.config import settings
from voice_tts.tts import register as _register_backend
from voice_tts.tts.engine import TTSEngine
try:
import torch
import torchaudio
from TTS.api import TTS
import TTS.utils.io as tts_io
XTTS_AVAILABLE = True
except ImportError as exc:
logger.warning("TTS/torch dependencies not available: {}", exc)
XTTS_AVAILABLE = False
torch = None
torchaudio = None
_XTTS_SAMPLE_RATE = 24_000
def _patch_weights_only() -> None:
"""XTTS-v2 checkpoints contain legacy pickle classes; allow full load.
PyTorch 2.6+ defaults ``torch.load(..., weights_only=True)``. XTTS-v2
checkpoints require several Coqui classes to be in the safe-globals list
in addition to forcing ``weights_only=False``. This function registers
those classes globally and patches TTS's fsspec loader to default to the
legacy behavior.
"""
if not XTTS_AVAILABLE:
return
if not hasattr(tts_io, "load_fsspec"):
return
# PyTorch 2.6+ safe-global allow-list for XTTS-v2 checkpoint pickles.
if hasattr(torch, "serialization") and hasattr(
torch.serialization, "add_safe_globals"
):
from TTS.config import shared_configs as _shared_configs
from TTS.tts.configs.xtts_config import XttsConfig as _XttsConfig
from TTS.tts.models.xtts import XttsArgs as _XttsArgs
from TTS.tts.models.xtts import XttsAudioConfig as _XttsAudioConfig
for _cls in (
_shared_configs.BaseDatasetConfig,
_XttsConfig,
_XttsArgs,
_XttsAudioConfig,
):
try:
torch.serialization.add_safe_globals([_cls])
except Exception:
pass
_orig = tts_io.load_fsspec
def _patched(model_path: str, map_location: str = "cpu", **kwargs):
kwargs.setdefault("weights_only", False)
return _orig(model_path, map_location=map_location, **kwargs)
tts_io.load_fsspec = _patched # type: ignore[assignment]
@_register_backend("xtts_v2")
class XTTSv2Engine(TTSEngine):
"""XTTS-v2 backend supporting English and Russian zero-shot voice cloning."""
sample_rate: int = _XTTS_SAMPLE_RATE
def __init__(
self,
model_name: str | None = None,
sample_rate: int | None = None,
device: str | None = None,
gpu: bool | None = None,
):
super().__init__()
self.model_name = model_name or settings.tts_model_name
self.sample_rate = sample_rate or settings.tts_sample_rate
self.device = device or settings.device
self.gpu = (gpu if gpu is not None else True) and self.device != "cpu"
self._model: "TTS | None" = None
def _is_supported_language(self, language: str) -> bool:
# XTTS-v2 supported languages: en, es, fr, de, it, pt, pl, tr, ru, nl, cs,
# ar, zh-cn, hu, ko, ja, hi. We currently expose en and ru.
return language.lower() in {"en", "ru"}
def load(self) -> None:
if not XTTS_AVAILABLE:
raise RuntimeError(
"coqui TTS package is not installed. Install it: pip install TTS"
)
logger.info("Loading XTTS-v2 model {} ...", self.model_name)
_patch_weights_only()
# Environment flag required by Coqui to download XTTS.
import os
os.environ.setdefault("COQUI_TOS_AGREED", "1")
self._model = TTS(self.model_name, gpu=self.gpu)
# Force the requested device if not already there.
if self.device.startswith("cuda"):
self._model = self._model.to("cuda")
elif self.device == "cpu":
self._model = self._model.to("cpu")
self.sample_rate = self._model.synthesizer.output_sample_rate or _XTTS_SAMPLE_RATE
logger.info("XTTS-v2 loaded. Output sample rate: {}", self.sample_rate)
async def warm_up(self) -> None:
if self._model is None:
self.load()
logger.info("Warm-up skipped: provide a reference audio before warm-up.")
def _normalize_language(self, language: str) -> str:
language = language.lower()
if language.startswith("ru"):
return "ru"
if language.startswith("en"):
return "en"
return language
async def synthesize(
self,
text: str,
ref_audio_path: Path | None,
language: str,
speed: float,
emotion: str,
ref_text: str | None = None,
) -> np.ndarray:
if self._model is None:
self.load()
assert self._model is not None
if isinstance(ref_audio_path, str):
ref_audio_path = Path(ref_audio_path)
if ref_audio_path is None:
raise ValueError("XTTS-v2 requires a reference audio file (voice_ref).")
if not ref_audio_path.exists():
raise FileNotFoundError(f"Reference audio not found: {ref_audio_path}")
lang = self._normalize_language(language)
if not self._is_supported_language(lang):
raise ValueError(
f"Language '{language}' is not supported by XTTS-v2. "
"Currently enabled languages: en, ru."
)
# Speed is not a direct argument for XTTS.tts_to_file, but we can resample
# the audio to approximate it. We keep the produced samples and stretch them.
out_path = "/tmp/opencode/xtts_synth_tmp.wav"
self._model.tts_to_file(
text=text,
speaker_wav=str(ref_audio_path),
language=lang,
file_path=out_path,
)
wav, sr = torchaudio.load(out_path)
wav = wav.mean(dim=0).numpy().astype(np.float32)
# Resample if the model returned a different rate.
if sr != self.sample_rate:
resampler = torchaudio.transforms.Resample(sr, self.sample_rate)
wav = resampler(torch.from_numpy(wav)).numpy().astype(np.float32)
# Normalize to [-1, 1].
peak = np.max(np.abs(wav))
if peak > 1.0:
wav = wav / peak
# Approximate speed change by resampling. Simple resampling changes
# pitch slightly; for small adjustments around 1.0 it is acceptable,
# but for larger factors we keep the duration change while reducing
# pitch shift artifacts via phase vocoder would be too expensive.
# The default speed comes from settings.tts_speed and can be overridden
# per-request via the WebSocket config/speak messages.
effective_speed = speed if speed is not None else settings.tts_speed
if effective_speed != 1.0 and effective_speed > 0:
new_rate = int(self.sample_rate * (1.0 / effective_speed))
resampler = torchaudio.transforms.Resample(self.sample_rate, new_rate)
wav = resampler(torch.from_numpy(wav)).numpy().astype(np.float32)
if wav.ndim == 0:
wav = np.zeros(1, dtype=np.float32)
elif wav.ndim > 1:
wav = wav.squeeze()
return wav