"""Fish Speech 1.5 backend for local GPU inference with zero-shot voice cloning."""
from pathlib import Path
import numpy as np
from loguru import logger
from voice_tts.config import settings
from voice_tts.tts import register as _register_backend
from voice_tts.tts.engine import TTSEngine
try:
import sys
import torch
import torchaudio
FISH_SPEECH_AVAILABLE = True
except ImportError as exc:
logger.warning("torch/torchaudio not available for Fish Speech: {}", exc)
FISH_SPEECH_AVAILABLE = False
torch = None
torchaudio = None
_FISH_SAMPLE_RATE = 44_100
_DEFAULT_SOURCE_ROOT = Path("models/fish-speech")
@_register_backend("fish_speech")
class FishSpeechEngine(TTSEngine):
"""Fish Speech 1.5 backend supporting English and Russian zero-shot TTS."""
sample_rate: int = _FISH_SAMPLE_RATE
def __init__(
self,
checkpoint_path: Path | str | None = None,
source_root: Path | str | None = None,
sample_rate: int | None = None,
device: str | None = None,
precision: torch.dtype | None = None,
compile: bool | None = None,
use_memory_cache: str | None = None,
chunk_length: int | None = None,
top_p: float | None = None,
temperature: float | None = None,
repetition_penalty: float | None = None,
seed: int | None = None,
tail_silence_threshold: float | None = None,
lowpass_cutoff: int | None = None,
):
super().__init__()
if not FISH_SPEECH_AVAILABLE:
raise RuntimeError(
"Fish Speech backend requires torch/torchaudio and the Fish Speech "
"source tree at models/fish-speech"
)
self.sample_rate = sample_rate or settings.tts_sample_rate
self.device = device or settings.device
self.precision = precision or (
torch.bfloat16 if settings.dtype == "bfloat16" else torch.float32
)
self.compile = compile if compile is not None else settings.fish_compile
self.use_memory_cache = use_memory_cache or settings.fish_use_memory_cache
self.chunk_length = chunk_length or settings.fish_chunk_length
self.top_p = top_p if top_p is not None else settings.fish_top_p
self.temperature = temperature if temperature is not None else settings.fish_temperature
self.repetition_penalty = repetition_penalty if repetition_penalty is not None else settings.fish_repetition_penalty
self.seed = seed if seed is not None else settings.fish_seed
self.tail_silence_threshold = tail_silence_threshold if tail_silence_threshold is not None else settings.fish_tail_silence_threshold
self.lowpass_cutoff = lowpass_cutoff if lowpass_cutoff is not None else settings.fish_lowpass_cutoff
self.repetition_penalty = repetition_penalty
self.seed = seed
self.tail_silence_threshold = tail_silence_threshold
self.lowpass_cutoff = lowpass_cutoff
self.source_root = Path(
source_root or settings.tts_vocab_path or _DEFAULT_SOURCE_ROOT
)
self.checkpoint_path = Path(
checkpoint_path or settings.tts_model_path or "models/fishaudio_fish-speech-1.5"
)
self._llama_queue = None
self._decoder = None
self._engine = None
self._loaded = False
# Cache reference audio bytes/text per path to avoid repeated disk reads.
self._ref_cache: dict[Path, tuple[bytes, str]] = {}
def _ensure_source_path(self) -> None:
"""Make sure the cloned Fish Speech source is on sys.path."""
root = str(self.source_root.resolve())
if root not in sys.path:
sys.path.insert(0, root)
def _is_supported_language(self, language: str) -> bool:
# Fish Speech 1.5 is multilingual; we expose English and Russian.
return language.lower() in {"en", "ru"}
def load(self) -> None:
if self._loaded:
return
self._ensure_source_path()
from fish_speech.inference_engine import TTSInferenceEngine
from fish_speech.models.text2semantic.inference import (
launch_thread_safe_queue,
)
from fish_speech.models.vqgan.inference import load_model as load_decoder_model
logger.info(
"Loading Fish Speech 1.5 from {} (source: {}) ...",
self.checkpoint_path,
self.source_root,
)
llama_checkpoint = self.checkpoint_path / "model.pth"
decoder_checkpoint = (
self.checkpoint_path
/ "firefly-gan-vq-fsq-8x1024-21hz-generator.pth"
)
self._llama_queue = launch_thread_safe_queue(
checkpoint_path=str(self.checkpoint_path),
device=self.device,
precision=self.precision,
compile=self.compile,
)
self._decoder = load_decoder_model(
config_name="firefly_gan_vq",
checkpoint_path=str(decoder_checkpoint),
device=self.device,
)
self._engine = TTSInferenceEngine(
llama_queue=self._llama_queue,
decoder_model=self._decoder,
precision=self.precision,
compile=self.compile,
)
self.sample_rate = self._decoder.spec_transform.sample_rate
self._loaded = True
logger.info(
"Fish Speech 1.5 loaded. Output sample rate: {}", self.sample_rate
)
async def warm_up(self) -> None:
if not self._loaded:
self.load()
logger.info("Fish Speech warm-up skipped; first synthesis will warm the cache.")
def _normalize_language(self, language: str) -> str:
language = language.lower()
if language.startswith("ru"):
return "ru"
if language.startswith("en"):
return "en"
return language
async def synthesize(
self,
text: str,
ref_audio_path: Path | None,
language: str,
speed: float,
emotion: str,
ref_text: str | None = None,
) -> np.ndarray:
if not self._loaded:
self.load()
if isinstance(ref_audio_path, str):
ref_audio_path = Path(ref_audio_path)
if ref_audio_path is None:
ref_audio_path = settings.default_voice_ref
if ref_audio_path is None:
raise ValueError("Fish Speech requires a reference audio file (voice_ref).")
if not ref_audio_path.exists():
raise FileNotFoundError(f"Reference audio not found: {ref_audio_path}")
lang = self._normalize_language(language)
if not self._is_supported_language(lang):
raise ValueError(
f"Language '{language}' is not supported by Fish Speech backend. "
"Currently enabled languages: en, ru."
)
from fish_speech.utils.schema import ServeReferenceAudio, ServeTTSRequest
# Cache reference audio bytes to avoid repeated disk reads. The transcript
# is cheap to resolve, so we re-evaluate fallback precedence every call.
cached_audio = self._ref_cache.get(ref_audio_path)
if cached_audio is None:
ref_audio = ref_audio_path.read_bytes()
self._ref_cache[ref_audio_path] = ref_audio
else:
ref_audio = cached_audio
# Reference transcript precedence: caller-provided > .lab > settings > placeholder.
if ref_text:
pass
else:
ref_text_path = ref_audio_path.with_suffix(".lab")
if ref_text_path.exists():
ref_text = ref_text_path.read_text(encoding="utf-8").strip()
elif settings.default_ref_text:
ref_text = settings.default_ref_text
else:
ref_text = (
"Hello, this is a reference voice recording."
if lang == "en"
else "Здравствуйте. Это тестовая запись голоса."
)
req = ServeTTSRequest(
text=text,
references=[
ServeReferenceAudio(audio=ref_audio, text=ref_text)
],
seed=self.seed,
top_p=self.top_p,
temperature=self.temperature,
repetition_penalty=self.repetition_penalty,
chunk_length=self.chunk_length,
use_memory_cache=self.use_memory_cache,
)
segments = []
for result in self._engine.inference(req):
logger.debug(
"Fish Speech inference result: code={} error={}",
result.code,
result.error,
)
if result.code == "error":
error = result.error or RuntimeError("Unknown Fish Speech error")
raise RuntimeError(f"Fish Speech synthesis failed: {error}")
if result.audio is not None:
sr, audio = result.audio
segments.append(audio.astype(np.float32))
if not segments:
raise RuntimeError("Fish Speech produced no audio.")
wav = np.concatenate(segments)
# Normalize to [-1, 1].
peak = np.max(np.abs(wav))
if peak > 1.0:
wav = wav / peak
# Gentle low-pass to reduce VQ codec noise.
if self.lowpass_cutoff > 0:
wt = torch.from_numpy(wav).unsqueeze(0) # (1, samples)
wt = torchaudio.functional.lowpass_biquad(
wt, self.sample_rate, self.lowpass_cutoff
)
wav = wt.squeeze(0).numpy()
# Apply speed adjustment via resampling. Fish Speech is already
# reasonably fast; the default settings.tts_speed allows global tuning.
effective_speed = speed if speed is not None else settings.tts_speed
if effective_speed != 1.0 and effective_speed > 0:
new_rate = int(self.sample_rate * (1.0 / effective_speed))
resampler = torchaudio.transforms.Resample(self.sample_rate, new_rate)
wav = resampler(torch.from_numpy(wav)).numpy().astype(np.float32)
if wav.ndim == 0:
wav = np.zeros(1, dtype=np.float32)
elif wav.ndim > 1:
wav = wav.squeeze()
return wav