Newer
Older
voice / src / voice_tts / tts / s2_backend.py
"""Fish Audio S2-Pro backend — HTTP client to the local S2 API server."""

import io
from pathlib import Path

import numpy as np
import soundfile as sf
from loguru import logger

from voice_tts.config import settings
from voice_tts.tts import register as _register_backend
from voice_tts.tts.engine import TTSEngine

try:
    import requests

    S2_AVAILABLE = True
except ImportError:
    S2_AVAILABLE = False
    requests = None

_S2_SAMPLE_RATE = 44_100


@_register_backend("s2")
class S2Engine(TTSEngine):
    """Fish Audio S2-Pro backend that delegates to a local S2 API server.

    The S2 API server must be running separately on ``api_url``.
    The reference audio is uploaded once and reused via ``reference_id``.
    """

    sample_rate: int = _S2_SAMPLE_RATE

    def __init__(self):
        super().__init__()
        if not S2_AVAILABLE:
            raise RuntimeError("S2 backend requires the 'requests' package")

        self.api_url = settings.s2_api_url.rstrip("/")
        self.sample_rate = settings.tts_sample_rate
        self._ref_uploaded = False

    # ------------------------------------------------------------------
    # Internal helpers
    # ------------------------------------------------------------------

    def _upload_reference(self, ref_audio_path: Path) -> None:
        """Upload the reference audio to the S2 server.

        If a reference with id ``default`` already exists we try to delete it
        first via HTTP, then fall back to removing its directory from the
        filesystem (the S2 server stores references in ``references/<id>/``
        relative to its working directory).
        """
        ref_text = settings.default_ref_text or ""

        # Try HTTP delete first (works with most setups).
        deleted = False
        try:
            r = requests.delete(
                f"{self.api_url}/v1/references/delete",
                json={"reference_id": "default"},
                timeout=10,
            )
            deleted = r.status_code in (200, 404)
        except Exception:
            pass

        # If HTTP delete didn't work, try removing the directory directly.
        if not deleted:
            candidate = Path("models/fish-speech/references/default")
            if candidate.exists():
                import shutil
                shutil.rmtree(candidate)
                logger.info("Removed stale reference directory: {}", candidate)

        with open(ref_audio_path, "rb") as fh:
            resp = requests.post(
                f"{self.api_url}/v1/references/add",
                data={"id": "default", "text": ref_text},
                files={"audio": ("ref.wav", fh, "audio/wav")},
                timeout=30,
            )
        if resp.status_code == 200:
            logger.info("Reference 'default' uploaded to S2 server")
            self._ref_uploaded = True
        else:
            raise RuntimeError(
                f"S2 reference upload failed: {resp.status_code} {resp.text}"
            )

    # ------------------------------------------------------------------
    # TTSEngine interface
    # ------------------------------------------------------------------

    async def synthesize(
        self,
        text: str,
        ref_audio_path: Path | None,
        language: str,
        speed: float,
        emotion: str,
        ref_text: str | None = None,
    ) -> np.ndarray:
        if ref_audio_path is not None and not self._ref_uploaded:
            self._upload_reference(ref_audio_path)

        payload = {
            "text": text,
            "reference_id": "default",
            "format": "wav",
            "chunk_length": 300,
        }

        resp = requests.post(
            f"{self.api_url}/v1/tts",
            json=payload,
            timeout=120,
        )
        if resp.status_code != 200:
            raise RuntimeError(f"S2 TTS request failed: {resp.status_code} {resp.text}")

        buf = io.BytesIO(resp.content)
        audio, sr = sf.read(buf)
        self.sample_rate = int(sr)
        return audio.astype(np.float32)

    async def warm_up(self) -> None:
        """Check that the S2 API server is alive."""
        try:
            resp = requests.get(f"{self.api_url}/v1/health", timeout=5)
            if resp.status_code != 200:
                raise RuntimeError(f"S2 server health check failed: {resp.status_code}")
            logger.info("S2 API server is healthy at {}", self.api_url)
        except requests.ConnectionError as exc:
            raise RuntimeError(
                f"S2 API server not reachable at {self.api_url}. "
                "Make sure the S2 server is running."
            ) from exc