Newer
Older
voice / src / voice_tts / tts / fish_speech_backend.py
"""Fish Speech 1.5 backend for local GPU inference with zero-shot voice cloning."""

from pathlib import Path

import numpy as np
from loguru import logger

from voice_tts.config import settings
from voice_tts.tts import register as _register_backend
from voice_tts.tts.engine import TTSEngine

try:
    import sys

    import torch
    import torchaudio

    FISH_SPEECH_AVAILABLE = True
except ImportError as exc:
    logger.warning("torch/torchaudio not available for Fish Speech: {}", exc)
    FISH_SPEECH_AVAILABLE = False
    torch = None
    torchaudio = None


_FISH_SAMPLE_RATE = 44_100
_DEFAULT_SOURCE_ROOT = Path("models/fish-speech")


@_register_backend("fish_speech")
class FishSpeechEngine(TTSEngine):
    """Fish Speech 1.5 backend supporting English and Russian zero-shot TTS."""

    sample_rate: int = _FISH_SAMPLE_RATE

    def __init__(
        self,
        checkpoint_path: Path | str | None = None,
        source_root: Path | str | None = None,
        sample_rate: int | None = None,
        device: str | None = None,
        precision: torch.dtype | None = None,
        compile: bool | None = None,
        use_memory_cache: str | None = None,
        chunk_length: int | None = None,
        top_p: float | None = None,
        temperature: float | None = None,
        repetition_penalty: float | None = None,
        seed: int | None = None,
        tail_silence_threshold: float | None = None,
        lowpass_cutoff: int | None = None,
    ):
        super().__init__()

        if not FISH_SPEECH_AVAILABLE:
            raise RuntimeError(
                "Fish Speech backend requires torch/torchaudio and the Fish Speech "
                "source tree at models/fish-speech"
            )

        self.sample_rate = sample_rate or settings.tts_sample_rate
        self.device = device or settings.device
        self.precision = precision or (
            torch.bfloat16 if settings.dtype == "bfloat16" else torch.float32
        )
        self.compile = compile if compile is not None else settings.fish_compile
        self.use_memory_cache = use_memory_cache or settings.fish_use_memory_cache
        self.chunk_length = chunk_length or settings.fish_chunk_length
        self.top_p = top_p if top_p is not None else settings.fish_top_p
        self.temperature = temperature if temperature is not None else settings.fish_temperature
        self.repetition_penalty = repetition_penalty if repetition_penalty is not None else settings.fish_repetition_penalty
        self.seed = seed if seed is not None else settings.fish_seed
        self.tail_silence_threshold = tail_silence_threshold if tail_silence_threshold is not None else settings.fish_tail_silence_threshold
        self.lowpass_cutoff = lowpass_cutoff if lowpass_cutoff is not None else settings.fish_lowpass_cutoff
        self.repetition_penalty = repetition_penalty
        self.seed = seed
        self.tail_silence_threshold = tail_silence_threshold
        self.lowpass_cutoff = lowpass_cutoff

        self.source_root = Path(
            source_root or settings.tts_vocab_path or _DEFAULT_SOURCE_ROOT
        )
        self.checkpoint_path = Path(
            checkpoint_path or settings.tts_model_path or "models/fishaudio_fish-speech-1.5"
        )

        self._llama_queue = None
        self._decoder = None
        self._engine = None
        self._loaded = False
        # Cache reference audio bytes/text per path to avoid repeated disk reads.
        self._ref_cache: dict[Path, tuple[bytes, str]] = {}

    def _ensure_source_path(self) -> None:
        """Make sure the cloned Fish Speech source is on sys.path."""
        root = str(self.source_root.resolve())
        if root not in sys.path:
            sys.path.insert(0, root)

    def _is_supported_language(self, language: str) -> bool:
        # Fish Speech 1.5 is multilingual; we expose English and Russian.
        return language.lower() in {"en", "ru"}

    def load(self) -> None:
        if self._loaded:
            return

        self._ensure_source_path()

        from fish_speech.inference_engine import TTSInferenceEngine
        from fish_speech.models.text2semantic.inference import (
            launch_thread_safe_queue,
        )
        from fish_speech.models.vqgan.inference import load_model as load_decoder_model

        logger.info(
            "Loading Fish Speech 1.5 from {} (source: {}) ...",
            self.checkpoint_path,
            self.source_root,
        )

        llama_checkpoint = self.checkpoint_path / "model.pth"
        decoder_checkpoint = (
            self.checkpoint_path
            / "firefly-gan-vq-fsq-8x1024-21hz-generator.pth"
        )

        self._llama_queue = launch_thread_safe_queue(
            checkpoint_path=str(self.checkpoint_path),
            device=self.device,
            precision=self.precision,
            compile=self.compile,
        )

        self._decoder = load_decoder_model(
            config_name="firefly_gan_vq",
            checkpoint_path=str(decoder_checkpoint),
            device=self.device,
        )

        self._engine = TTSInferenceEngine(
            llama_queue=self._llama_queue,
            decoder_model=self._decoder,
            precision=self.precision,
            compile=self.compile,
        )

        self.sample_rate = self._decoder.spec_transform.sample_rate
        self._loaded = True
        logger.info(
            "Fish Speech 1.5 loaded. Output sample rate: {}", self.sample_rate
        )

    async def warm_up(self) -> None:
        if not self._loaded:
            self.load()
        logger.info("Fish Speech warm-up skipped; first synthesis will warm the cache.")

    def _normalize_language(self, language: str) -> str:
        language = language.lower()
        if language.startswith("ru"):
            return "ru"
        if language.startswith("en"):
            return "en"
        return language

    async def synthesize(
        self,
        text: str,
        ref_audio_path: Path | None,
        language: str,
        speed: float,
        emotion: str,
        ref_text: str | None = None,
    ) -> np.ndarray:
        if not self._loaded:
            self.load()

        if isinstance(ref_audio_path, str):
            ref_audio_path = Path(ref_audio_path)

        if ref_audio_path is None:
            ref_audio_path = settings.default_voice_ref
        if ref_audio_path is None:
            raise ValueError("Fish Speech requires a reference audio file (voice_ref).")
        if not ref_audio_path.exists():
            raise FileNotFoundError(f"Reference audio not found: {ref_audio_path}")

        lang = self._normalize_language(language)
        if not self._is_supported_language(lang):
            raise ValueError(
                f"Language '{language}' is not supported by Fish Speech backend. "
                "Currently enabled languages: en, ru."
            )

        from fish_speech.utils.schema import ServeReferenceAudio, ServeTTSRequest

        # Cache reference audio bytes to avoid repeated disk reads. The transcript
        # is cheap to resolve, so we re-evaluate fallback precedence every call.
        cached_audio = self._ref_cache.get(ref_audio_path)
        if cached_audio is None:
            ref_audio = ref_audio_path.read_bytes()
            self._ref_cache[ref_audio_path] = ref_audio
        else:
            ref_audio = cached_audio

        # Reference transcript precedence: caller-provided > .lab > settings > placeholder.
        if ref_text:
            pass
        else:
            ref_text_path = ref_audio_path.with_suffix(".lab")
            if ref_text_path.exists():
                ref_text = ref_text_path.read_text(encoding="utf-8").strip()
            elif settings.default_ref_text:
                ref_text = settings.default_ref_text
            else:
                ref_text = (
                    "Hello, this is a reference voice recording."
                    if lang == "en"
                    else "Здравствуйте. Это тестовая запись голоса."
                )

        req = ServeTTSRequest(
            text=text,
            references=[
                ServeReferenceAudio(audio=ref_audio, text=ref_text)
            ],
            seed=self.seed,
            top_p=self.top_p,
            temperature=self.temperature,
            repetition_penalty=self.repetition_penalty,
            chunk_length=self.chunk_length,
            use_memory_cache=self.use_memory_cache,
        )

        segments = []
        for result in self._engine.inference(req):
            logger.debug(
                "Fish Speech inference result: code={} error={}",
                result.code,
                result.error,
            )
            if result.code == "error":
                error = result.error or RuntimeError("Unknown Fish Speech error")
                raise RuntimeError(f"Fish Speech synthesis failed: {error}")
            if result.audio is not None:
                sr, audio = result.audio
                segments.append(audio.astype(np.float32))

        if not segments:
            raise RuntimeError("Fish Speech produced no audio.")

        wav = np.concatenate(segments)

        # Normalize to [-1, 1].
        peak = np.max(np.abs(wav))
        if peak > 1.0:
            wav = wav / peak

        # Gentle low-pass to reduce VQ codec noise.
        if self.lowpass_cutoff > 0:
            wt = torch.from_numpy(wav).unsqueeze(0)  # (1, samples)
            wt = torchaudio.functional.lowpass_biquad(
                wt, self.sample_rate, self.lowpass_cutoff
            )
            wav = wt.squeeze(0).numpy()

        # Apply speed adjustment via resampling. Fish Speech is already
        # reasonably fast; the default settings.tts_speed allows global tuning.
        effective_speed = speed if speed is not None else settings.tts_speed
        if effective_speed != 1.0 and effective_speed > 0:
            new_rate = int(self.sample_rate * (1.0 / effective_speed))
            resampler = torchaudio.transforms.Resample(self.sample_rate, new_rate)
            wav = resampler(torch.from_numpy(wav)).numpy().astype(np.float32)

        if wav.ndim == 0:
            wav = np.zeros(1, dtype=np.float32)
        elif wav.ndim > 1:
            wav = wav.squeeze()

        return wav