"""Benchmark local TTS backends: Fish Speech 1.5 vs XTTS-v2.
Measures:
- cold load time
- per-sentence synthesis time
- real-time factor (RTF)
- Whisper ASR accuracy
Outputs WAV files and a JSON report in outputs/benchmark/.
"""
import asyncio
import json
import os
import time
from pathlib import Path
import numpy as np
import torch
import torchaudio
from loguru import logger
# Ensure project source is importable.
ROOT = Path(__file__).resolve().parent.parent
if str(ROOT / "src") not in os.sys.path:
os.sys.path.insert(0, str(ROOT / "src"))
from voice_tts.audio.formats import float_to_wav_bytes # noqa: E402
from voice_tts.config import settings # noqa: E402
from voice_tts.tts.fish_speech_backend import FishSpeechEngine # noqa: E402
from voice_tts.tts.xtts_backend import XTTSv2Engine # noqa: E402
SENTENCES = [
("en", "Hello, this is a short English sentence for the benchmark."),
(
"ru",
"Добрый вечер, меня зовут Евгений. Это тестовое предложение для проверки качества синтеза.",
),
(
"en",
"The quick brown fox jumps over the lazy dog, testing every letter of the alphabet.",
),
(
"ru",
"Наша цель — сделать речь естественной и понятной на английском и русском языках.",
),
]
def _load_whisper(model_name: str = "large-v3", device: str = "cuda"):
from faster_whisper import WhisperModel
model_path = ROOT / "models" / "faster-whisper" / model_name
model_path.parent.mkdir(parents=True, exist_ok=True)
if not model_path.exists():
logger.info("Downloading faster-whisper {} ...", model_name)
return WhisperModel(
model_name,
device=device if device.startswith("cuda") else "cpu",
compute_type="float16" if device.startswith("cuda") else "int8",
download_root=str(model_path.parent),
)
def _transcribe(model, wav_path: Path) -> str:
segments, _ = model.transcribe(str(wav_path), language=None)
return " ".join(s.text.strip() for s in segments).strip()
def _save_wav(audio: np.ndarray, sr: int, path: Path) -> None:
path.parent.mkdir(parents=True, exist_ok=True)
path.write_bytes(float_to_wav_bytes(audio, sr))
def _normalise(text: str) -> str:
return " ".join(text.lower().replace(",", " ").replace(".", " ").replace("—", " ").split())
def _wer(ref: str, hyp: str) -> float:
"""Simple word-error rate."""
r = _normalise(ref).split()
h = _normalise(hyp).split()
if not r:
return 0.0 if not h else 1.0
# Levenshtein distance
prev = list(range(len(h) + 1))
for i, rc in enumerate(r, 1):
curr = [i]
for j, hc in enumerate(h, 1):
cost = 0 if rc == hc else 1
curr.append(min(curr[-1] + 1, prev[j] + 1, prev[j - 1] + cost))
prev = curr
return prev[-1] / len(r)
async def _synth(engine, text: str, lang: str, ref: Path, speed: float = 1.0):
# Fish Speech has async synthesize, XTTS too; both are heavy CUDA sync.
# Run inside a thread to mimic server behavior.
def _run():
kwargs = dict(
text=text,
ref_audio_path=ref,
language=lang,
speed=speed,
emotion="neutral",
)
if isinstance(engine, FishSpeechEngine):
kwargs["ref_text"] = settings.default_ref_text
return asyncio.run(engine.synthesize(**kwargs))
return await asyncio.to_thread(_run)
async def benchmark_backend(name: str, factory, ref: Path, output_dir: Path):
logger.info("Benchmarking {} ...", name)
report = {"backend": name, "sentences": [], "load_seconds": None}
t0 = time.perf_counter()
engine = factory()
engine.load()
report["load_seconds"] = round(time.perf_counter() - t0, 3)
# Optional warm-up to make per-sentence timing more representative.
warmup_text = "One two three." if name == "xtts_v2" else "Раз, два, три."
await _synth(engine, warmup_text, "en" if name == "xtts_v2" else "ru", ref, 1.0)
whisper_model = _load_whisper(device=settings.device)
for lang, text in SENTENCES:
t0 = time.perf_counter()
audio = await _synth(engine, text, lang, ref, settings.tts_speed)
elapsed = time.perf_counter() - t0
duration = len(audio) / engine.sample_rate if engine.sample_rate else 0.0
rtf = elapsed / duration if duration else 0.0
wav_path = output_dir / name / f"{lang}_{len(report['sentences'])}.wav"
_save_wav(audio, engine.sample_rate, wav_path)
hyp = _transcribe(whisper_model, wav_path)
wer = _wer(text, hyp)
report["sentences"].append(
{
"language": lang,
"text": text,
"duration_seconds": round(duration, 3),
"synth_seconds": round(elapsed, 3),
"rtf": round(rtf, 3),
"whisper": hyp,
"wer": round(wer, 3),
"wav": str(wav_path.relative_to(output_dir)),
}
)
logger.info(
"[{} {}] RTF={} dur={}s synth={}s WER={}",
name,
lang,
round(rtf, 3),
round(duration, 3),
round(elapsed, 3),
round(wer, 3),
)
return report
def _fish_factory():
return FishSpeechEngine(
checkpoint_path=settings.tts_model_path,
source_root=settings.tts_vocab_path,
device=settings.device,
compile=settings.fish_compile,
use_memory_cache=settings.fish_use_memory_cache,
chunk_length=settings.fish_chunk_length,
)
def _xtts_factory():
return XTTSv2Engine(
model_name=settings.tts_model_name,
device=settings.device,
)
async def main():
output_dir = ROOT / "outputs" / "benchmark"
output_dir.mkdir(parents=True, exist_ok=True)
ref = Path(settings.default_voice_ref) if settings.default_voice_ref else None
if not ref or not ref.exists():
raise FileNotFoundError("Set DEFAULT_VOICE_REF to an existing reference WAV.")
reports = []
if "fish_speech" in os.environ.get("BENCHMARK_BACKENDS", "fish_speech,xtts_v2"):
reports.append(await benchmark_backend("fish_speech", _fish_factory, ref, output_dir))
if "xtts_v2" in os.environ.get("BENCHMARK_BACKENDS", "fish_speech,xtts_v2"):
reports.append(await benchmark_backend("xtts_v2", _xtts_factory, ref, output_dir))
summary_path = output_dir / "summary.json"
summary_path.write_text(json.dumps(reports, ensure_ascii=False, indent=2))
logger.info("Report saved to {}", summary_path)
# Print a quick Markdown table.
print("\n## Summary")
print("| Backend | Lang | Dur (s) | Synth (s) | RTF | WER |")
print("|---------|------|---------|-----------|-----|-----|")
for report in reports:
for s in report["sentences"]:
print(
f"| {report['backend']} | {s['language']} | "
f"{s['duration_seconds']} | {s['synth_seconds']} | "
f"{s['rtf']} | {s['wer']} |"
)
if __name__ == "__main__":
asyncio.run(main())