Newer
Older
voice / src / voice_tts / tts / f5_backend.py
from pathlib import Path

import numpy as np
from loguru import logger

from voice_tts.config import settings
from voice_tts.tts.engine import TTSEngine


try:
    import torch
    import torchaudio
    from f5_tts.api import F5TTS
    from f5_tts.infer.utils_infer import preprocess_ref_audio_text

    F5_AVAILABLE = True
except ImportError as exc:
    logger.warning("f5-tts/torch dependencies not available: {}", exc)
    F5_AVAILABLE = False
    torch = None
    torchaudio = None


class F5TTSEngine(TTSEngine):
    """F5-TTS backend for local GPU inference with voice cloning."""

    def __init__(
        self,
        model: str = "F5TTS_v1_Base",
        sample_rate: int = 24_000,
        device: str | None = None,
        nfe_step: int = 32,
        cfg_strength: float = 2.0,
        sway_sampling_coef: float = -1.0,
        speed: float = 1.0,
        target_rms: float = 0.1,
        cross_fade_duration: float = 0.0,
        remove_silence: bool = False,
    ):
        super().__init__()
        self.model_name = model
        self.sample_rate = sample_rate
        self.device = device or settings.device
        self.nfe_step = nfe_step
        self.cfg_strength = cfg_strength
        self.sway_sampling_coef = sway_sampling_coef
        self.speed = speed
        self.target_rms = target_rms
        self.cross_fade_duration = cross_fade_duration
        self.remove_silence = remove_silence

        self._f5: "F5TTS | None" = None
        self._ref_cache: dict[str, tuple[str, str]] = {}  # emotion_key -> (processed_audio_path, ref_text)

    def _get_key(self, ref_path: Path | None, emotion: str) -> str:
        path_str = str(ref_path) if ref_path else ""
        return f"{emotion}::{path_str}"

    def load(self) -> None:
        if not F5_AVAILABLE:
            raise RuntimeError(
                "f5-tts/torch package is not installed. Install it: pip install f5-tts torch torchaudio"
            )
        logger.info(
            "Loading F5-TTS model {} on device {} ...",
            self.model_name,
            self.device,
        )
        self._f5 = F5TTS(
            model=self.model_name,
            device=self.device,
        )
        self.sample_rate = self._f5.target_sample_rate
        logger.info("F5-TTS loaded. Target sample rate: {}", self.sample_rate)

    async def warm_up(self) -> None:
        if self._f5 is None:
            self.load()
        logger.info("Warming up F5-TTS ...")
        logger.info("Warm-up skipped: provide a reference audio before warm-up.")

    def _ensure_reference(
        self,
        ref_audio_path: Path | None,
        emotion: str,
    ) -> tuple[str, str]:
        if not F5_AVAILABLE:
            raise RuntimeError("f5-tts/torch is not installed")
        if ref_audio_path is None:
            raise ValueError("F5-TTS requires a reference audio file (voice_ref).")

        if isinstance(ref_audio_path, str):
            ref_audio_path = Path(ref_audio_path)

        key = self._get_key(ref_audio_path, emotion)
        if key in self._ref_cache:
            return self._ref_cache[key]

        if not ref_audio_path.exists():
            raise FileNotFoundError(f"Reference audio not found: {ref_audio_path}")

        processed_audio, ref_text = preprocess_ref_audio_text(
            str(ref_audio_path),
            "",  # empty ref_text triggers automatic transcription
        )
        self._ref_cache[key] = (processed_audio, ref_text)
        logger.info(
            "Reference cached for emotion={} path={} text={}",
            emotion,
            ref_audio_path,
            ref_text,
        )
        return processed_audio, ref_text

    async def synthesize(
        self,
        text: str,
        ref_audio_path: Path | None,
        language: str,
        speed: float,
        emotion: str,
    ) -> np.ndarray:
        if self._f5 is None:
            self.load()

        assert self._f5 is not None

        if isinstance(ref_audio_path, str):
            ref_audio_path = Path(ref_audio_path)

        processed_audio, ref_text = self._ensure_reference(ref_audio_path, emotion)

        wav, sr, _spec = self._f5.infer(
            ref_file=processed_audio,
            ref_text=ref_text,
            gen_text=text,
            nfe_step=self.nfe_step,
            cfg_strength=self.cfg_strength,
            sway_sampling_coef=self.sway_sampling_coef,
            speed=speed,
            cross_fade_duration=self.cross_fade_duration,
            target_rms=self.target_rms,
            remove_silence=self.remove_silence,
        )

        if wav is None:
            raise RuntimeError("F5-TTS produced no audio")

        # Ensure shape and dtype
        if isinstance(wav, torch.Tensor):
            wav = wav.squeeze().cpu().numpy()
        wav = np.asarray(wav, dtype=np.float32)
        if wav.ndim == 0:
            wav = np.zeros(1, dtype=np.float32)
        elif wav.ndim > 1:
            wav = wav.squeeze()

        return wav