Newer
Older
voice / scripts / benchmark_backends.py
"""Benchmark local TTS backends: Fish Speech 1.5 vs XTTS-v2.

Measures:
  - cold load time
  - per-sentence synthesis time
  - real-time factor (RTF)
  - Whisper ASR accuracy

Outputs WAV files and a JSON report in outputs/benchmark/.
"""

import asyncio
import json
import os
import time
from pathlib import Path

import numpy as np
import torch
import torchaudio
from loguru import logger

# Ensure project source is importable.
ROOT = Path(__file__).resolve().parent.parent
if str(ROOT / "src") not in os.sys.path:
    os.sys.path.insert(0, str(ROOT / "src"))

from voice_tts.audio.formats import float_to_wav_bytes  # noqa: E402
from voice_tts.config import settings  # noqa: E402
from voice_tts.tts.fish_speech_backend import FishSpeechEngine  # noqa: E402
from voice_tts.tts.xtts_backend import XTTSv2Engine  # noqa: E402


SENTENCES = [
    ("en", "Hello, this is a short English sentence for the benchmark."),
    (
        "ru",
        "Добрый вечер, меня зовут Евгений. Это тестовое предложение для проверки качества синтеза.",
    ),
    (
        "en",
        "The quick brown fox jumps over the lazy dog, testing every letter of the alphabet.",
    ),
    (
        "ru",
        "Наша цель — сделать речь естественной и понятной на английском и русском языках.",
    ),
]


def _load_whisper(model_name: str = "large-v3", device: str = "cuda"):
    from faster_whisper import WhisperModel

    model_path = ROOT / "models" / "faster-whisper" / model_name
    model_path.parent.mkdir(parents=True, exist_ok=True)
    if not model_path.exists():
        logger.info("Downloading faster-whisper {} ...", model_name)
    return WhisperModel(
        model_name,
        device=device if device.startswith("cuda") else "cpu",
        compute_type="float16" if device.startswith("cuda") else "int8",
        download_root=str(model_path.parent),
    )


def _transcribe(model, wav_path: Path) -> str:
    segments, _ = model.transcribe(str(wav_path), language=None)
    return " ".join(s.text.strip() for s in segments).strip()


def _save_wav(audio: np.ndarray, sr: int, path: Path) -> None:
    path.parent.mkdir(parents=True, exist_ok=True)
    path.write_bytes(float_to_wav_bytes(audio, sr))


def _normalise(text: str) -> str:
    return " ".join(text.lower().replace(",", " ").replace(".", " ").replace("—", " ").split())


def _wer(ref: str, hyp: str) -> float:
    """Simple word-error rate."""
    r = _normalise(ref).split()
    h = _normalise(hyp).split()
    if not r:
        return 0.0 if not h else 1.0
    # Levenshtein distance
    prev = list(range(len(h) + 1))
    for i, rc in enumerate(r, 1):
        curr = [i]
        for j, hc in enumerate(h, 1):
            cost = 0 if rc == hc else 1
            curr.append(min(curr[-1] + 1, prev[j] + 1, prev[j - 1] + cost))
        prev = curr
    return prev[-1] / len(r)


async def _synth(engine, text: str, lang: str, ref: Path, speed: float = 1.0):
    # Fish Speech has async synthesize, XTTS too; both are heavy CUDA sync.
    # Run inside a thread to mimic server behavior.
    def _run():
        kwargs = dict(
            text=text,
            ref_audio_path=ref,
            language=lang,
            speed=speed,
            emotion="neutral",
        )
        if isinstance(engine, FishSpeechEngine):
            kwargs["ref_text"] = settings.default_ref_text
        return asyncio.run(engine.synthesize(**kwargs))

    return await asyncio.to_thread(_run)


async def benchmark_backend(name: str, factory, ref: Path, output_dir: Path):
    logger.info("Benchmarking {} ...", name)
    report = {"backend": name, "sentences": [], "load_seconds": None}

    t0 = time.perf_counter()
    engine = factory()
    engine.load()
    report["load_seconds"] = round(time.perf_counter() - t0, 3)

    # Optional warm-up to make per-sentence timing more representative.
    warmup_text = "One two three." if name == "xtts_v2" else "Раз, два, три."
    await _synth(engine, warmup_text, "en" if name == "xtts_v2" else "ru", ref, 1.0)

    whisper_model = _load_whisper(device=settings.device)

    for lang, text in SENTENCES:
        t0 = time.perf_counter()
        audio = await _synth(engine, text, lang, ref, settings.tts_speed)
        elapsed = time.perf_counter() - t0

        duration = len(audio) / engine.sample_rate if engine.sample_rate else 0.0
        rtf = elapsed / duration if duration else 0.0

        wav_path = output_dir / name / f"{lang}_{len(report['sentences'])}.wav"
        _save_wav(audio, engine.sample_rate, wav_path)

        hyp = _transcribe(whisper_model, wav_path)
        wer = _wer(text, hyp)

        report["sentences"].append(
            {
                "language": lang,
                "text": text,
                "duration_seconds": round(duration, 3),
                "synth_seconds": round(elapsed, 3),
                "rtf": round(rtf, 3),
                "whisper": hyp,
                "wer": round(wer, 3),
                "wav": str(wav_path.relative_to(output_dir)),
            }
        )
        logger.info(
            "[{} {}] RTF={} dur={}s synth={}s WER={}",
            name,
            lang,
            round(rtf, 3),
            round(duration, 3),
            round(elapsed, 3),
            round(wer, 3),
        )

    return report


def _fish_factory():
    return FishSpeechEngine(
        checkpoint_path=settings.tts_model_path,
        source_root=settings.tts_vocab_path,
        device=settings.device,
        compile=settings.fish_compile,
        use_memory_cache=settings.fish_use_memory_cache,
        chunk_length=settings.fish_chunk_length,
    )


def _xtts_factory():
    return XTTSv2Engine(
        model_name=settings.tts_model_name,
        device=settings.device,
    )


async def main():
    output_dir = ROOT / "outputs" / "benchmark"
    output_dir.mkdir(parents=True, exist_ok=True)

    ref = Path(settings.default_voice_ref) if settings.default_voice_ref else None
    if not ref or not ref.exists():
        raise FileNotFoundError("Set DEFAULT_VOICE_REF to an existing reference WAV.")

    reports = []

    if "fish_speech" in os.environ.get("BENCHMARK_BACKENDS", "fish_speech,xtts_v2"):
        reports.append(await benchmark_backend("fish_speech", _fish_factory, ref, output_dir))
    if "xtts_v2" in os.environ.get("BENCHMARK_BACKENDS", "fish_speech,xtts_v2"):
        reports.append(await benchmark_backend("xtts_v2", _xtts_factory, ref, output_dir))

    summary_path = output_dir / "summary.json"
    summary_path.write_text(json.dumps(reports, ensure_ascii=False, indent=2))
    logger.info("Report saved to {}", summary_path)

    # Print a quick Markdown table.
    print("\n## Summary")
    print("| Backend | Lang | Dur (s) | Synth (s) | RTF | WER |")
    print("|---------|------|---------|-----------|-----|-----|")
    for report in reports:
        for s in report["sentences"]:
            print(
                f"| {report['backend']} | {s['language']} | "
                f"{s['duration_seconds']} | {s['synth_seconds']} | "
                f"{s['rtf']} | {s['wer']} |"
            )


if __name__ == "__main__":
    asyncio.run(main())