"""Unit tests for the Fish Speech 1.5 backend with mocked inference engine."""
from pathlib import Path
import numpy as np
import pytest
import torch
from voice_tts.config import settings
from voice_tts.tts.fish_speech_backend import FishSpeechEngine
@pytest.fixture
def mocked_fish(monkeypatch, tmp_path):
"""Return a FishSpeechEngine with all heavy model imports stubbed."""
fake_root = tmp_path / "fish-speech"
fake_root.mkdir()
(fake_root / "fish_speech").mkdir()
(fake_root / "fish_speech" / "__init__.py").write_text("")
(fake_root / "fish_speech" / "inference_engine.py").write_text(
"class TTSInferenceEngine:\n pass\n"
)
(fake_root / "fish_speech" / "inference_engine").mkdir()
(fake_root / "fish_speech" / "inference_engine" / "__init__.py").write_text(
"class TTSInferenceEngine:\n pass\n"
)
(fake_root / "fish_speech" / "models").mkdir()
(fake_root / "fish_speech" / "models" / "__init__.py").write_text("")
(fake_root / "fish_speech" / "models" / "text2semantic").mkdir()
(fake_root / "fish_speech" / "models" / "text2semantic" / "__init__.py").write_text("")
(fake_root / "fish_speech" / "models" / "text2semantic" / "inference.py").write_text(
"def launch_thread_safe_queue(*args, **kwargs):\n"
" return object()\n"
)
vqgan = fake_root / "fish_speech" / "models" / "vqgan"
vqgan.mkdir()
(vqgan / "__init__.py").write_text("")
(vqgan / "inference.py").write_text(
"def load_model(*args, **kwargs):\n"
" class FakeDecoder:\n"
" spec_transform = type('T', (), {'sample_rate': 44100})()\n"
" def encode(self, audios, lengths):\n"
" return [None]\n"
" return FakeDecoder()\n"
)
(fake_root / "fish_speech" / "utils").mkdir()
(fake_root / "fish_speech" / "utils" / "__init__.py").write_text("")
(fake_root / "fish_speech" / "utils" / "schema.py").write_text(
"from pydantic import BaseModel\n"
"from typing import List\n"
"class ServeReferenceAudio(BaseModel):\n"
" audio: bytes\n"
" text: str\n"
"class ServeTTSRequest(BaseModel):\n"
" text: str\n"
" references: List[ServeReferenceAudio] = []\n"
" seed: int | None = None\n"
" top_p: float = 0.7\n"
" temperature: float = 0.7\n"
" repetition_penalty: float = 1.2\n"
" max_new_tokens: int = 1024\n"
" chunk_length: int = 200\n"
" use_memory_cache: str = 'on'\n"
)
# Stub the heavy model functions used inside load().
class FakeEngine:
def __init__(self, *args, **kwargs):
pass
def inference(self, req):
sample_rate = 44_100
audio = np.linspace(-0.9, 0.9, sample_rate, dtype=np.float32)
yield type(
"R", (), {"code": "final", "error": None, "audio": (sample_rate, audio)}
)()
# Point sys.path at fake root temporarily so imports succeed.
monkeypatch.syspath_prepend(str(fake_root))
from fish_speech import inference_engine as _ie
monkeypatch.setattr(_ie, "TTSInferenceEngine", FakeEngine)
ref_wav = tmp_path / "ref.wav"
ref_wav.write_bytes(b"RIFF" + b"\x00" * 40)
ref_lab = tmp_path / "ref.lab"
ref_lab.write_text("Reference transcript from lab file.", encoding="utf-8")
engine = FishSpeechEngine(
checkpoint_path=tmp_path / "checkpoint",
source_root=fake_root,
device="cpu",
precision=torch.float32,
chunk_length=200,
)
engine.load()
return engine, ref_wav, ref_lab
@pytest.mark.asyncio
async def test_synthesize_returns_audio(mocked_fish):
engine, ref_wav, _ = mocked_fish
audio = await engine.synthesize(
text="Hello world.",
ref_audio_path=ref_wav,
language="en",
speed=1.0,
emotion="neutral",
)
assert audio.ndim == 1
assert len(audio) == engine.sample_rate
assert audio.dtype == np.float32
assert np.max(np.abs(audio)) <= 1.1
@pytest.mark.asyncio
async def test_unsupported_language_raises(mocked_fish):
engine, ref_wav, _ = mocked_fish
with pytest.raises(ValueError, match="not supported"):
await engine.synthesize(
text="Bonjour.",
ref_audio_path=ref_wav,
language="fr",
speed=1.0,
emotion="neutral",
)
@pytest.mark.asyncio
async def test_missing_reference_raises(mocked_fish):
engine, _, _ = mocked_fish
with pytest.raises(FileNotFoundError):
await engine.synthesize(
text="Hello.",
ref_audio_path=Path("/nonexistent/ref.wav"),
language="en",
speed=1.0,
emotion="neutral",
)
@pytest.mark.asyncio
async def test_ref_text_fallback_order(monkeypatch, mocked_fish):
engine, ref_wav, ref_lab = mocked_fish
captured = {}
def _spy_inference(req):
captured["text"] = req.references[0].text
sample_rate = 44_100
audio = np.zeros(sample_rate, dtype=np.float32)
yield type(
"R", (), {"code": "final", "error": None, "audio": (sample_rate, audio)}
)()
# Replace the already-created engine's inference method with a spy.
engine._engine.inference = _spy_inference
# 1. Explicit ref_text wins.
await engine.synthesize(
text="Hi.",
ref_audio_path=ref_wav,
language="en",
speed=1.0,
emotion="neutral",
ref_text="explicit",
)
assert captured["text"] == "explicit"
# 2. .lab file next to reference is used when no explicit text.
await engine.synthesize(
text="Hi.",
ref_audio_path=ref_wav,
language="en",
speed=1.0,
emotion="neutral",
)
assert captured["text"] == "Reference transcript from lab file."
# 3. settings.default_ref_text wins over placeholder.
monkeypatch.setattr(settings, "default_ref_text", "settings default")
ref_lab.unlink(missing_ok=True)
await engine.synthesize(
text="Hi.",
ref_audio_path=ref_wav,
language="en",
speed=1.0,
emotion="neutral",
)
assert captured["text"] == "settings default"
# 4. Placeholder when nothing else is available.
monkeypatch.setattr(settings, "default_ref_text", None)
await engine.synthesize(
text="Hi.",
ref_audio_path=ref_wav,
language="en",
speed=1.0,
emotion="neutral",
)
assert captured["text"] == "Hello, this is a reference voice recording."
@pytest.mark.asyncio
async def test_speed_resampling_changes_length(mocked_fish):
engine, ref_wav, _ = mocked_fish
slow = await engine.synthesize(
text="Hello.",
ref_audio_path=ref_wav,
language="en",
speed=1.5,
emotion="neutral",
)
fast = await engine.synthesize(
text="Hello.",
ref_audio_path=ref_wav,
language="en",
speed=0.8,
emotion="neutral",
)
assert len(fast) > len(slow)
# Speed 1.5 should produce fewer samples than original; 0.8 should produce more.
normal = await engine.synthesize(
text="Hello.",
ref_audio_path=ref_wav,
language="en",
speed=1.0,
emotion="neutral",
)
assert len(slow) < len(normal)
assert len(fast) > len(normal)