"""Fish Audio S2-Pro backend — HTTP client to the local S2 API server."""
import io
from pathlib import Path
import numpy as np
import soundfile as sf
from loguru import logger
from voice_tts.config import settings
from voice_tts.tts import register as _register_backend
from voice_tts.tts.engine import TTSEngine
try:
import requests
S2_AVAILABLE = True
except ImportError:
S2_AVAILABLE = False
requests = None
_S2_SAMPLE_RATE = 44_100
@_register_backend("s2")
class S2Engine(TTSEngine):
"""Fish Audio S2-Pro backend that delegates to a local S2 API server.
The S2 API server must be running separately on ``api_url``.
The reference audio is uploaded once and reused via ``reference_id``.
"""
sample_rate: int = _S2_SAMPLE_RATE
def __init__(self):
super().__init__()
if not S2_AVAILABLE:
raise RuntimeError("S2 backend requires the 'requests' package")
self.api_url = settings.s2_api_url.rstrip("/")
self.sample_rate = settings.tts_sample_rate
self._ref_uploaded = False
# ------------------------------------------------------------------
# Internal helpers
# ------------------------------------------------------------------
def _upload_reference(self, ref_audio_path: Path) -> None:
"""Upload the reference audio to the S2 server.
If a reference with id ``default`` already exists we try to delete it
first via HTTP, then fall back to removing its directory from the
filesystem (the S2 server stores references in ``references/<id>/``
relative to its working directory).
"""
ref_text = settings.default_ref_text or ""
# Try HTTP delete first (works with most setups).
deleted = False
try:
r = requests.delete(
f"{self.api_url}/v1/references/delete",
json={"reference_id": "default"},
timeout=10,
)
deleted = r.status_code in (200, 404)
except Exception:
pass
# If HTTP delete didn't work, try removing the directory directly.
if not deleted:
candidate = Path("models/fish-speech/references/default")
if candidate.exists():
import shutil
shutil.rmtree(candidate)
logger.info("Removed stale reference directory: {}", candidate)
with open(ref_audio_path, "rb") as fh:
resp = requests.post(
f"{self.api_url}/v1/references/add",
data={"id": "default", "text": ref_text},
files={"audio": ("ref.wav", fh, "audio/wav")},
timeout=30,
)
if resp.status_code == 200:
logger.info("Reference 'default' uploaded to S2 server")
self._ref_uploaded = True
else:
raise RuntimeError(
f"S2 reference upload failed: {resp.status_code} {resp.text}"
)
# ------------------------------------------------------------------
# TTSEngine interface
# ------------------------------------------------------------------
async def synthesize(
self,
text: str,
ref_audio_path: Path | None,
language: str,
speed: float,
emotion: str,
ref_text: str | None = None,
) -> np.ndarray:
if ref_audio_path is not None and not self._ref_uploaded:
self._upload_reference(ref_audio_path)
payload = {
"text": text,
"reference_id": "default",
"format": "wav",
"chunk_length": 300,
}
resp = requests.post(
f"{self.api_url}/v1/tts",
json=payload,
timeout=120,
)
if resp.status_code != 200:
raise RuntimeError(f"S2 TTS request failed: {resp.status_code} {resp.text}")
buf = io.BytesIO(resp.content)
audio, sr = sf.read(buf)
self.sample_rate = int(sr)
return audio.astype(np.float32)
async def warm_up(self) -> None:
"""Check that the S2 API server is alive."""
try:
resp = requests.get(f"{self.api_url}/v1/health", timeout=5)
if resp.status_code != 200:
raise RuntimeError(f"S2 server health check failed: {resp.status_code}")
logger.info("S2 API server is healthy at {}", self.api_url)
except requests.ConnectionError as exc:
raise RuntimeError(
f"S2 API server not reachable at {self.api_url}. "
"Make sure the S2 server is running."
) from exc