diff --git a/.env.example b/.env.example
index b875020..d28298b 100644
--- a/.env.example
+++ b/.env.example
@@ -5,27 +5,47 @@
PORT=8765
LOG_LEVEL=INFO
-TTS_BACKEND=f5_tts
-# F5-TTS model name. Built-in options: F5TTS_v1_Base, F5TTS_v1_Small.
-# Downloaded automatically on first use via HuggingFace.
-TTS_MODEL_NAME=F5TTS_v1_Base
-# TTS_MODEL_PATH=models/f5-tts/model.pt
-# TTS_VOCAB_PATH=models/f5-tts/vocab.txt
-TTS_SAMPLE_RATE=24000
+# Available backends:
+# dummy — sine-wave test (no GPU, no deps)
+# s2 — Fish Audio S2-Pro INT4 (требует S2 API сервер на порту 8081)
+# fish_speech — Fish Speech 1.5 (требует чекпоинт fish-speech)
+# f5_tts — F5-TTS v1 (экспериментальный)
+# xtts_v2 — XTTS-v2 от Coqui (авто-загрузка, ~2 GB VRAM)
+TTS_BACKEND=s2
+
+# Paths for Fish Speech / XTTS backends
+TTS_MODEL_PATH=models/fishaudio_fish-speech-1.5
+TTS_VOCAB_PATH=models/fish-speech
+TTS_MODEL_NAME=tts_models/multilingual/multi-dataset/xtts_v2
+TTS_SAMPLE_RATE=44100
+TTS_SPEED=1.5
+
+# Fish Speech tuning
+FISH_COMPILE=false
+FISH_CHUNK_LENGTH=300
+FISH_USE_MEMORY_CACHE=on
+FISH_TOP_P=0.7
+FISH_TEMPERATURE=0.7
+FISH_REPETITION_PENALTY=1.2
+FISH_TAIL_SILENCE_THRESHOLD=0
+FISH_LOWPASS_CUTOFF=0
VOICES_DIR=voices
-# Path to default reference audio (relative to project root or absolute).
-# Providing DEFAULT_VOICE_REF enables instant warm-up and voice cloning.
-DEFAULT_VOICE_REF=voices/rick_ref_clean.wav
-# Exact transcript of the reference audio. When set, Whisper transcription is skipped.
-DEFAULT_REF_TEXT="Ва-ба-ла-ба-дап-дап! Рикки-тики-тави, сученька! И вот такие у нас новости! Иди."
+DEFAULT_VOICE_REF=voices/default_ref.wav
+DEFAULT_REF_TEXT=""
+# Segmentation (прогрессивный порог: первый сегмент короче)
MIN_SEGMENT_LENGTH=30
-MAX_SEGMENT_LENGTH=200
+MAX_SEGMENT_LENGTH=500
MAX_BUFFER_WAIT_MS=500
+FAST_START_INITIAL=12
+FAST_START_COUNT=3
DEVICE=cuda
DTYPE=bfloat16
-# Run a dummy inference at startup to cache reference and prime CUDA.
WARMUP=true
+WARMUP_TEXT="Привет. Это тестовая фраза."
+
+# S2 API server URL
+S2_API_URL=http://127.0.0.1:8081
diff --git a/AGENTS.md b/AGENTS.md
index c662d52..d16ed6b 100644
--- a/AGENTS.md
+++ b/AGENTS.md
@@ -7,6 +7,9 @@
- **Run server (Fish Speech default):** `python -m voice_tts.main`
- **Dummy backend for fast local tests:** `TTS_BACKEND=dummy python -m voice_tts.main`
- **XTTS-v2 backend:** `TTS_BACKEND=xtts_v2 python -m voice_tts.main`
+- **Fish Audio S2 backend:** `TTS_BACKEND=s2 python -m voice_tts.main` (требует S2 API сервер на порту 8081)
+- **S2 API server (INT4, без compile):** `cd models/fish-speech && .venv/bin/python tools/api_server.py --llama-checkpoint-path checkpoints/fs-1.2-int4-g128-s2-pro-nf4 --decoder-checkpoint-path checkpoints/s2-pro/codec.pth --listen 127.0.0.1:8081`
+- **S2 API server (INT4, compile ~1× real-time):** `cd models/fish-speech && .venv/bin/python tools/api_server.py --llama-checkpoint-path checkpoints/fs-1.2-int4-g128-s2-pro-nf4 --decoder-checkpoint-path checkpoints/s2-pro/codec.pth --listen 127.0.0.1:8081 --compile` — первый ~3.5 мин, потом ~51 tok/s, **~9.6 GB VRAM** (вмещается в 12 GB)
- **Console script (installed):** `voice-tts`
- **Health check:** `curl http://localhost:8765/health`
- **Browser test client:** `cd examples && python -m http.server 8080` → открыть `http://localhost:8080/client_browser.html`
@@ -55,9 +58,17 @@
| `TTS_MODEL_PATH` | — | Fish Speech checkpoint folder (contains model.pth, firefly-gan-vq-fsq-8x1024-21hz-generator.pth, tokenizer.tiktoken, config.json) |
| `TTS_VOCAB_PATH` | — | Fish Speech v1.5 source tree path (used to import firefly_gan / FSQ modules) |
| `TTS_MODEL_NAME` | `tts_models/multilingual/multi-dataset/xtts_v2` | Coqui model manager path; xtts_v2 downloads this on first use |
-| `FISH_COMPILE` | `false` | Avoid setting to true. Enables torch.compile but causes CUDAGraphs tensor-overwrite errors on repeated inference. |
| `FISH_CHUNK_LENGTH` | 200 | Chunk length for Fish Speech (100–300). Higher = more GPU work per call, higher latency. |
+### S2-Pro INT4
+
+- Чекпоинт: `models/fish-speech/checkpoints/fs-1.2-int4-g128-s2-pro-nf4/` (4.9 ГБ model.pth, INT4 quantized, groupsize=128)
+- `config.json:` в чекпоинте имеет `text_config.max_seq_len=4096` — KV cache ~0.67 GB (было 5.36 GB при 32768)
+- **VRAM:** ~4.3 GB модель, ~0.67 GB KV cache, ~4.5 GB codec decoder + compile cache + активации = **~9.6 GB пик** (под 12 GB лимитом)
+- **INT4 авто-определение:** `from_pretrained` детектит `"int4"` в пути к чекпоинту, загружает `model.pth` с `WeightOnlyInt4Linear` модулями
+- Чекпоинт получен через `tools/llama/quantize.py --checkpoint-path .../s2-pro --mode int4 --groupsize 128`
+- PyTorch 2.4.0 обязателен — `_convert_weight_to_int4pack` / `_weight_int4pack_mm` использует kKTileSize=16
+
## WebSocket Protocol (`/ws`)
- Server at `ws://localhost:8765/ws`
diff --git a/README.md b/README.md
index e53e6c4..a1b413d 100644
--- a/README.md
+++ b/README.md
@@ -1,57 +1,118 @@
# Voice TTS
-Local GPU-powered real-time text-to-speech pipeline with a WebSocket API, designed to voice AI agents that stream text in chunks.
+Локальный GPU-пайплайн синтеза речи в реальном времени с WebSocket API.
+Разработан для озвучки ИИ-агентов: текст поступает частями от LLM, аудио отдаётся
+клиенту по мере готовности.
-## Features
+## Основное
-- **Streaming input**: accepts partial text as it is generated by the LLM/agent.
-- **Streaming output**: returns PCM audio chunks over WebSocket as soon as they are synthesized.
-- **Voice cloning**: single speaker cloned from reference audio, with optional per-emotion references.
-- **Interrupt / stop**: agent can immediately stop playback when the user interrupts the AI.
-- **Emotion control**: switch emotion on the fly (requires matching reference audio or supported backend).
-- **Local GPU**: runs entirely on your NVIDIA GPU (RTX 3090 / 3060 compatible).
+- **Стриминг текста**: фразы от LLM приходят частями, сервер не ждёт полного текста
+- **Стриминг аудио**: PCM-чанки отправляются сразу после синтеза
+- **Клонирование голоса**: один спикер по референсному аудио
+- **Прерывание**: немедленная остановка по сигналу `stop`
+- **Выбор бэкенда**: переключение между моделями через `TTS_BACKEND` без изменения кода
-## Project status
+## Бэкенды
-- Working WebSocket server with streaming text, audio streaming, and instant stop/resume.
-- F5-TTS backend installed, GPU-ready, and producing real audio (`models/F5TTS_v1_Base/` downloaded).
-- Dummy backend available for fast offline tests.
-- Startup warm-up caches the default reference and primes CUDA.
-- Next: multilingual evaluation, latency optimization, and client examples.
+| Имя | Модель | Требования | VRAM | RTF |
+|-----|--------|-----------|------|-----|
+| `s2` | Fish Audio S2-Pro INT4 | S2 API сервер (порт 8081), PyTorch 2.4 | ~9.6 GB | **0.59** |
+| `fish_speech` | Fish Speech 1.5 | чекпоинт fish-speech, PyTorch | ~12 GB | ~1.4 |
+| `xtts_v2` | XTTS-v2 (Coqui) | авто-загрузка | ~2 GB | ~0.34 |
+| `f5_tts` | F5-TTS v1 | чекпоинт F5TTS | ~4 GB | ~1.0 |
+| `dummy` | синусоида | ничего | — | — |
-## Quick start
+## Быстрый старт (S2-Pro INT4)
```bash
-# Create virtual environment (Python 3.10-3.12 recommended)
-python3.11 -m venv .venv
-source .venv/bin/activate
+# 1. S2 API сервер (требует ~9.6 GB VRAM, 50 tok/s с --compile)
+cd models/fish-speech
+PYTHONPATH=. /home/gmikcon/Projects/voice/.venv/bin/python tools/api_server.py \
+ --llama-checkpoint-path checkpoints/fs-1.2-int4-g128-s2-pro-nf4 \
+ --decoder-checkpoint-path checkpoints/s2-pro/codec.pth \
+ --listen 127.0.0.1:8081 --compile
-# Install PyTorch with CUDA 12.6 support first
-pip install torch torchaudio --index-url https://download.pytorch.org/whl/cu126
-
-# Install remaining dependencies
-pip install -r requirements.txt
-
-# (Optional) Download the F5-TTS model beforehand
-python scripts/download_f5_tts.py --model F5TTS_v1_Base
-
-# Run the server
-python -m voice_tts.main
-
-# Or run in dummy test mode
-TTS_BACKEND=dummy python -m voice_tts.main
+# 2. WebSocket прокси (в другом терминале)
+TTS_BACKEND=s2 python -m voice_tts.main
```
-Server will listen on `ws://localhost:8765/ws`.
+Сервер слушает `ws://localhost:8765/ws`.
-## WebSocket protocol
+## WebSocket протокол
-See full documentation in [`docs/03_websocket_protocol.md`](docs/03_websocket_protocol.md).
+Полная документация — [`docs/03_websocket_protocol.md`](docs/03_websocket_protocol.md).
-## Architecture and roadmap
+### Кратко
-- [`docs/01_overview.md`](docs/01_overview.md)
-- [`docs/02_architecture.md`](docs/02_architecture.md)
-- [`docs/04_roadmap.md`](docs/04_roadmap.md)
-- [`docs/05_usage.md`](docs/05_usage.md)
-- [`docs/06_technical_notes.md`](docs/06_technical_notes.md)
+Клиент → Сервер:
+- `init` — настройки сессии (голос, язык, скорость)
+- `text` — чанк текста (`payload` — строка)
+- `flush` — озвучить остаток буфера
+- `stop` — немедленно прервать
+- `emotion` / `config` — смена параметров
+
+Сервер → Клиент:
+- `status` — события (`session_ready`, `segment_started`, `segment_finished`, `stopped`)
+- `audio` — PCM16 base64 с `sample_rate`
+- `error` — ошибка
+
+## Архитектура
+
+```
+[LLM / AI Agent] → WebSocket → [SessionManager]
+ │
+ TextBuffer / Segmenter
+ │
+ segments (async tasks)
+ │
+ TTS Engine (через реестр)
+ │
+ AudioQueue → WS send
+```
+
+Бэкенды регистрируются через декоратор `@register("name")` в `voice_tts.tts`.
+Новый бэкенд добавляется одним файлом с декоратором — никаких правок в `server.py`.
+
+```python
+from voice_tts.tts import register as _register_backend
+
+@_register_backend("my_model")
+class MyEngine(TTSEngine):
+ ...
+```
+
+## Установка
+
+```bash
+python3.11 -m venv .venv
+source .venv/bin/activate
+pip install torch torchaudio --index-url https://download.pytorch.org/whl/cu126
+pip install -r requirements.txt
+```
+
+## Конфигурация
+
+Все переменные в `.env` (см. `.env.example`). Ключевые:
+
+| Переменная | По умолч. | Описание |
+|-----------|-----------|----------|
+| `TTS_BACKEND` | `s2` | `dummy` / `s2` / `fish_speech` / `f5_tts` / `xtts_v2` |
+| `DEFAULT_VOICE_REF` | — | Референсное аудио для клонирования |
+| `DEFAULT_REF_TEXT` | — | Точный транскрипт референса |
+| `MIN_SEGMENT_LENGTH` | `30` | Мин. длина сегмента (прогрессивно: 12 → 30) |
+| `MAX_SEGMENT_LENGTH` | `200` | Макс. длина сегмента |
+| `FAST_START_INITIAL` | `12` | Порог для первого сегмента (снижает задержку) |
+| `FAST_START_COUNT` | `3` | Сегментов с прогрессивным порогом |
+
+## Тесты
+
+```bash
+pytest tests/ -v
+```
+
+## Документация
+
+- [`docs/01_overview.md`](docs/01_overview.md) — обзор проекта
+- [`docs/02_architecture.md`](docs/02_architecture.md) — архитектура
+- [`docs/03_websocket_protocol.md`](docs/03_websocket_protocol.md) — протокол
+- [`docs/05_usage.md`](docs/05_usage.md) — использование и развёртывание
diff --git a/docs/05_usage.md b/docs/05_usage.md
index 41105ee..c700c25 100644
--- a/docs/05_usage.md
+++ b/docs/05_usage.md
@@ -2,189 +2,140 @@
## Установка
-> Рекомендуется Python 3.11. Python 3.14+ пока не имеет совместимых wheel для
-> `torch`, поэтому используйте 3.10–3.12.
-
```bash
-# Клонировать / перейти в директорию проекта
-cd voice
-
-# Создать виртуальное окружение
python3.11 -m venv .venv
-source .venv/bin/activate # Windows: .venv\Scripts\activate
-
-# Установить PyTorch с CUDA 12.6 (обязательно первым)
+source .venv/bin/activate
pip install torch torchaudio --index-url https://download.pytorch.org/whl/cu126
-
-# Установить остальные зависимости
pip install -r requirements.txt
-
-# (Опционально) скопировать настройки
cp .env.example .env
+# отредактировать .env под свой бэкенд
```
-## Запуск сервера
+## Запуск: S2-Pro INT4 (рекомендуемый)
+
+Требуется **PyTorch 2.4.0+cu121** (установлен). S2-сервер и WebSocket-прокси
+запускаются в разных терминалах:
```bash
-# Основной режим: Fish Speech 1.5 на GPU (по умолчанию)
-python -m voice_tts.main
+# Терминал 1: S2 API сервер
+cd models/fish-speech
+PYTHONPATH=. /home/gmikcon/Projects/voice/.venv/bin/python tools/api_server.py \
+ --llama-checkpoint-path checkpoints/fs-1.2-int4-g128-s2-pro-nf4 \
+ --decoder-checkpoint-path checkpoints/s2-pro/codec.pth \
+ --listen 127.0.0.1:8081 --compile
+# Первый запуск ~3.5 мин (компиляция), потом ~50 tok/s, ~9.6 GB VRAM
-# С настроенным референсом и warm-up (рекомендуется)
+# Терминал 2: WebSocket прокси
+cd /home/gmikcon/Projects/voice
+TTS_BACKEND=s2 python -m voice_tts.main
+```
+
+### S2 без компиляции (меньше VRAM, медленнее)
+
+```bash
+# Без --compile: ~6 tok/s
+cd models/fish-speech && PYTHONPATH=. /home/gmikcon/Projects/voice/.venv/bin/python \
+ tools/api_server.py \
+ --llama-checkpoint-path checkpoints/fs-1.2-int4-g128-s2-pro-nf4 \
+ --decoder-checkpoint-path checkpoints/s2-pro/codec.pth \
+ --listen 127.0.0.1:8081
+```
+
+### Обновление чекпоинта S2
+
+Если выходит новая версия S2-Pro:
+1. Скачать новый чекпоинт с HuggingFace
+2. Запустить квантизацию: `python tools/llama/quantize.py --checkpoint-path ... --mode int4 --groupsize 128`
+3. Обновить `--llama-checkpoint-path` на новый квантизованный чекпоинт
+
+## Запуск: Fish Speech 1.5
+
+```bash
TTS_BACKEND=fish_speech \
-DEFAULT_VOICE_REF=voices/self_ref_clean.wav \
-DEFAULT_REF_TEXT="Добрый вечер, меня зовут Евгений." \
+DEFAULT_VOICE_REF=voices/default_ref.wav \
+DEFAULT_REF_TEXT="текст референса" \
WARMUP=true \
python -m voice_tts.main
+```
-# Быстрый вариант XTTS-v2
+## Запуск: XTTS-v2
+
+```bash
TTS_BACKEND=xtts_v2 \
-DEFAULT_VOICE_REF=voices/self_ref_clean.wav \
+DEFAULT_VOICE_REF=voices/default_ref.wav \
python -m voice_tts.main
+```
-# Тестовый режим без модели
+Модель скачивается автоматически при первом запуске.
+
+## Запуск: тестовый dummy
+
+```bash
TTS_BACKEND=dummy python -m voice_tts.main
```
-Сервер поднимется на `ws://localhost:8765/ws`.
-
-## Проверка работоспособности
+## Проверка
```bash
curl http://localhost:8765/health
-# {"status":"ok","backend":"fish_speech"}
+# {"status":"ok","backend":"s2"}
```
## Клиенты
-В директории `examples/` лежат готовые клиенты:
-
-- `examples/client_python.py` — Python-клиент с воспроизведением через `sounddevice`.
-- `examples/client_browser.html` — HTML/JS клиент для браузера с `AudioContext`.
-
-### Python-клиент
+В `examples/`:
+- `client_python.py` — Python-клиент с sounddevice
+- `client_browser.html` — браузерный клиент
```bash
pip install websockets sounddevice
python examples/client_python.py --uri ws://localhost:8765/ws "Привет, мир!"
```
-Опции:
-
-```bash
-python examples/client_python.py \
- --uri ws://localhost:8765/ws \
- --voice-ref voices/self_ref_clean.wav \
- --language ru \
- --speed 1.0 \
- "Это тестовая фраза для проверки."
-```
-
-Клиент:
-1. Отправляет `init` с настройками.
-2. Разбивает текст на слова и шлёт их как потоковые `text`.
-3. Отправляет `flush`.
-4. Получает `audio`-чанки, декодирует base64 PCM16 и складывает в аудиобуфер.
-5. `sounddevice` воспроизводит аудио в реальном времени из callback.
-6. По `Ctrl+C` отправляет `stop` и выходит.
-
-### Браузерный клиент
-
-```bash
-# 1. Терминал 1: запустить TTS сервер (dummy — без GPU, любой бэкенд по выбору)
-TTS_BACKEND=dummy python -m voice_tts.main
-
-# 2. Терминал 2: открыть клиент через HTTP (file:// не даёт WebSocket в некоторых браузерах)
-python -m http.server 8080 --directory examples/
-# Открыть http://localhost:8080/client_browser.html
-```
-
-Нажмите **Connect**, затем **Speak streaming**. Клиент:
-
-1. Шлёт `init` с настройками (язык, скорость, эмоция).
-2. Шлёт слова по одному как потоковые `text` с задержкой 120 мс.
-3. Завершает `flush`.
-4. Получает `audio`-чанки с динамическим `sample_rate` (поддерживается любой бэкенд).
-5. Декодирует PCM16 из base64 и ставит в очередь `AudioBuffer` для бесшовного воспроизведения.
-
-Кнопка **Stop** отправляет `stop`. Кнопка **Test audio** проверяет звук в браузере независимо от сервера.
-
-## Настройка через переменные окружения (.env)
-
-| Переменная | Описание | По умолчанию |
-|------------|----------|--------------|
-| `HOST` | Хост сервера | `0.0.0.0` |
-| `PORT` | Порт сервера | `8765` |
-| `LOG_LEVEL` | Уровень логирования | `INFO` |
-| `TTS_BACKEND` | Бэкенд (`dummy` / `fish_speech` / `xtts_v2`) | `fish_speech` |
-| `TTS_MODEL_PATH` | Папка с checkpoint Fish Speech / XTTS | — |
-| `TTS_VOCAB_PATH` | Исходники Fish Speech v1.5.1 | `models/fish-speech-v1.5.1` |
-| `TTS_SAMPLE_RATE` | Частота дискретизации | `44100` |
-| `TTS_SPEED` | Множитель скорости речи | `1.2` |
-| `VOICES_DIR` | Директория с референсами | `voices` |
-| `DEFAULT_VOICE_REF` | Референс по умолчанию | — |
-| `DEFAULT_REF_TEXT` | Точный текст референса (skip Whisper) | — |
-| `MIN_SEGMENT_LENGTH` | Мин. длина сегмента | `30` |
-| `MAX_SEGMENT_LENGTH` | Макс. длина сегмента | `200` |
-| `MAX_BUFFER_WAIT_MS` | Макс. ожидание перед flush | `500` |
-| `DEVICE` | `cuda` или `cpu` | `cuda` |
-| `DTYPE` | `bfloat16` / `float16` / `float32` | `bfloat16` |
-| `FISH_COMPILE` | `torch.compile` для Fish Speech | `false` |
-| `FISH_CHUNK_LENGTH` | Длина LLM-чанка Fish Speech | `200` |
-| `FISH_USE_MEMORY_CACHE` | Кэшировать VQ референса | `on` |
-| `WARMUP` | Прогреть CUDA и кэшировать референс | `false` |
-| `WARMUP_TEXT` | Текст для warm-up | `Привет. Это тестовая фраза.` |
-
-## Модели
-
-### Fish Speech 1.5
-
-По умолчанию используется локальный checkpoint `models/fishaudio_fish-speech-1.5/`:
-
-- `model.pth` — LLaMA языковая модель,
-- `firefly-gan-vq-fsq-8x1024-21hz-generator.pth` — VQ-GAN декодер,
-- `tokenizer.tiktoken`, `config.json`, `special_tokens.json`.
-
-Исходный код Fish Speech v1.5.1 должен лежать в `models/fish-speech-v1.5.1/`,
-чтобы Python мог импортировать нужные модули.
-
-### XTTS-v2
-
-Coqui-модель `tts_models/multilingual/multi-dataset/xtts_v2` скачивается
-автоматически при первом запуске `TTS_BACKEND=xtts_v2`. Можно указать
-локальный путь через `TTS_MODEL_PATH`.
-
## Референсные аудио
-Поместите файлы в директорию `voices/`:
-
+Поместите в `voices/`:
```
voices/
-├── self_ref_clean.wav
-├── self_ref_clean.lab
-├── default_neutral.wav
-├── default_happy.wav
-├── default_sad.wav
+├── product_voice_sample_clean.wav
+├── product_voice_sample_clean.lab (точный транскрипт для Fish Speech)
└── ...
```
-Требования к референсу:
-- WAV или другой формат, читаемый `torchaudio`.
-- Моно, 16+ кГц.
-- Длина 5–15 секунд для Fish Speech, 3–10 секунд для XTTS-v2.
-- Чистая речь одного спикера без фонового шума.
-- Для Fish Speech рядом с `.wav` можно положить `.lab` с точным транскриптом.
- Иначе сервер использует `DEFAULT_REF_TEXT` или Whisper-транскрипцию.
+Требования:
+- WAV, моно, 16+ кГц
+- 5–15 секунд чистой речи одного спикера
+- Для Fish Speech рядом с `.wav` можно положить `.lab` с транскриптом
## Тесты
```bash
-# Быстрые тесты без загрузки тяжёлых моделей
-python -m pytest tests/ -v
+pytest tests/ -v
```
-Для запуска в CI рекомендуется отдельно прогонять быстрые и тяжёлые тесты:
+## Добавление нового бэкенда
-```bash
-python -m pytest tests/test_segmenter.py tests/test_server.py -v
-python -m pytest tests/test_fish_speech_backend.py -v
+1. Создать файл `src/voice_tts/tts/my_model.py`
+2. Унаследоваться от `TTSEngine`, добавить декоратор:
+
+```python
+from voice_tts.tts import register as _register_backend
+from voice_tts.tts.engine import TTSEngine
+
+@_register_backend("my_model")
+class MyEngine(TTSEngine):
+ def __init__(self):
+ # читать настройки из voice_tts.config.settings
+
+ async def synthesize(self, text, ref_audio_path, language, speed, emotion, ref_text=None):
+ ...
+
+ async def warm_up(self):
+ ...
```
+
+3. Готово — бэкенд доступен через `TTS_BACKEND=my_model`
+
+## Переменные окружения
+
+Полный список — в `.env.example`.
diff --git a/docs/06_technical_notes.md b/docs/06_technical_notes.md
index 20ab2ef..9e0ced8 100644
--- a/docs/06_technical_notes.md
+++ b/docs/06_technical_notes.md
@@ -35,29 +35,58 @@
Для русского языка предложения обычно короче, чем на английском, поэтому `max_length` выбран консервативно.
-## F5-TTS: особенности интеграции
+## Fish Speech 1.5
-- Используется готовая модель `F5TTS_v1_Base` через официальный пакет `f5-tts`.
-- Модель автоматически загружается из Hugging Face при первом вызове `load()`;
- можно скачать заранее скриптом `scripts/download_f5_tts.py`.
-- Референсное аудио транскрибируется автоматически, если `ref_text` не задан.
-- Результат транскрипции и путь к обработанному аудио кэшируются по паре
- `(путь, эмоция)`, чтобы не повторять работу при смене сегментов.
-- Скорость регулируется параметром `speed` инференса F5-TTS.
-- Текущая целевая частота дискретизации — **24 kHz**; модель сама
- передискретирует референс при необходимости.
-- Инференс выполняется в отдельном потоке (`asyncio.to_thread`) с
- `asyncio.Lock`, чтобы не блокировать event loop и не допускать
- параллельных CUDA-вызовов. Все `send_text` сериализуются через
- отдельный `_send_lock`, чтобы избежать deadlock внутри `websockets`.
+### Особенности интеграции
+
+- Локальный checkpoint в `models/fishaudio_fish-speech-1.5/`:
+ `model.pth`, `firefly-gan-vq-fsq-8x1024-21hz-generator.pth`,
+ `tokenizer.tiktoken`, `config.json`, `special_tokens.json`.
+- Исходный код `models/fish-speech-v1.5.1/` добавляется в `sys.path` для импорта модулей.
+- Выходная частота дискретизации — **44,1 кГц**.
+- Точный транскрипт референса важен: используется `.lab` рядом с референсом,
+ затем `DEFAULT_REF_TEXT`, затем Whisper-транскрипция, затем placeholder.
+- Скорость регулируется resampling'ом после синтеза (`TTS_SPEED`).
+- Параллельные CUDA-вызовы сериализуются через `asyncio.Lock` и
+ `asyncio.to_thread`.
+
+### Тюнинг
+
+- `FISH_USE_MEMORY_CACHE=on` — кэшировать VQ-представление референса (включено).
+- `FISH_CHUNK_LENGTH` — длина LLM-чанка (100–300, по умолчанию 200).
+ Больше = длиннее связные куски, но выше задержка.
+- `FISH_COMPILE=true` — пытается включить `torch.compile`. **Не включать по умолчанию:**
+ при повторном инференсе возникает ошибка `accessing tensor output of CUDAGraphs that has been overwritten`.
+ Исследуется отдельно.
### Замеры задержки (RTX 3090, Python 3.11, CUDA 12.6)
-- Первый запуск без `DEFAULT_REF_TEXT`: ~5–6 с, большая часть уходит на Whisper-транскрипцию референса.
-- С `DEFAULT_REF_TEXT` и `WARMUP=true`: warm-up занимает ~2 с (загрузка модели + один инференс).
-- После warm-up с кэшированным референсом: первый audio-chunk ~1.1 с на коротком сегменте.
-- 4 сегмента подряд: первый finished ~1.1 с, последний ~4.4 с.
-- `stop` + возобновление работает без переподключения WebSocket.
+- Первый запуск без `DEFAULT_REF_TEXT`: ~5–6 с, большая часть уходит на Whisper-транскрипцию.
+- С `DEFAULT_REF_TEXT` и `WARMUP=true`: загрузка модели + один инференс.
+- RTF (real-time factor) Fish Speech ~1.4 на коротких сегментах: медленнее реального времени.
+- RTF XTTS-v2 ~0.34: быстрее реального времени.
+- Fish Speech даёт более естественную русскую интонацию, поэтому выбран по умолчанию.
+- XTTS-v2 — резервный быстрый бэкенд для сценариев, где задержка важнее качества.
+
+### Сравнение с XTTS-v2
+
+| Показатель | Fish Speech 1.5 | XTTS-v2 |
+|------------|-----------------|---------|
+| RTF | ~1.4 | ~0.34 |
+| Русская интонация | естественнее | приемлемо, акцент чаще |
+| Английская речь | хорошо | хорошо |
+| Размер weights | ~2 ГБ LLM + VQGAN | ~3 ГБ |
+| Sample rate | 44,1 кГц | 24 кГц |
+| Требует ref transcript | да, точный | да, но терпимее |
+| `torch.compile` | нестабилен | не применяется |
+
+## XTTS-v2
+
+- Coqui-модель `tts_models/multilingual/multi-dataset/xtts_v2`.
+- Автоматически скачивается при первом `TTS_BACKEND=xtts_v2`.
+- Можно указать локальный checkpoint через `TTS_MODEL_PATH`.
+- Sample rate 24 кГц; клиенты `examples/client_*.py`/`examples/client_browser.html`
+ настроены на динамический sample rate из сообщений `audio`.
## Мультиязычность
@@ -94,10 +123,11 @@
3. **Fine-tuning F5-TTS** — самый трудоёмкий, но даёт лучший контроль
над акцентом и языком.
-Сейчас оставляем F5-TTS как единственный backend, но архитектура
-(`TTSEngine` + `_BACKEND_MAP`) позволяет добавить fallback позже.
+Сейчас оставляем Fish Speech 1.5 как бэкенд по умолчанию для en/ru,
+XTTS-v2 — как быстрый резерв. Архитектура (`TTSEngine` + `_BACKEND_MAP`)
+позволяет добавить fallback позже.
- F5-TTS может не идеально произносить украинский / европейские языки из коробки — возможно потребуется fine-tuning или fallback. Сейчас протокол и сегментатор не ограничивают язык; качество зависит от самой модели.
- RTX 3060 (12 GB) подойдёт для базовой модели, но batch-size и длина референса придётся ограничивать.
- Быстрый `stop` во время CUDA kernel не прервёт уже запущенный kernel, но предотвратит отправку результата.
-- `main.py` создаёт engine до старта uvicorn; при `TTS_BACKEND=f5_tts` первый запуск может занять десяток секунд из-за загрузки модели и vocos.
+- `main.py` создаёт engine до старта uvicorn; при `TTS_BACKEND=fish_speech` первый запуск занимает десяток секунд из-за загрузки LLM и VQ-GAN.
diff --git a/examples/client_browser.html b/examples/client_browser.html
index a7f237e..e0da15a 100644
--- a/examples/client_browser.html
+++ b/examples/client_browser.html
@@ -3,61 +3,31 @@
- Voice TTS WebSocket Client
+ Voice TTS
- Voice TTS WebSocket Client
-
-
-
-
+ Voice TTS
+
+
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
+
+
-
-
+ Нажми Connect для запуска.
diff --git a/pyproject.toml b/pyproject.toml
index 7baa355..f56bcbe 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -12,7 +12,7 @@
"torch>=2.4.0",
"torchaudio>=2.4.0",
"loguru>=0.7.0",
- "f5-tts>=1.1.0",
+ "TTS>=0.25.0",
"soundfile>=0.12.0",
"pydub>=0.25.0",
"tqdm>=4.66.0",
@@ -52,3 +52,4 @@
[tool.pytest.ini_options]
asyncio_mode = "auto"
testpaths = ["tests"]
+addopts = "-p no:cacheprovider"
diff --git a/scripts/benchmark_backends.py b/scripts/benchmark_backends.py
new file mode 100644
index 0000000..4257ca6
--- /dev/null
+++ b/scripts/benchmark_backends.py
@@ -0,0 +1,220 @@
+"""Benchmark local TTS backends: Fish Speech 1.5 vs XTTS-v2.
+
+Measures:
+ - cold load time
+ - per-sentence synthesis time
+ - real-time factor (RTF)
+ - Whisper ASR accuracy
+
+Outputs WAV files and a JSON report in outputs/benchmark/.
+"""
+
+import asyncio
+import json
+import os
+import time
+from pathlib import Path
+
+import numpy as np
+import torch
+import torchaudio
+from loguru import logger
+
+# Ensure project source is importable.
+ROOT = Path(__file__).resolve().parent.parent
+if str(ROOT / "src") not in os.sys.path:
+ os.sys.path.insert(0, str(ROOT / "src"))
+
+from voice_tts.audio.formats import float_to_wav_bytes # noqa: E402
+from voice_tts.config import settings # noqa: E402
+from voice_tts.tts.fish_speech_backend import FishSpeechEngine # noqa: E402
+from voice_tts.tts.xtts_backend import XTTSv2Engine # noqa: E402
+
+
+SENTENCES = [
+ ("en", "Hello, this is a short English sentence for the benchmark."),
+ (
+ "ru",
+ "Добрый вечер, меня зовут Евгений. Это тестовое предложение для проверки качества синтеза.",
+ ),
+ (
+ "en",
+ "The quick brown fox jumps over the lazy dog, testing every letter of the alphabet.",
+ ),
+ (
+ "ru",
+ "Наша цель — сделать речь естественной и понятной на английском и русском языках.",
+ ),
+]
+
+
+def _load_whisper(model_name: str = "large-v3", device: str = "cuda"):
+ from faster_whisper import WhisperModel
+
+ model_path = ROOT / "models" / "faster-whisper" / model_name
+ model_path.parent.mkdir(parents=True, exist_ok=True)
+ if not model_path.exists():
+ logger.info("Downloading faster-whisper {} ...", model_name)
+ return WhisperModel(
+ model_name,
+ device=device if device.startswith("cuda") else "cpu",
+ compute_type="float16" if device.startswith("cuda") else "int8",
+ download_root=str(model_path.parent),
+ )
+
+
+def _transcribe(model, wav_path: Path) -> str:
+ segments, _ = model.transcribe(str(wav_path), language=None)
+ return " ".join(s.text.strip() for s in segments).strip()
+
+
+def _save_wav(audio: np.ndarray, sr: int, path: Path) -> None:
+ path.parent.mkdir(parents=True, exist_ok=True)
+ path.write_bytes(float_to_wav_bytes(audio, sr))
+
+
+def _normalise(text: str) -> str:
+ return " ".join(text.lower().replace(",", " ").replace(".", " ").replace("—", " ").split())
+
+
+def _wer(ref: str, hyp: str) -> float:
+ """Simple word-error rate."""
+ r = _normalise(ref).split()
+ h = _normalise(hyp).split()
+ if not r:
+ return 0.0 if not h else 1.0
+ # Levenshtein distance
+ prev = list(range(len(h) + 1))
+ for i, rc in enumerate(r, 1):
+ curr = [i]
+ for j, hc in enumerate(h, 1):
+ cost = 0 if rc == hc else 1
+ curr.append(min(curr[-1] + 1, prev[j] + 1, prev[j - 1] + cost))
+ prev = curr
+ return prev[-1] / len(r)
+
+
+async def _synth(engine, text: str, lang: str, ref: Path, speed: float = 1.0):
+ # Fish Speech has async synthesize, XTTS too; both are heavy CUDA sync.
+ # Run inside a thread to mimic server behavior.
+ def _run():
+ kwargs = dict(
+ text=text,
+ ref_audio_path=ref,
+ language=lang,
+ speed=speed,
+ emotion="neutral",
+ )
+ if isinstance(engine, FishSpeechEngine):
+ kwargs["ref_text"] = settings.default_ref_text
+ return asyncio.run(engine.synthesize(**kwargs))
+
+ return await asyncio.to_thread(_run)
+
+
+async def benchmark_backend(name: str, factory, ref: Path, output_dir: Path):
+ logger.info("Benchmarking {} ...", name)
+ report = {"backend": name, "sentences": [], "load_seconds": None}
+
+ t0 = time.perf_counter()
+ engine = factory()
+ engine.load()
+ report["load_seconds"] = round(time.perf_counter() - t0, 3)
+
+ # Optional warm-up to make per-sentence timing more representative.
+ warmup_text = "One two three." if name == "xtts_v2" else "Раз, два, три."
+ await _synth(engine, warmup_text, "en" if name == "xtts_v2" else "ru", ref, 1.0)
+
+ whisper_model = _load_whisper(device=settings.device)
+
+ for lang, text in SENTENCES:
+ t0 = time.perf_counter()
+ audio = await _synth(engine, text, lang, ref, settings.tts_speed)
+ elapsed = time.perf_counter() - t0
+
+ duration = len(audio) / engine.sample_rate if engine.sample_rate else 0.0
+ rtf = elapsed / duration if duration else 0.0
+
+ wav_path = output_dir / name / f"{lang}_{len(report['sentences'])}.wav"
+ _save_wav(audio, engine.sample_rate, wav_path)
+
+ hyp = _transcribe(whisper_model, wav_path)
+ wer = _wer(text, hyp)
+
+ report["sentences"].append(
+ {
+ "language": lang,
+ "text": text,
+ "duration_seconds": round(duration, 3),
+ "synth_seconds": round(elapsed, 3),
+ "rtf": round(rtf, 3),
+ "whisper": hyp,
+ "wer": round(wer, 3),
+ "wav": str(wav_path.relative_to(output_dir)),
+ }
+ )
+ logger.info(
+ "[{} {}] RTF={} dur={}s synth={}s WER={}",
+ name,
+ lang,
+ round(rtf, 3),
+ round(duration, 3),
+ round(elapsed, 3),
+ round(wer, 3),
+ )
+
+ return report
+
+
+def _fish_factory():
+ return FishSpeechEngine(
+ checkpoint_path=settings.tts_model_path,
+ source_root=settings.tts_vocab_path,
+ device=settings.device,
+ compile=settings.fish_compile,
+ use_memory_cache=settings.fish_use_memory_cache,
+ chunk_length=settings.fish_chunk_length,
+ )
+
+
+def _xtts_factory():
+ return XTTSv2Engine(
+ model_name=settings.tts_model_name,
+ device=settings.device,
+ )
+
+
+async def main():
+ output_dir = ROOT / "outputs" / "benchmark"
+ output_dir.mkdir(parents=True, exist_ok=True)
+
+ ref = Path(settings.default_voice_ref) if settings.default_voice_ref else None
+ if not ref or not ref.exists():
+ raise FileNotFoundError("Set DEFAULT_VOICE_REF to an existing reference WAV.")
+
+ reports = []
+
+ if "fish_speech" in os.environ.get("BENCHMARK_BACKENDS", "fish_speech,xtts_v2"):
+ reports.append(await benchmark_backend("fish_speech", _fish_factory, ref, output_dir))
+ if "xtts_v2" in os.environ.get("BENCHMARK_BACKENDS", "fish_speech,xtts_v2"):
+ reports.append(await benchmark_backend("xtts_v2", _xtts_factory, ref, output_dir))
+
+ summary_path = output_dir / "summary.json"
+ summary_path.write_text(json.dumps(reports, ensure_ascii=False, indent=2))
+ logger.info("Report saved to {}", summary_path)
+
+ # Print a quick Markdown table.
+ print("\n## Summary")
+ print("| Backend | Lang | Dur (s) | Synth (s) | RTF | WER |")
+ print("|---------|------|---------|-----------|-----|-----|")
+ for report in reports:
+ for s in report["sentences"]:
+ print(
+ f"| {report['backend']} | {s['language']} | "
+ f"{s['duration_seconds']} | {s['synth_seconds']} | "
+ f"{s['rtf']} | {s['wer']} |"
+ )
+
+
+if __name__ == "__main__":
+ asyncio.run(main())
diff --git a/scripts/benchmark_compile.py b/scripts/benchmark_compile.py
new file mode 100644
index 0000000..382020e
--- /dev/null
+++ b/scripts/benchmark_compile.py
@@ -0,0 +1,73 @@
+"""Quick test: Fish Speech with and without torch.compile on one sentence."""
+
+import asyncio
+import os
+import time
+from pathlib import Path
+
+ROOT = Path(__file__).resolve().parent.parent
+if str(ROOT / "src") not in os.sys.path:
+ os.sys.path.insert(0, str(ROOT / "src"))
+
+from voice_tts.config import settings
+from voice_tts.tts.fish_speech_backend import FishSpeechEngine
+
+
+TEXT = "The quick brown fox jumps over the lazy dog."
+LANG = "en"
+REF = Path(settings.default_voice_ref)
+
+
+async def run(compile: bool) -> float:
+ engine = FishSpeechEngine(
+ checkpoint_path=settings.tts_model_path,
+ source_root=settings.tts_vocab_path,
+ device=settings.device,
+ compile=compile,
+ chunk_length=settings.fish_chunk_length,
+ )
+ engine.load()
+
+ # First run (warm-up + compile)
+ t0 = time.perf_counter()
+ await asyncio.to_thread(
+ lambda: asyncio.run(
+ engine.synthesize(
+ text=TEXT,
+ ref_audio_path=REF,
+ language=LANG,
+ speed=1.0,
+ emotion="neutral",
+ ref_text=settings.default_ref_text,
+ )
+ )
+ )
+ first = time.perf_counter() - t0
+
+ # Second run
+ t0 = time.perf_counter()
+ await asyncio.to_thread(
+ lambda: asyncio.run(
+ engine.synthesize(
+ text=TEXT,
+ ref_audio_path=REF,
+ language=LANG,
+ speed=1.0,
+ emotion="neutral",
+ ref_text=settings.default_ref_text,
+ )
+ )
+ )
+ second = time.perf_counter() - t0
+ return first, second
+
+
+async def main():
+ f1, f2 = await run(compile=False)
+ print(f"compile=False first={f1:.2f}s second={f2:.2f}s")
+ c1, c2 = await run(compile=True)
+ print(f"compile=True first={c1:.2f}s second={c2:.2f}s")
+
+
+if __name__ == "__main__":
+ asyncio.run(main())
diff --git a/src/voice_tts/api/server.py b/src/voice_tts/api/server.py
index 27e0d57..cb392f7 100644
--- a/src/voice_tts/api/server.py
+++ b/src/voice_tts/api/server.py
@@ -24,20 +24,9 @@
from voice_tts.audio.formats import float_to_pcm16, pcm16_to_base64
from voice_tts.config import settings
from voice_tts.session.state import SessionState, VoiceProfile
-from voice_tts.tts.engine import DummyTTSEngine, TTSEngine
-from voice_tts.tts.f5_backend import F5TTSEngine
-from voice_tts.tts.fish_speech_backend import FishSpeechEngine
+from voice_tts.tts import create_engine as _create_tts_engine
+from voice_tts.tts.engine import TTSEngine
from voice_tts.tts.segmenter import Segmenter
-from voice_tts.tts.xtts_backend import XTTSv2Engine
-
-
-# Supported TTS backends
-_BACKEND_MAP: dict[str, type[TTSEngine]] = {
- "dummy": DummyTTSEngine,
- "f5_tts": F5TTSEngine,
- "xtts_v2": XTTSv2Engine,
- "fish_speech": FishSpeechEngine,
-}
class SessionManager:
@@ -49,12 +38,15 @@
self.segmenter = Segmenter(
min_length=settings.min_segment_length,
max_length=settings.max_segment_length,
+ fast_start_initial=settings.fast_start_initial,
+ fast_start_count=settings.fast_start_count,
)
self.state = SessionState(session_id="")
self._running = True
self._tasks: list[asyncio.Task] = []
self._send_lock = asyncio.Lock()
self._synth_lock = asyncio.Lock()
+ self._synth_tasks: set[asyncio.Task] = set()
async def run(self) -> None:
await self.ws.accept()
@@ -187,6 +179,10 @@
self.state.stop()
self.state.clear_buffer()
+ # Cancel all in-flight synthesis tasks.
+ for t in set(self._synth_tasks):
+ t.cancel()
+
# Drain pending audio queue
while not self.state.audio_queue.empty():
try:
@@ -230,8 +226,10 @@
ref_path = self.state.voice.ref_for(self.state.emotion)
+ # Track this task so stop can cancel it.
+ task = asyncio.current_task()
+ self._synth_tasks.add(task)
try:
- # Serialize GPU inference and run it off the main event loop.
async with self._synth_lock:
audio = await asyncio.to_thread(
SessionManager._sync_synthesize,
@@ -242,10 +240,14 @@
self.state.speed,
self.state.emotion,
)
+ except asyncio.CancelledError:
+ raise
except Exception as exc:
logger.exception("TTS synthesis failed")
await self._send(ErrorMessage(message=f"TTS failed: {exc}", seq=segment_seq))
return
+ finally:
+ self._synth_tasks.discard(task)
if self.state.is_stopped():
return
@@ -291,32 +293,14 @@
speed: float,
emotion: str,
) -> "np.ndarray":
- """Thread-safe wrapper around the synchronous TTS inference call."""
+ """Thread-safe wrapper around the TTS inference call.
+
+ All backends expose an async ``synthesize``. Inside a thread from
+ ``asyncio.to_thread`` there is no running loop, so we drive the
+ coroutine with a fresh transient event loop.
+ """
import numpy as np
- # The dummy backend is async; run it on a transient event loop in this thread.
- if isinstance(engine, DummyTTSEngine):
- try:
- loop = asyncio.get_running_loop()
- except RuntimeError:
- loop = None
-
- async def _run() -> "np.ndarray":
- return await engine.synthesize(
- text=text,
- ref_audio_path=ref_path,
- language=language,
- speed=speed,
- emotion=emotion,
- )
-
- if loop is not None:
- return asyncio.run_coroutine_threadsafe(_run(), loop).result()
- return asyncio.run(_run())
-
- # Backends expose an async synthesize method that blocks on CPU/CUDA work.
- # Inside a thread from asyncio.to_thread there is no running loop, so we
- # drive the coroutine with a fresh transient event loop.
kwargs: dict = dict(
text=text,
ref_audio_path=ref_path,
@@ -324,9 +308,7 @@
speed=speed,
emotion=emotion,
)
- if isinstance(engine, F5TTSEngine) and settings.default_ref_text:
- kwargs["ref_text"] = settings.default_ref_text
- if isinstance(engine, FishSpeechEngine) and settings.default_ref_text:
+ if settings.default_ref_text:
kwargs["ref_text"] = settings.default_ref_text
return asyncio.run(engine.synthesize(**kwargs))
@@ -342,35 +324,7 @@
def _create_engine() -> TTSEngine:
- backend = settings.tts_backend
- engine_cls = _BACKEND_MAP.get(backend)
- if engine_cls is None:
- raise RuntimeError(
- f"Unknown TTS backend: {backend}. "
- f"Available backends: {list(_BACKEND_MAP.keys())}"
- )
- if engine_cls is F5TTSEngine:
- engine = engine_cls(
- model=settings.tts_model_name,
- sample_rate=settings.tts_sample_rate,
- )
- elif engine_cls is XTTSv2Engine:
- engine = engine_cls(
- model_name=settings.tts_model_name,
- sample_rate=settings.tts_sample_rate,
- )
- elif engine_cls is FishSpeechEngine:
- engine = engine_cls(
- checkpoint_path=settings.tts_model_path,
- source_root=settings.tts_vocab_path,
- sample_rate=settings.tts_sample_rate,
- device=settings.device,
- compile=settings.fish_compile,
- use_memory_cache=settings.fish_use_memory_cache,
- chunk_length=settings.fish_chunk_length,
- )
- else:
- engine = engine_cls(sample_rate=settings.tts_sample_rate)
+ engine = _create_tts_engine()
if hasattr(engine, "load"):
engine.load()
return engine
diff --git a/src/voice_tts/config.py b/src/voice_tts/config.py
index 00d3788..207970f 100644
--- a/src/voice_tts/config.py
+++ b/src/voice_tts/config.py
@@ -10,11 +10,16 @@
log_level: str = "INFO"
# TTS model configuration
- tts_backend: str = "f5_tts" # or "dummy" for tests
- tts_model_name: str = "F5TTS_v1_Base" # env: TTS_MODEL_NAME
+ tts_backend: str = "fish_speech" # "dummy" / "f5_tts" / "xtts_v2" / "fish_speech"
+ # XTTS-v2 model name (Coqui model manager path); used when backend is xtts_v2.
+ tts_model_name: str = "tts_models/multilingual/multi-dataset/xtts_v2"
+ # Local checkpoint path. For Fish Speech this is the folder containing model.pth,
+ # firefly-gan-vq-fsq-8x1024-21hz-generator.pth, tokenizer.tiktoken, config.json, etc.
tts_model_path: Path | None = None
+ # Source tree path for Fish Speech modules (e.g. models/fish-speech-v1.5.1).
tts_vocab_path: Path | None = None
- tts_sample_rate: int = 24_000
+ tts_sample_rate: int = 44_100
+ tts_speed: float = 1.2 # env: TTS_SPEED
# Reference voices directory
voices_dir: Path = Path("voices")
@@ -23,6 +28,8 @@
min_segment_length: int = 30
max_segment_length: int = 200
max_buffer_wait_ms: int = 500
+ fast_start_initial: int = 12 # first segment threshold for lower latency
+ fast_start_count: int = 3 # how many segments use progressive sizing
# GPU / inference
device: str = "cuda" # or "cpu"
@@ -32,6 +39,20 @@
default_voice_ref: Path | None = None # env: DEFAULT_VOICE_REF
default_ref_text: str | None = None # env: DEFAULT_REF_TEXT
+ # S2-Pro backend settings
+ s2_api_url: str = "http://127.0.0.1:8081"
+
+ # Fish Speech-specific settings
+ fish_compile: bool = False # torch.compile the LLaMA model (slow first run)
+ fish_chunk_length: int = 200 # 100-300; higher = longer coherent chunks
+ fish_use_memory_cache: str = "on" # "on" / "off" reference VQ cache
+ fish_top_p: float = 0.7 # nucleus sampling (0-1); lower = more deterministic
+ fish_temperature: float = 0.7 # sampling temperature; lower = more stable
+ fish_repetition_penalty: float = 1.2 # >1 reduces repeated tokens
+ fish_seed: int | None = None # None = random; set for reproducible output
+ fish_tail_silence_threshold: float = 0.02 # trim trailing silence below this RMS
+ fish_lowpass_cutoff: int = 0 # Hz; low-pass filter output to reduce VQ noise (0 = off)
+
# Warm-up
warmup: bool = False # run a dummy inference at startup
warmup_text: str = "Привет. Это тестовая фраза."
diff --git a/src/voice_tts/tts/__init__.py b/src/voice_tts/tts/__init__.py
new file mode 100644
index 0000000..740e740
--- /dev/null
+++ b/src/voice_tts/tts/__init__.py
@@ -0,0 +1,62 @@
+"""TTS backend registry and factory.
+
+Each backend module self-registers via ``@register(name)``.
+Usage:
+
+ >>> from voice_tts.tts import create_engine
+ >>> engine = create_engine("s2") # by name
+ >>> engine = create_engine() # from settings.TTS_BACKEND
+"""
+
+from pathlib import Path
+
+from voice_tts.config import settings
+from voice_tts.tts.engine import TTSEngine, DummyTTSEngine
+
+_BACKENDS: dict[str, type[TTSEngine]] = {
+ "dummy": DummyTTSEngine,
+}
+
+
+def register(name: str):
+ """Decorator that registers a :class:`TTSEngine` subclass."""
+ def decorator(cls):
+ _BACKENDS[name] = cls
+ return cls
+ return decorator
+
+
+def get_backend(name: str | None = None) -> type[TTSEngine]:
+ """Look up a backend class by name (defaults to ``settings.tts_backend``)."""
+ name = name or settings.tts_backend
+ if name not in _BACKENDS:
+ raise RuntimeError(
+ f"Unknown TTS backend: {name}. "
+ f"Available backends: {list(_BACKENDS.keys())}"
+ )
+ return _BACKENDS[name]
+
+
+def list_backends() -> list[str]:
+ """Return names of all registered backends."""
+ return list(_BACKENDS.keys())
+
+
+def create_engine(name: str | None = None) -> TTSEngine:
+ """Create a TTS engine instance.
+
+ Each backend reads its configuration from ``settings``; no arguments
+ are passed to the constructor besides ``name``.
+ """
+ cls = get_backend(name)
+ return cls()
+
+
+# Lazy-import backends so they self-register via ``@register``.
+# Each submodule gracefully handles missing optional dependencies.
+from voice_tts.tts import ( # noqa: E402, F401
+ s2_backend,
+ f5_backend,
+ xtts_backend,
+ fish_speech_backend,
+)
diff --git a/src/voice_tts/tts/engine.py b/src/voice_tts/tts/engine.py
index 07fad11..4053d4e 100644
--- a/src/voice_tts/tts/engine.py
+++ b/src/voice_tts/tts/engine.py
@@ -20,6 +20,7 @@
language: str,
speed: float,
emotion: str,
+ ref_text: str | None = None,
) -> np.ndarray:
"""Return audio as float32 ndarray normalized to [-1, 1]."""
...
@@ -43,6 +44,7 @@
language: str,
speed: float,
emotion: str,
+ ref_text: str | None = None,
) -> np.ndarray:
duration_sec = max(0.5, len(text) * 0.08) / speed
num_samples = int(self.sample_rate * duration_sec)
@@ -52,6 +54,5 @@
audio *= np.hanning(num_samples)
return audio.astype(np.float32)
-
async def warm_up(self) -> None:
pass
diff --git a/src/voice_tts/tts/f5_backend.py b/src/voice_tts/tts/f5_backend.py
index e9cad82..c835793 100644
--- a/src/voice_tts/tts/f5_backend.py
+++ b/src/voice_tts/tts/f5_backend.py
@@ -4,6 +4,7 @@
from loguru import logger
from voice_tts.config import settings
+from voice_tts.tts import register as _register_backend
from voice_tts.tts.engine import TTSEngine
@@ -21,33 +22,34 @@
torchaudio = None
+@_register_backend("f5_tts")
class F5TTSEngine(TTSEngine):
"""F5-TTS backend for local GPU inference with voice cloning."""
def __init__(
self,
- model: str = "F5TTS_v1_Base",
- sample_rate: int = 24_000,
+ model: str | None = None,
+ sample_rate: int | None = None,
device: str | None = None,
- nfe_step: int = 32,
- cfg_strength: float = 2.0,
- sway_sampling_coef: float = -1.0,
- speed: float = 1.0,
- target_rms: float = 0.1,
- cross_fade_duration: float = 0.0,
- remove_silence: bool = False,
+ nfe_step: int | None = None,
+ cfg_strength: float | None = None,
+ sway_sampling_coef: float | None = None,
+ speed: float | None = None,
+ target_rms: float | None = None,
+ cross_fade_duration: float | None = None,
+ remove_silence: bool | None = None,
):
super().__init__()
- self.model_name = model
- self.sample_rate = sample_rate
+ self.model_name = model or settings.tts_model_name
+ self.sample_rate = sample_rate or settings.tts_sample_rate
self.device = device or settings.device
- self.nfe_step = nfe_step
- self.cfg_strength = cfg_strength
- self.sway_sampling_coef = sway_sampling_coef
- self.speed = speed
- self.target_rms = target_rms
- self.cross_fade_duration = cross_fade_duration
- self.remove_silence = remove_silence
+ self.nfe_step = nfe_step or 32
+ self.cfg_strength = cfg_strength or 2.0
+ self.sway_sampling_coef = sway_sampling_coef or -1.0
+ self.speed = speed or 1.0
+ self.target_rms = target_rms or 0.1
+ self.cross_fade_duration = cross_fade_duration or 0.0
+ self.remove_silence = remove_silence or False
self._f5: "F5TTS | None" = None
self._ref_cache: dict[str, tuple[str, str]] = {} # emotion_key -> (processed_audio_path, ref_text)
@@ -66,8 +68,18 @@
self.model_name,
self.device,
)
+ ckpt_file = str(settings.tts_model_path) if settings.tts_model_path else ""
+ vocab_file = str(settings.tts_vocab_path) if settings.tts_vocab_path else ""
+ logger.info(
+ "Loading F5-TTS model {} ckpt={} vocab={} ...",
+ self.model_name,
+ ckpt_file or "(default)",
+ vocab_file or "(default)",
+ )
self._f5 = F5TTS(
model=self.model_name,
+ ckpt_file=ckpt_file,
+ vocab_file=vocab_file,
device=self.device,
)
self.sample_rate = self._f5.target_sample_rate
diff --git a/src/voice_tts/tts/fish_speech_backend.py b/src/voice_tts/tts/fish_speech_backend.py
new file mode 100644
index 0000000..ced922d
--- /dev/null
+++ b/src/voice_tts/tts/fish_speech_backend.py
@@ -0,0 +1,281 @@
+"""Fish Speech 1.5 backend for local GPU inference with zero-shot voice cloning."""
+
+from pathlib import Path
+
+import numpy as np
+from loguru import logger
+
+from voice_tts.config import settings
+from voice_tts.tts import register as _register_backend
+from voice_tts.tts.engine import TTSEngine
+
+try:
+ import sys
+
+ import torch
+ import torchaudio
+
+ FISH_SPEECH_AVAILABLE = True
+except ImportError as exc:
+ logger.warning("torch/torchaudio not available for Fish Speech: {}", exc)
+ FISH_SPEECH_AVAILABLE = False
+ torch = None
+ torchaudio = None
+
+
+_FISH_SAMPLE_RATE = 44_100
+_DEFAULT_SOURCE_ROOT = Path("models/fish-speech")
+
+
+@_register_backend("fish_speech")
+class FishSpeechEngine(TTSEngine):
+ """Fish Speech 1.5 backend supporting English and Russian zero-shot TTS."""
+
+ sample_rate: int = _FISH_SAMPLE_RATE
+
+ def __init__(
+ self,
+ checkpoint_path: Path | str | None = None,
+ source_root: Path | str | None = None,
+ sample_rate: int | None = None,
+ device: str | None = None,
+ precision: torch.dtype | None = None,
+ compile: bool | None = None,
+ use_memory_cache: str | None = None,
+ chunk_length: int | None = None,
+ top_p: float | None = None,
+ temperature: float | None = None,
+ repetition_penalty: float | None = None,
+ seed: int | None = None,
+ tail_silence_threshold: float | None = None,
+ lowpass_cutoff: int | None = None,
+ ):
+ super().__init__()
+
+ if not FISH_SPEECH_AVAILABLE:
+ raise RuntimeError(
+ "Fish Speech backend requires torch/torchaudio and the Fish Speech "
+ "source tree at models/fish-speech"
+ )
+
+ self.sample_rate = sample_rate or settings.tts_sample_rate
+ self.device = device or settings.device
+ self.precision = precision or (
+ torch.bfloat16 if settings.dtype == "bfloat16" else torch.float32
+ )
+ self.compile = compile if compile is not None else settings.fish_compile
+ self.use_memory_cache = use_memory_cache or settings.fish_use_memory_cache
+ self.chunk_length = chunk_length or settings.fish_chunk_length
+ self.top_p = top_p if top_p is not None else settings.fish_top_p
+ self.temperature = temperature if temperature is not None else settings.fish_temperature
+ self.repetition_penalty = repetition_penalty if repetition_penalty is not None else settings.fish_repetition_penalty
+ self.seed = seed if seed is not None else settings.fish_seed
+ self.tail_silence_threshold = tail_silence_threshold if tail_silence_threshold is not None else settings.fish_tail_silence_threshold
+ self.lowpass_cutoff = lowpass_cutoff if lowpass_cutoff is not None else settings.fish_lowpass_cutoff
+ self.repetition_penalty = repetition_penalty
+ self.seed = seed
+ self.tail_silence_threshold = tail_silence_threshold
+ self.lowpass_cutoff = lowpass_cutoff
+
+ self.source_root = Path(
+ source_root or settings.tts_vocab_path or _DEFAULT_SOURCE_ROOT
+ )
+ self.checkpoint_path = Path(
+ checkpoint_path or settings.tts_model_path or "models/fishaudio_fish-speech-1.5"
+ )
+
+ self._llama_queue = None
+ self._decoder = None
+ self._engine = None
+ self._loaded = False
+ # Cache reference audio bytes/text per path to avoid repeated disk reads.
+ self._ref_cache: dict[Path, tuple[bytes, str]] = {}
+
+ def _ensure_source_path(self) -> None:
+ """Make sure the cloned Fish Speech source is on sys.path."""
+ root = str(self.source_root.resolve())
+ if root not in sys.path:
+ sys.path.insert(0, root)
+
+ def _is_supported_language(self, language: str) -> bool:
+ # Fish Speech 1.5 is multilingual; we expose English and Russian.
+ return language.lower() in {"en", "ru"}
+
+ def load(self) -> None:
+ if self._loaded:
+ return
+
+ self._ensure_source_path()
+
+ from fish_speech.inference_engine import TTSInferenceEngine
+ from fish_speech.models.text2semantic.inference import (
+ launch_thread_safe_queue,
+ )
+ from fish_speech.models.vqgan.inference import load_model as load_decoder_model
+
+ logger.info(
+ "Loading Fish Speech 1.5 from {} (source: {}) ...",
+ self.checkpoint_path,
+ self.source_root,
+ )
+
+ llama_checkpoint = self.checkpoint_path / "model.pth"
+ decoder_checkpoint = (
+ self.checkpoint_path
+ / "firefly-gan-vq-fsq-8x1024-21hz-generator.pth"
+ )
+
+ self._llama_queue = launch_thread_safe_queue(
+ checkpoint_path=str(self.checkpoint_path),
+ device=self.device,
+ precision=self.precision,
+ compile=self.compile,
+ )
+
+ self._decoder = load_decoder_model(
+ config_name="firefly_gan_vq",
+ checkpoint_path=str(decoder_checkpoint),
+ device=self.device,
+ )
+
+ self._engine = TTSInferenceEngine(
+ llama_queue=self._llama_queue,
+ decoder_model=self._decoder,
+ precision=self.precision,
+ compile=self.compile,
+ )
+
+ self.sample_rate = self._decoder.spec_transform.sample_rate
+ self._loaded = True
+ logger.info(
+ "Fish Speech 1.5 loaded. Output sample rate: {}", self.sample_rate
+ )
+
+ async def warm_up(self) -> None:
+ if not self._loaded:
+ self.load()
+ logger.info("Fish Speech warm-up skipped; first synthesis will warm the cache.")
+
+ def _normalize_language(self, language: str) -> str:
+ language = language.lower()
+ if language.startswith("ru"):
+ return "ru"
+ if language.startswith("en"):
+ return "en"
+ return language
+
+ async def synthesize(
+ self,
+ text: str,
+ ref_audio_path: Path | None,
+ language: str,
+ speed: float,
+ emotion: str,
+ ref_text: str | None = None,
+ ) -> np.ndarray:
+ if not self._loaded:
+ self.load()
+
+ if isinstance(ref_audio_path, str):
+ ref_audio_path = Path(ref_audio_path)
+
+ if ref_audio_path is None:
+ ref_audio_path = settings.default_voice_ref
+ if ref_audio_path is None:
+ raise ValueError("Fish Speech requires a reference audio file (voice_ref).")
+ if not ref_audio_path.exists():
+ raise FileNotFoundError(f"Reference audio not found: {ref_audio_path}")
+
+ lang = self._normalize_language(language)
+ if not self._is_supported_language(lang):
+ raise ValueError(
+ f"Language '{language}' is not supported by Fish Speech backend. "
+ "Currently enabled languages: en, ru."
+ )
+
+ from fish_speech.utils.schema import ServeReferenceAudio, ServeTTSRequest
+
+ # Cache reference audio bytes to avoid repeated disk reads. The transcript
+ # is cheap to resolve, so we re-evaluate fallback precedence every call.
+ cached_audio = self._ref_cache.get(ref_audio_path)
+ if cached_audio is None:
+ ref_audio = ref_audio_path.read_bytes()
+ self._ref_cache[ref_audio_path] = ref_audio
+ else:
+ ref_audio = cached_audio
+
+ # Reference transcript precedence: caller-provided > .lab > settings > placeholder.
+ if ref_text:
+ pass
+ else:
+ ref_text_path = ref_audio_path.with_suffix(".lab")
+ if ref_text_path.exists():
+ ref_text = ref_text_path.read_text(encoding="utf-8").strip()
+ elif settings.default_ref_text:
+ ref_text = settings.default_ref_text
+ else:
+ ref_text = (
+ "Hello, this is a reference voice recording."
+ if lang == "en"
+ else "Здравствуйте. Это тестовая запись голоса."
+ )
+
+ req = ServeTTSRequest(
+ text=text,
+ references=[
+ ServeReferenceAudio(audio=ref_audio, text=ref_text)
+ ],
+ seed=self.seed,
+ top_p=self.top_p,
+ temperature=self.temperature,
+ repetition_penalty=self.repetition_penalty,
+ chunk_length=self.chunk_length,
+ use_memory_cache=self.use_memory_cache,
+ )
+
+ segments = []
+ for result in self._engine.inference(req):
+ logger.debug(
+ "Fish Speech inference result: code={} error={}",
+ result.code,
+ result.error,
+ )
+ if result.code == "error":
+ error = result.error or RuntimeError("Unknown Fish Speech error")
+ raise RuntimeError(f"Fish Speech synthesis failed: {error}")
+ if result.audio is not None:
+ sr, audio = result.audio
+ segments.append(audio.astype(np.float32))
+
+ if not segments:
+ raise RuntimeError("Fish Speech produced no audio.")
+
+ wav = np.concatenate(segments)
+
+ # Normalize to [-1, 1].
+ peak = np.max(np.abs(wav))
+ if peak > 1.0:
+ wav = wav / peak
+
+ # Gentle low-pass to reduce VQ codec noise.
+ if self.lowpass_cutoff > 0:
+ wt = torch.from_numpy(wav).unsqueeze(0) # (1, samples)
+ wt = torchaudio.functional.lowpass_biquad(
+ wt, self.sample_rate, self.lowpass_cutoff
+ )
+ wav = wt.squeeze(0).numpy()
+
+ # Apply speed adjustment via resampling. Fish Speech is already
+ # reasonably fast; the default settings.tts_speed allows global tuning.
+ effective_speed = speed if speed is not None else settings.tts_speed
+ if effective_speed != 1.0 and effective_speed > 0:
+ new_rate = int(self.sample_rate * (1.0 / effective_speed))
+ resampler = torchaudio.transforms.Resample(self.sample_rate, new_rate)
+ wav = resampler(torch.from_numpy(wav)).numpy().astype(np.float32)
+
+ if wav.ndim == 0:
+ wav = np.zeros(1, dtype=np.float32)
+ elif wav.ndim > 1:
+ wav = wav.squeeze()
+
+ return wav
diff --git a/src/voice_tts/tts/s2_backend.py b/src/voice_tts/tts/s2_backend.py
new file mode 100644
index 0000000..804a4aa
--- /dev/null
+++ b/src/voice_tts/tts/s2_backend.py
@@ -0,0 +1,140 @@
+"""Fish Audio S2-Pro backend — HTTP client to the local S2 API server."""
+
+import io
+from pathlib import Path
+
+import numpy as np
+import soundfile as sf
+from loguru import logger
+
+from voice_tts.config import settings
+from voice_tts.tts import register as _register_backend
+from voice_tts.tts.engine import TTSEngine
+
+try:
+ import requests
+
+ S2_AVAILABLE = True
+except ImportError:
+ S2_AVAILABLE = False
+ requests = None
+
+_S2_SAMPLE_RATE = 44_100
+
+
+@_register_backend("s2")
+class S2Engine(TTSEngine):
+ """Fish Audio S2-Pro backend that delegates to a local S2 API server.
+
+ The S2 API server must be running separately on ``api_url``.
+ The reference audio is uploaded once and reused via ``reference_id``.
+ """
+
+ sample_rate: int = _S2_SAMPLE_RATE
+
+ def __init__(self):
+ super().__init__()
+ if not S2_AVAILABLE:
+ raise RuntimeError("S2 backend requires the 'requests' package")
+
+ self.api_url = settings.s2_api_url.rstrip("/")
+ self.sample_rate = settings.tts_sample_rate
+ self._ref_uploaded = False
+
+ # ------------------------------------------------------------------
+ # Internal helpers
+ # ------------------------------------------------------------------
+
+ def _upload_reference(self, ref_audio_path: Path) -> None:
+ """Upload the reference audio to the S2 server.
+
+ If a reference with id ``default`` already exists we try to delete it
+ first via HTTP, then fall back to removing its directory from the
+ filesystem (the S2 server stores references in ``references//``
+ relative to its working directory).
+ """
+ ref_text = settings.default_ref_text or ""
+
+ # Try HTTP delete first (works with most setups).
+ deleted = False
+ try:
+ r = requests.delete(
+ f"{self.api_url}/v1/references/delete",
+ json={"reference_id": "default"},
+ timeout=10,
+ )
+ deleted = r.status_code in (200, 404)
+ except Exception:
+ pass
+
+ # If HTTP delete didn't work, try removing the directory directly.
+ if not deleted:
+ candidate = Path("models/fish-speech/references/default")
+ if candidate.exists():
+ import shutil
+ shutil.rmtree(candidate)
+ logger.info("Removed stale reference directory: {}", candidate)
+
+ with open(ref_audio_path, "rb") as fh:
+ resp = requests.post(
+ f"{self.api_url}/v1/references/add",
+ data={"id": "default", "text": ref_text},
+ files={"audio": ("ref.wav", fh, "audio/wav")},
+ timeout=30,
+ )
+ if resp.status_code == 200:
+ logger.info("Reference 'default' uploaded to S2 server")
+ self._ref_uploaded = True
+ else:
+ raise RuntimeError(
+ f"S2 reference upload failed: {resp.status_code} {resp.text}"
+ )
+
+ # ------------------------------------------------------------------
+ # TTSEngine interface
+ # ------------------------------------------------------------------
+
+ async def synthesize(
+ self,
+ text: str,
+ ref_audio_path: Path | None,
+ language: str,
+ speed: float,
+ emotion: str,
+ ref_text: str | None = None,
+ ) -> np.ndarray:
+ if ref_audio_path is not None and not self._ref_uploaded:
+ self._upload_reference(ref_audio_path)
+
+ payload = {
+ "text": text,
+ "reference_id": "default",
+ "format": "wav",
+ "chunk_length": 300,
+ }
+
+ resp = requests.post(
+ f"{self.api_url}/v1/tts",
+ json=payload,
+ timeout=120,
+ )
+ if resp.status_code != 200:
+ raise RuntimeError(f"S2 TTS request failed: {resp.status_code} {resp.text}")
+
+ buf = io.BytesIO(resp.content)
+ audio, sr = sf.read(buf)
+ self.sample_rate = int(sr)
+ return audio.astype(np.float32)
+
+ async def warm_up(self) -> None:
+ """Check that the S2 API server is alive."""
+ try:
+ resp = requests.get(f"{self.api_url}/v1/health", timeout=5)
+ if resp.status_code != 200:
+ raise RuntimeError(f"S2 server health check failed: {resp.status_code}")
+ logger.info("S2 API server is healthy at {}", self.api_url)
+ except requests.ConnectionError as exc:
+ raise RuntimeError(
+ f"S2 API server not reachable at {self.api_url}. "
+ "Make sure the S2 server is running."
+ ) from exc
diff --git a/src/voice_tts/tts/segmenter.py b/src/voice_tts/tts/segmenter.py
index 26315a6..790a1f7 100644
--- a/src/voice_tts/tts/segmenter.py
+++ b/src/voice_tts/tts/segmenter.py
@@ -11,15 +11,24 @@
class Segmenter:
- """Splits streaming text into TTS-ready segments."""
+ """Splits streaming text into TTS-ready segments.
+
+ Supports progressive chunking: the first few segments use a lower
+ ``min_length`` so audio starts sooner (lower initial latency).
+ """
def __init__(
self,
min_length: int = 30,
max_length: int = 200,
+ fast_start_initial: int = 12,
+ fast_start_count: int = 3,
):
self.min_length = min_length
self.max_length = max_length
+ self._fast_start_initial = fast_start_initial
+ self._fast_start_count = fast_start_count
+ self._segments_emitted = 0
# End-of-sentence delimiters
self.sentence_breaks = re.compile(r"[.。!??!\n]+")
@@ -27,6 +36,20 @@
self.clause_breaks = re.compile(r"[,;:\-—()()]")
self.whitespace_re = re.compile(r"\s+")
+ def _effective_min_length(self) -> int:
+ """Minimum length before a segment is considered 'ready'.
+
+ Gradually ramps up from ``fast_start_initial`` to ``min_length``
+ over the first ``fast_start_count`` segments.
+ """
+ if self._segments_emitted < self._fast_start_count:
+ fraction = self._segments_emitted / max(self._fast_start_count - 1, 1)
+ return int(
+ self._fast_start_initial
+ + (self.min_length - self._fast_start_initial) * fraction
+ )
+ return self.min_length
+
def feed(self, buffer: str) -> tuple[str, list[Segment]]:
"""
Consume `buffer` and return (remaining_buffer, ready_segments).
@@ -34,13 +57,18 @@
Logic:
- Prefer cutting at sentence boundaries.
- If a segment grows beyond max_length, cut at the nearest clause boundary.
- - Segments shorter than min_length are returned only if the caller forces flush
- or the input is shorter than min_length but ends with a sentence break.
+ - Segments shorter than the effective min_length are returned only if
+ the caller forces flush or the input is shorter but ends with a
+ sentence break.
+ - Progressive chunking: early segments use a lower threshold so audio
+ starts sooner.
"""
segments: list[Segment] = []
remaining = buffer
while remaining:
+ min_len = self._effective_min_length()
+
# Find the first sentence boundary.
first_sentence_cut = -1
for match in self.sentence_breaks.finditer(remaining):
@@ -52,7 +80,7 @@
segment_text = remaining[:first_sentence_cut].strip()
# Case 1: sentence is long enough -> emit immediately.
- if len(segment_text) >= self.min_length:
+ if len(segment_text) >= min_len:
remaining = remaining[first_sentence_cut:].lstrip()
if segment_text:
segments.append(Segment(text=segment_text, is_final=True))
@@ -72,7 +100,7 @@
break
combined_text = remaining[:next_boundary].strip()
combined_cut = next_boundary
- if len(combined_text) >= self.min_length:
+ if len(combined_text) >= min_len:
remaining = remaining[combined_cut:].lstrip()
if combined_text:
segments.append(Segment(text=combined_text, is_final=True))
@@ -96,7 +124,7 @@
last_clause = -1
for match in self.clause_breaks.finditer(window):
pos = match.end()
- if pos >= self.min_length:
+ if pos >= min_len:
last_clause = pos
if last_clause != -1:
segment_text = remaining[:last_clause].strip()
@@ -108,6 +136,7 @@
# Nothing to cut yet; wait for more text.
break
+ self._segments_emitted += len(segments)
return remaining, segments
def flush(self, buffer: str) -> list[Segment]:
diff --git a/src/voice_tts/tts/utils.py b/src/voice_tts/tts/utils.py
index 099ee8e..5035268 100644
--- a/src/voice_tts/tts/utils.py
+++ b/src/voice_tts/tts/utils.py
@@ -10,20 +10,90 @@
"。", "!", "?", ";", ":",
}
+# Emoji range: all Unicode emoji blocks
+_EMOJI_PATTERN = re.compile(
+ "["
+ "\U0001f600-\U0001f64f" # emoticons
+ "\U0001f300-\U0001f5ff" # symbols & pictographs
+ "\U0001f680-\U0001f6ff" # transport & map symbols
+ "\U0001f1e0-\U0001f1ff" # flags (iOS)
+ "\U00002600-\U000027BF" # misc symbols, dingbats
+ "\U0001f900-\U0001f9ff" # supplemental symbols
+ "\U0001fa00-\U0001fa6f" # chess symbols
+ "\U0001fa70-\U0001faff" # symbols extended-A
+ "\U00002702-\U000027B0" # dingbats
+ "\U000024C2-\U00002500" # enclosed / geometric shapes
+ "\U00002B05-\U00002B55" # arrows
+ "\U0001d300-\U0001d7ff" # musical symbols, etc.
+ "\U0001f000-\U0001f02f" # mahjong tiles
+ "\U0001f030-\U0001f09f" # domino tiles
+ "\U00002100-\U0000214f" # letterlike symbols
+ "\U0001f0a0-\U0001f0ff" # playing cards
+ "\U0001f600-\U0001f64f" # emoticons (duplicate range for safety)
+ "\U0000FE00-\U0000FE0F" # variation selectors
+ "\U0000FE20-\U0000FE23" # combining half marks
+ "\U0000200D" # zero-width joiner
+ "\U0000200C" # zero-width non-joiner
+ "]+"
+)
+
+_MARKDOWN_PATTERN = re.compile(
+ r"```[\s\S]*?```" # fenced code blocks
+ r"|`[^`\n]+`" # inline code
+ r"|\[([^\]]+)\]\([^)]+\)" # markdown links → keep link text
+ r"|!\[([^\]]*)\]\([^)]+\)" # markdown images
+ r"|^#{1,6}\s+" # headings
+ r"|^>+\s+" # blockquotes
+ r"|^\s*[-*+]\s+(?![-*+])" # unordered lists
+ r"|^\s*\d+[.)]\s+" # ordered lists
+ r"|^\s*\|.*\|" # tables
+ r"|^[-=]{3,}\s*$" # horizontal rules
+ r"|(?:^|(?<=\s))\*{1,3}(?=\S)" # leading bold/italic
+ r"|(?<=\S)\*{1,3}(?=\s|$)" # trailing bold/italic
+ r"|(?:^|(?<=\s))_{1,3}(?=\S)" # leading underline emphasis
+ r"|(?<=\S)_{1,3}(?=\s|$)" # trailing underline emphasis
+ r"|~~(.*?)~~" # strikethrough
+ r"|(?:^|(?<=\s))~{1,3}(?=\S)" # leading strikethrough marker
+ r"|(?<=\S)~{1,3}(?=\s|$)" # trailing strikethrough marker
+, re.MULTILINE)
+
+_URL_PATTERN = re.compile(r"https?://[^\s<>\"']+|www\.[^\s<>\"']+")
+
+_HTML_PATTERN = re.compile(r"<[^>]+>")
+
+_SPECIAL_SYMBOLS = re.compile(r"[^\w\s.,!?;:\-—\"'()«»„“”‘’…\n]")
+
def normalize_whitespace(text: str) -> str:
"""Collapse repeated whitespace and strip edges, preserving single spaces."""
return re.sub(r"\s+", " ", text).strip()
+def clean_text_for_tts(text: str) -> str:
+ """Remove characters that TTS should never pronounce."""
+ text = re.sub(r"[\x00-\x08\x0b\x0c\x0e-\x1f\x7f]", "", text)
+ # Remove emojis
+ text = _EMOJI_PATTERN.sub("", text)
+ # Remove HTML tags
+ text = _HTML_PATTERN.sub("", text)
+ # Remove markdown formatting (keep link/image/strikethrough text)
+ text = _MARKDOWN_PATTERN.sub(
+ lambda m: next((g for g in m.groups() if g is not None), ""), text
+ )
+ # Remove URLs
+ text = _URL_PATTERN.sub("", text)
+ # Remove special Unicode symbols not used in normal text
+ text = _SPECIAL_SYMBOLS.sub("", text)
+ return normalize_whitespace(text)
+
+
def preprocess_text_for_tts(text: str) -> str:
"""
- Minimal cleanup before TTS.
+ Clean text before TTS synthesis.
+ - Remove control characters, emojis, HTML, markdown, URLs, special symbols.
- Collapse whitespace.
- - Remove control characters.
"""
- text = re.sub(r"[\x00-\x08\x0b\x0c\x0e-\x1f\x7f]", "", text)
- return normalize_whitespace(text)
+ return clean_text_for_tts(text)
def has_sentence_ending(text: str) -> bool:
diff --git a/src/voice_tts/tts/xtts_backend.py b/src/voice_tts/tts/xtts_backend.py
new file mode 100644
index 0000000..150da9d
--- /dev/null
+++ b/src/voice_tts/tts/xtts_backend.py
@@ -0,0 +1,205 @@
+"""XTTS-v2 backend for local GPU inference with zero-shot voice cloning."""
+
+from pathlib import Path
+
+import numpy as np
+from loguru import logger
+
+from voice_tts.config import settings
+from voice_tts.tts import register as _register_backend
+from voice_tts.tts.engine import TTSEngine
+
+try:
+ import torch
+ import torchaudio
+ from TTS.api import TTS
+ import TTS.utils.io as tts_io
+
+ XTTS_AVAILABLE = True
+except ImportError as exc:
+ logger.warning("TTS/torch dependencies not available: {}", exc)
+ XTTS_AVAILABLE = False
+ torch = None
+ torchaudio = None
+
+
+_XTTS_SAMPLE_RATE = 24_000
+
+
+def _patch_weights_only() -> None:
+ """XTTS-v2 checkpoints contain legacy pickle classes; allow full load.
+
+ PyTorch 2.6+ defaults ``torch.load(..., weights_only=True)``. XTTS-v2
+ checkpoints require several Coqui classes to be in the safe-globals list
+ in addition to forcing ``weights_only=False``. This function registers
+ those classes globally and patches TTS's fsspec loader to default to the
+ legacy behavior.
+ """
+ if not XTTS_AVAILABLE:
+ return
+
+ if not hasattr(tts_io, "load_fsspec"):
+ return
+
+ # PyTorch 2.6+ safe-global allow-list for XTTS-v2 checkpoint pickles.
+ if hasattr(torch, "serialization") and hasattr(
+ torch.serialization, "add_safe_globals"
+ ):
+ from TTS.config import shared_configs as _shared_configs
+ from TTS.tts.configs.xtts_config import XttsConfig as _XttsConfig
+ from TTS.tts.models.xtts import XttsArgs as _XttsArgs
+ from TTS.tts.models.xtts import XttsAudioConfig as _XttsAudioConfig
+
+ for _cls in (
+ _shared_configs.BaseDatasetConfig,
+ _XttsConfig,
+ _XttsArgs,
+ _XttsAudioConfig,
+ ):
+ try:
+ torch.serialization.add_safe_globals([_cls])
+ except Exception:
+ pass
+
+ _orig = tts_io.load_fsspec
+
+ def _patched(model_path: str, map_location: str = "cpu", **kwargs):
+ kwargs.setdefault("weights_only", False)
+ return _orig(model_path, map_location=map_location, **kwargs)
+
+ tts_io.load_fsspec = _patched # type: ignore[assignment]
+
+
+@_register_backend("xtts_v2")
+class XTTSv2Engine(TTSEngine):
+ """XTTS-v2 backend supporting English and Russian zero-shot voice cloning."""
+
+ sample_rate: int = _XTTS_SAMPLE_RATE
+
+ def __init__(
+ self,
+ model_name: str | None = None,
+ sample_rate: int | None = None,
+ device: str | None = None,
+ gpu: bool | None = None,
+ ):
+ super().__init__()
+ self.model_name = model_name or settings.tts_model_name
+ self.sample_rate = sample_rate or settings.tts_sample_rate
+ self.device = device or settings.device
+ self.gpu = (gpu if gpu is not None else True) and self.device != "cpu"
+
+ self._model: "TTS | None" = None
+
+ def _is_supported_language(self, language: str) -> bool:
+ # XTTS-v2 supported languages: en, es, fr, de, it, pt, pl, tr, ru, nl, cs,
+ # ar, zh-cn, hu, ko, ja, hi. We currently expose en and ru.
+ return language.lower() in {"en", "ru"}
+
+ def load(self) -> None:
+ if not XTTS_AVAILABLE:
+ raise RuntimeError(
+ "coqui TTS package is not installed. Install it: pip install TTS"
+ )
+
+ logger.info("Loading XTTS-v2 model {} ...", self.model_name)
+ _patch_weights_only()
+
+ # Environment flag required by Coqui to download XTTS.
+ import os
+
+ os.environ.setdefault("COQUI_TOS_AGREED", "1")
+
+ self._model = TTS(self.model_name, gpu=self.gpu)
+ # Force the requested device if not already there.
+ if self.device.startswith("cuda"):
+ self._model = self._model.to("cuda")
+ elif self.device == "cpu":
+ self._model = self._model.to("cpu")
+
+ self.sample_rate = self._model.synthesizer.output_sample_rate or _XTTS_SAMPLE_RATE
+ logger.info("XTTS-v2 loaded. Output sample rate: {}", self.sample_rate)
+
+ async def warm_up(self) -> None:
+ if self._model is None:
+ self.load()
+ logger.info("Warm-up skipped: provide a reference audio before warm-up.")
+
+ def _normalize_language(self, language: str) -> str:
+ language = language.lower()
+ if language.startswith("ru"):
+ return "ru"
+ if language.startswith("en"):
+ return "en"
+ return language
+
+ async def synthesize(
+ self,
+ text: str,
+ ref_audio_path: Path | None,
+ language: str,
+ speed: float,
+ emotion: str,
+ ref_text: str | None = None,
+ ) -> np.ndarray:
+ if self._model is None:
+ self.load()
+
+ assert self._model is not None
+
+ if isinstance(ref_audio_path, str):
+ ref_audio_path = Path(ref_audio_path)
+
+ if ref_audio_path is None:
+ raise ValueError("XTTS-v2 requires a reference audio file (voice_ref).")
+ if not ref_audio_path.exists():
+ raise FileNotFoundError(f"Reference audio not found: {ref_audio_path}")
+
+ lang = self._normalize_language(language)
+ if not self._is_supported_language(lang):
+ raise ValueError(
+ f"Language '{language}' is not supported by XTTS-v2. "
+ "Currently enabled languages: en, ru."
+ )
+
+ # Speed is not a direct argument for XTTS.tts_to_file, but we can resample
+ # the audio to approximate it. We keep the produced samples and stretch them.
+ out_path = "/tmp/opencode/xtts_synth_tmp.wav"
+ self._model.tts_to_file(
+ text=text,
+ speaker_wav=str(ref_audio_path),
+ language=lang,
+ file_path=out_path,
+ )
+
+ wav, sr = torchaudio.load(out_path)
+ wav = wav.mean(dim=0).numpy().astype(np.float32)
+
+ # Resample if the model returned a different rate.
+ if sr != self.sample_rate:
+ resampler = torchaudio.transforms.Resample(sr, self.sample_rate)
+ wav = resampler(torch.from_numpy(wav)).numpy().astype(np.float32)
+
+ # Normalize to [-1, 1].
+ peak = np.max(np.abs(wav))
+ if peak > 1.0:
+ wav = wav / peak
+
+ # Approximate speed change by resampling. Simple resampling changes
+ # pitch slightly; for small adjustments around 1.0 it is acceptable,
+ # but for larger factors we keep the duration change while reducing
+ # pitch shift artifacts via phase vocoder would be too expensive.
+ # The default speed comes from settings.tts_speed and can be overridden
+ # per-request via the WebSocket config/speak messages.
+ effective_speed = speed if speed is not None else settings.tts_speed
+ if effective_speed != 1.0 and effective_speed > 0:
+ new_rate = int(self.sample_rate * (1.0 / effective_speed))
+ resampler = torchaudio.transforms.Resample(self.sample_rate, new_rate)
+ wav = resampler(torch.from_numpy(wav)).numpy().astype(np.float32)
+
+ if wav.ndim == 0:
+ wav = np.zeros(1, dtype=np.float32)
+ elif wav.ndim > 1:
+ wav = wav.squeeze()
+
+ return wav
diff --git a/tests/test_fish_speech_backend.py b/tests/test_fish_speech_backend.py
new file mode 100644
index 0000000..ec356e0
--- /dev/null
+++ b/tests/test_fish_speech_backend.py
@@ -0,0 +1,233 @@
+"""Unit tests for the Fish Speech 1.5 backend with mocked inference engine."""
+
+from pathlib import Path
+
+import numpy as np
+import pytest
+import torch
+
+from voice_tts.config import settings
+from voice_tts.tts.fish_speech_backend import FishSpeechEngine
+
+
+@pytest.fixture
+def mocked_fish(monkeypatch, tmp_path):
+ """Return a FishSpeechEngine with all heavy model imports stubbed."""
+
+ fake_root = tmp_path / "fish-speech"
+ fake_root.mkdir()
+ (fake_root / "fish_speech").mkdir()
+ (fake_root / "fish_speech" / "__init__.py").write_text("")
+ (fake_root / "fish_speech" / "inference_engine.py").write_text(
+ "class TTSInferenceEngine:\n pass\n"
+ )
+ (fake_root / "fish_speech" / "inference_engine").mkdir()
+ (fake_root / "fish_speech" / "inference_engine" / "__init__.py").write_text(
+ "class TTSInferenceEngine:\n pass\n"
+ )
+ (fake_root / "fish_speech" / "models").mkdir()
+ (fake_root / "fish_speech" / "models" / "__init__.py").write_text("")
+ (fake_root / "fish_speech" / "models" / "text2semantic").mkdir()
+ (fake_root / "fish_speech" / "models" / "text2semantic" / "__init__.py").write_text("")
+ (fake_root / "fish_speech" / "models" / "text2semantic" / "inference.py").write_text(
+ "def launch_thread_safe_queue(*args, **kwargs):\n"
+ " return object()\n"
+ )
+ vqgan = fake_root / "fish_speech" / "models" / "vqgan"
+ vqgan.mkdir()
+ (vqgan / "__init__.py").write_text("")
+ (vqgan / "inference.py").write_text(
+ "def load_model(*args, **kwargs):\n"
+ " class FakeDecoder:\n"
+ " spec_transform = type('T', (), {'sample_rate': 44100})()\n"
+ " def encode(self, audios, lengths):\n"
+ " return [None]\n"
+ " return FakeDecoder()\n"
+ )
+ (fake_root / "fish_speech" / "utils").mkdir()
+ (fake_root / "fish_speech" / "utils" / "__init__.py").write_text("")
+ (fake_root / "fish_speech" / "utils" / "schema.py").write_text(
+ "from pydantic import BaseModel\n"
+ "from typing import List\n"
+ "class ServeReferenceAudio(BaseModel):\n"
+ " audio: bytes\n"
+ " text: str\n"
+"class ServeTTSRequest(BaseModel):\n"
+" text: str\n"
+" references: List[ServeReferenceAudio] = []\n"
+" seed: int | None = None\n"
+" top_p: float = 0.7\n"
+" temperature: float = 0.7\n"
+" repetition_penalty: float = 1.2\n"
+" max_new_tokens: int = 1024\n"
+" chunk_length: int = 200\n"
+" use_memory_cache: str = 'on'\n"
+ )
+
+ # Stub the heavy model functions used inside load().
+ class FakeEngine:
+ def __init__(self, *args, **kwargs):
+ pass
+
+ def inference(self, req):
+ sample_rate = 44_100
+ audio = np.linspace(-0.9, 0.9, sample_rate, dtype=np.float32)
+ yield type(
+ "R", (), {"code": "final", "error": None, "audio": (sample_rate, audio)}
+ )()
+
+ # Point sys.path at fake root temporarily so imports succeed.
+ monkeypatch.syspath_prepend(str(fake_root))
+
+ from fish_speech import inference_engine as _ie
+
+ monkeypatch.setattr(_ie, "TTSInferenceEngine", FakeEngine)
+
+ ref_wav = tmp_path / "ref.wav"
+ ref_wav.write_bytes(b"RIFF" + b"\x00" * 40)
+ ref_lab = tmp_path / "ref.lab"
+ ref_lab.write_text("Reference transcript from lab file.", encoding="utf-8")
+
+ engine = FishSpeechEngine(
+ checkpoint_path=tmp_path / "checkpoint",
+ source_root=fake_root,
+ device="cpu",
+ precision=torch.float32,
+ chunk_length=200,
+ )
+ engine.load()
+ return engine, ref_wav, ref_lab
+
+
+@pytest.mark.asyncio
+async def test_synthesize_returns_audio(mocked_fish):
+ engine, ref_wav, _ = mocked_fish
+ audio = await engine.synthesize(
+ text="Hello world.",
+ ref_audio_path=ref_wav,
+ language="en",
+ speed=1.0,
+ emotion="neutral",
+ )
+ assert audio.ndim == 1
+ assert len(audio) == engine.sample_rate
+ assert audio.dtype == np.float32
+ assert np.max(np.abs(audio)) <= 1.1
+
+
+@pytest.mark.asyncio
+async def test_unsupported_language_raises(mocked_fish):
+ engine, ref_wav, _ = mocked_fish
+ with pytest.raises(ValueError, match="not supported"):
+ await engine.synthesize(
+ text="Bonjour.",
+ ref_audio_path=ref_wav,
+ language="fr",
+ speed=1.0,
+ emotion="neutral",
+ )
+
+
+@pytest.mark.asyncio
+async def test_missing_reference_raises(mocked_fish):
+ engine, _, _ = mocked_fish
+ with pytest.raises(FileNotFoundError):
+ await engine.synthesize(
+ text="Hello.",
+ ref_audio_path=Path("/nonexistent/ref.wav"),
+ language="en",
+ speed=1.0,
+ emotion="neutral",
+ )
+
+
+@pytest.mark.asyncio
+async def test_ref_text_fallback_order(monkeypatch, mocked_fish):
+ engine, ref_wav, ref_lab = mocked_fish
+
+ captured = {}
+
+ def _spy_inference(req):
+ captured["text"] = req.references[0].text
+ sample_rate = 44_100
+ audio = np.zeros(sample_rate, dtype=np.float32)
+ yield type(
+ "R", (), {"code": "final", "error": None, "audio": (sample_rate, audio)}
+ )()
+
+ # Replace the already-created engine's inference method with a spy.
+ engine._engine.inference = _spy_inference
+
+ # 1. Explicit ref_text wins.
+ await engine.synthesize(
+ text="Hi.",
+ ref_audio_path=ref_wav,
+ language="en",
+ speed=1.0,
+ emotion="neutral",
+ ref_text="explicit",
+ )
+ assert captured["text"] == "explicit"
+
+ # 2. .lab file next to reference is used when no explicit text.
+ await engine.synthesize(
+ text="Hi.",
+ ref_audio_path=ref_wav,
+ language="en",
+ speed=1.0,
+ emotion="neutral",
+ )
+ assert captured["text"] == "Reference transcript from lab file."
+
+ # 3. settings.default_ref_text wins over placeholder.
+ monkeypatch.setattr(settings, "default_ref_text", "settings default")
+ ref_lab.unlink(missing_ok=True)
+ await engine.synthesize(
+ text="Hi.",
+ ref_audio_path=ref_wav,
+ language="en",
+ speed=1.0,
+ emotion="neutral",
+ )
+ assert captured["text"] == "settings default"
+
+ # 4. Placeholder when nothing else is available.
+ monkeypatch.setattr(settings, "default_ref_text", None)
+ await engine.synthesize(
+ text="Hi.",
+ ref_audio_path=ref_wav,
+ language="en",
+ speed=1.0,
+ emotion="neutral",
+ )
+ assert captured["text"] == "Hello, this is a reference voice recording."
+
+
+@pytest.mark.asyncio
+async def test_speed_resampling_changes_length(mocked_fish):
+ engine, ref_wav, _ = mocked_fish
+ slow = await engine.synthesize(
+ text="Hello.",
+ ref_audio_path=ref_wav,
+ language="en",
+ speed=1.5,
+ emotion="neutral",
+ )
+ fast = await engine.synthesize(
+ text="Hello.",
+ ref_audio_path=ref_wav,
+ language="en",
+ speed=0.8,
+ emotion="neutral",
+ )
+ assert len(fast) > len(slow)
+ # Speed 1.5 should produce fewer samples than original; 0.8 should produce more.
+ normal = await engine.synthesize(
+ text="Hello.",
+ ref_audio_path=ref_wav,
+ language="en",
+ speed=1.0,
+ emotion="neutral",
+ )
+ assert len(slow) < len(normal)
+ assert len(fast) > len(normal)