from pathlib import Path
import numpy as np
from loguru import logger
from voice_tts.config import settings
from voice_tts.tts.engine import TTSEngine
try:
import torch
import torchaudio
from f5_tts.api import F5TTS
from f5_tts.infer.utils_infer import preprocess_ref_audio_text
F5_AVAILABLE = True
except ImportError as exc:
logger.warning("f5-tts/torch dependencies not available: {}", exc)
F5_AVAILABLE = False
torch = None
torchaudio = None
class F5TTSEngine(TTSEngine):
"""F5-TTS backend for local GPU inference with voice cloning."""
def __init__(
self,
model: str = "F5TTS_v1_Base",
sample_rate: int = 24_000,
device: str | None = None,
nfe_step: int = 32,
cfg_strength: float = 2.0,
sway_sampling_coef: float = -1.0,
speed: float = 1.0,
target_rms: float = 0.1,
cross_fade_duration: float = 0.0,
remove_silence: bool = False,
):
super().__init__()
self.model_name = model
self.sample_rate = sample_rate
self.device = device or settings.device
self.nfe_step = nfe_step
self.cfg_strength = cfg_strength
self.sway_sampling_coef = sway_sampling_coef
self.speed = speed
self.target_rms = target_rms
self.cross_fade_duration = cross_fade_duration
self.remove_silence = remove_silence
self._f5: "F5TTS | None" = None
self._ref_cache: dict[str, tuple[str, str]] = {} # emotion_key -> (processed_audio_path, ref_text)
def _get_key(self, ref_path: Path | None, emotion: str) -> str:
path_str = str(ref_path) if ref_path else ""
return f"{emotion}::{path_str}"
def load(self) -> None:
if not F5_AVAILABLE:
raise RuntimeError(
"f5-tts/torch package is not installed. Install it: pip install f5-tts torch torchaudio"
)
logger.info(
"Loading F5-TTS model {} on device {} ...",
self.model_name,
self.device,
)
self._f5 = F5TTS(
model=self.model_name,
device=self.device,
)
self.sample_rate = self._f5.target_sample_rate
logger.info("F5-TTS loaded. Target sample rate: {}", self.sample_rate)
async def warm_up(self) -> None:
if self._f5 is None:
self.load()
logger.info("Warming up F5-TTS ...")
logger.info("Warm-up skipped: provide a reference audio before warm-up.")
def _ensure_reference(
self,
ref_audio_path: Path | None,
emotion: str,
ref_text_override: str | None = None,
) -> tuple[str, str]:
if not F5_AVAILABLE:
raise RuntimeError("f5-tts/torch is not installed")
if ref_audio_path is None:
raise ValueError("F5-TTS requires a reference audio file (voice_ref).")
if isinstance(ref_audio_path, str):
ref_audio_path = Path(ref_audio_path)
key = self._get_key(ref_audio_path, emotion)
if key in self._ref_cache:
return self._ref_cache[key]
if not ref_audio_path.exists():
raise FileNotFoundError(f"Reference audio not found: {ref_audio_path}")
ref_text = ref_text_override or settings.default_ref_text or ""
processed_audio, ref_text = preprocess_ref_audio_text(
str(ref_audio_path),
ref_text,
)
self._ref_cache[key] = (processed_audio, ref_text)
logger.info(
"Reference cached for emotion={} path={} text={}",
emotion,
ref_audio_path,
ref_text,
)
return processed_audio, ref_text
async def synthesize(
self,
text: str,
ref_audio_path: Path | None,
language: str,
speed: float,
emotion: str,
ref_text: str | None = None,
) -> np.ndarray:
if self._f5 is None:
self.load()
assert self._f5 is not None
if isinstance(ref_audio_path, str):
ref_audio_path = Path(ref_audio_path)
processed_audio, ref_text = self._ensure_reference(ref_audio_path, emotion, ref_text)
wav, sr, _spec = self._f5.infer(
ref_file=processed_audio,
ref_text=ref_text,
gen_text=text,
nfe_step=self.nfe_step,
cfg_strength=self.cfg_strength,
sway_sampling_coef=self.sway_sampling_coef,
speed=speed,
cross_fade_duration=self.cross_fade_duration,
target_rms=self.target_rms,
remove_silence=self.remove_silence,
)
if wav is None:
raise RuntimeError("F5-TTS produced no audio")
# Ensure shape and dtype
if isinstance(wav, torch.Tensor):
wav = wav.squeeze().cpu().numpy()
wav = np.asarray(wav, dtype=np.float32)
if wav.ndim == 0:
wav = np.zeros(1, dtype=np.float32)
elif wav.ndim > 1:
wav = wav.squeeze()
return wav