diff --git a/.env.example b/.env.example index b875020..d28298b 100644 --- a/.env.example +++ b/.env.example @@ -5,27 +5,47 @@ PORT=8765 LOG_LEVEL=INFO -TTS_BACKEND=f5_tts -# F5-TTS model name. Built-in options: F5TTS_v1_Base, F5TTS_v1_Small. -# Downloaded automatically on first use via HuggingFace. -TTS_MODEL_NAME=F5TTS_v1_Base -# TTS_MODEL_PATH=models/f5-tts/model.pt -# TTS_VOCAB_PATH=models/f5-tts/vocab.txt -TTS_SAMPLE_RATE=24000 +# Available backends: +# dummy — sine-wave test (no GPU, no deps) +# s2 — Fish Audio S2-Pro INT4 (требует S2 API сервер на порту 8081) +# fish_speech — Fish Speech 1.5 (требует чекпоинт fish-speech) +# f5_tts — F5-TTS v1 (экспериментальный) +# xtts_v2 — XTTS-v2 от Coqui (авто-загрузка, ~2 GB VRAM) +TTS_BACKEND=s2 + +# Paths for Fish Speech / XTTS backends +TTS_MODEL_PATH=models/fishaudio_fish-speech-1.5 +TTS_VOCAB_PATH=models/fish-speech +TTS_MODEL_NAME=tts_models/multilingual/multi-dataset/xtts_v2 +TTS_SAMPLE_RATE=44100 +TTS_SPEED=1.5 + +# Fish Speech tuning +FISH_COMPILE=false +FISH_CHUNK_LENGTH=300 +FISH_USE_MEMORY_CACHE=on +FISH_TOP_P=0.7 +FISH_TEMPERATURE=0.7 +FISH_REPETITION_PENALTY=1.2 +FISH_TAIL_SILENCE_THRESHOLD=0 +FISH_LOWPASS_CUTOFF=0 VOICES_DIR=voices -# Path to default reference audio (relative to project root or absolute). -# Providing DEFAULT_VOICE_REF enables instant warm-up and voice cloning. -DEFAULT_VOICE_REF=voices/rick_ref_clean.wav -# Exact transcript of the reference audio. When set, Whisper transcription is skipped. -DEFAULT_REF_TEXT="Ва-ба-ла-ба-дап-дап! Рикки-тики-тави, сученька! И вот такие у нас новости! Иди." +DEFAULT_VOICE_REF=voices/default_ref.wav +DEFAULT_REF_TEXT="" +# Segmentation (прогрессивный порог: первый сегмент короче) MIN_SEGMENT_LENGTH=30 -MAX_SEGMENT_LENGTH=200 +MAX_SEGMENT_LENGTH=500 MAX_BUFFER_WAIT_MS=500 +FAST_START_INITIAL=12 +FAST_START_COUNT=3 DEVICE=cuda DTYPE=bfloat16 -# Run a dummy inference at startup to cache reference and prime CUDA. WARMUP=true +WARMUP_TEXT="Привет. Это тестовая фраза." + +# S2 API server URL +S2_API_URL=http://127.0.0.1:8081 diff --git a/AGENTS.md b/AGENTS.md index c662d52..d16ed6b 100644 --- a/AGENTS.md +++ b/AGENTS.md @@ -7,6 +7,9 @@ - **Run server (Fish Speech default):** `python -m voice_tts.main` - **Dummy backend for fast local tests:** `TTS_BACKEND=dummy python -m voice_tts.main` - **XTTS-v2 backend:** `TTS_BACKEND=xtts_v2 python -m voice_tts.main` +- **Fish Audio S2 backend:** `TTS_BACKEND=s2 python -m voice_tts.main` (требует S2 API сервер на порту 8081) +- **S2 API server (INT4, без compile):** `cd models/fish-speech && .venv/bin/python tools/api_server.py --llama-checkpoint-path checkpoints/fs-1.2-int4-g128-s2-pro-nf4 --decoder-checkpoint-path checkpoints/s2-pro/codec.pth --listen 127.0.0.1:8081` +- **S2 API server (INT4, compile ~1× real-time):** `cd models/fish-speech && .venv/bin/python tools/api_server.py --llama-checkpoint-path checkpoints/fs-1.2-int4-g128-s2-pro-nf4 --decoder-checkpoint-path checkpoints/s2-pro/codec.pth --listen 127.0.0.1:8081 --compile` — первый ~3.5 мин, потом ~51 tok/s, **~9.6 GB VRAM** (вмещается в 12 GB) - **Console script (installed):** `voice-tts` - **Health check:** `curl http://localhost:8765/health` - **Browser test client:** `cd examples && python -m http.server 8080` → открыть `http://localhost:8080/client_browser.html` @@ -55,9 +58,17 @@ | `TTS_MODEL_PATH` | — | Fish Speech checkpoint folder (contains model.pth, firefly-gan-vq-fsq-8x1024-21hz-generator.pth, tokenizer.tiktoken, config.json) | | `TTS_VOCAB_PATH` | — | Fish Speech v1.5 source tree path (used to import firefly_gan / FSQ modules) | | `TTS_MODEL_NAME` | `tts_models/multilingual/multi-dataset/xtts_v2` | Coqui model manager path; xtts_v2 downloads this on first use | -| `FISH_COMPILE` | `false` | Avoid setting to true. Enables torch.compile but causes CUDAGraphs tensor-overwrite errors on repeated inference. | | `FISH_CHUNK_LENGTH` | 200 | Chunk length for Fish Speech (100–300). Higher = more GPU work per call, higher latency. | +### S2-Pro INT4 + +- Чекпоинт: `models/fish-speech/checkpoints/fs-1.2-int4-g128-s2-pro-nf4/` (4.9 ГБ model.pth, INT4 quantized, groupsize=128) +- `config.json:` в чекпоинте имеет `text_config.max_seq_len=4096` — KV cache ~0.67 GB (было 5.36 GB при 32768) +- **VRAM:** ~4.3 GB модель, ~0.67 GB KV cache, ~4.5 GB codec decoder + compile cache + активации = **~9.6 GB пик** (под 12 GB лимитом) +- **INT4 авто-определение:** `from_pretrained` детектит `"int4"` в пути к чекпоинту, загружает `model.pth` с `WeightOnlyInt4Linear` модулями +- Чекпоинт получен через `tools/llama/quantize.py --checkpoint-path .../s2-pro --mode int4 --groupsize 128` +- PyTorch 2.4.0 обязателен — `_convert_weight_to_int4pack` / `_weight_int4pack_mm` использует kKTileSize=16 + ## WebSocket Protocol (`/ws`) - Server at `ws://localhost:8765/ws` diff --git a/README.md b/README.md index e53e6c4..a1b413d 100644 --- a/README.md +++ b/README.md @@ -1,57 +1,118 @@ # Voice TTS -Local GPU-powered real-time text-to-speech pipeline with a WebSocket API, designed to voice AI agents that stream text in chunks. +Локальный GPU-пайплайн синтеза речи в реальном времени с WebSocket API. +Разработан для озвучки ИИ-агентов: текст поступает частями от LLM, аудио отдаётся +клиенту по мере готовности. -## Features +## Основное -- **Streaming input**: accepts partial text as it is generated by the LLM/agent. -- **Streaming output**: returns PCM audio chunks over WebSocket as soon as they are synthesized. -- **Voice cloning**: single speaker cloned from reference audio, with optional per-emotion references. -- **Interrupt / stop**: agent can immediately stop playback when the user interrupts the AI. -- **Emotion control**: switch emotion on the fly (requires matching reference audio or supported backend). -- **Local GPU**: runs entirely on your NVIDIA GPU (RTX 3090 / 3060 compatible). +- **Стриминг текста**: фразы от LLM приходят частями, сервер не ждёт полного текста +- **Стриминг аудио**: PCM-чанки отправляются сразу после синтеза +- **Клонирование голоса**: один спикер по референсному аудио +- **Прерывание**: немедленная остановка по сигналу `stop` +- **Выбор бэкенда**: переключение между моделями через `TTS_BACKEND` без изменения кода -## Project status +## Бэкенды -- Working WebSocket server with streaming text, audio streaming, and instant stop/resume. -- F5-TTS backend installed, GPU-ready, and producing real audio (`models/F5TTS_v1_Base/` downloaded). -- Dummy backend available for fast offline tests. -- Startup warm-up caches the default reference and primes CUDA. -- Next: multilingual evaluation, latency optimization, and client examples. +| Имя | Модель | Требования | VRAM | RTF | +|-----|--------|-----------|------|-----| +| `s2` | Fish Audio S2-Pro INT4 | S2 API сервер (порт 8081), PyTorch 2.4 | ~9.6 GB | **0.59** | +| `fish_speech` | Fish Speech 1.5 | чекпоинт fish-speech, PyTorch | ~12 GB | ~1.4 | +| `xtts_v2` | XTTS-v2 (Coqui) | авто-загрузка | ~2 GB | ~0.34 | +| `f5_tts` | F5-TTS v1 | чекпоинт F5TTS | ~4 GB | ~1.0 | +| `dummy` | синусоида | ничего | — | — | -## Quick start +## Быстрый старт (S2-Pro INT4) ```bash -# Create virtual environment (Python 3.10-3.12 recommended) -python3.11 -m venv .venv -source .venv/bin/activate +# 1. S2 API сервер (требует ~9.6 GB VRAM, 50 tok/s с --compile) +cd models/fish-speech +PYTHONPATH=. /home/gmikcon/Projects/voice/.venv/bin/python tools/api_server.py \ + --llama-checkpoint-path checkpoints/fs-1.2-int4-g128-s2-pro-nf4 \ + --decoder-checkpoint-path checkpoints/s2-pro/codec.pth \ + --listen 127.0.0.1:8081 --compile -# Install PyTorch with CUDA 12.6 support first -pip install torch torchaudio --index-url https://download.pytorch.org/whl/cu126 - -# Install remaining dependencies -pip install -r requirements.txt - -# (Optional) Download the F5-TTS model beforehand -python scripts/download_f5_tts.py --model F5TTS_v1_Base - -# Run the server -python -m voice_tts.main - -# Or run in dummy test mode -TTS_BACKEND=dummy python -m voice_tts.main +# 2. WebSocket прокси (в другом терминале) +TTS_BACKEND=s2 python -m voice_tts.main ``` -Server will listen on `ws://localhost:8765/ws`. +Сервер слушает `ws://localhost:8765/ws`. -## WebSocket protocol +## WebSocket протокол -See full documentation in [`docs/03_websocket_protocol.md`](docs/03_websocket_protocol.md). +Полная документация — [`docs/03_websocket_protocol.md`](docs/03_websocket_protocol.md). -## Architecture and roadmap +### Кратко -- [`docs/01_overview.md`](docs/01_overview.md) -- [`docs/02_architecture.md`](docs/02_architecture.md) -- [`docs/04_roadmap.md`](docs/04_roadmap.md) -- [`docs/05_usage.md`](docs/05_usage.md) -- [`docs/06_technical_notes.md`](docs/06_technical_notes.md) +Клиент → Сервер: +- `init` — настройки сессии (голос, язык, скорость) +- `text` — чанк текста (`payload` — строка) +- `flush` — озвучить остаток буфера +- `stop` — немедленно прервать +- `emotion` / `config` — смена параметров + +Сервер → Клиент: +- `status` — события (`session_ready`, `segment_started`, `segment_finished`, `stopped`) +- `audio` — PCM16 base64 с `sample_rate` +- `error` — ошибка + +## Архитектура + +``` +[LLM / AI Agent] → WebSocket → [SessionManager] + │ + TextBuffer / Segmenter + │ + segments (async tasks) + │ + TTS Engine (через реестр) + │ + AudioQueue → WS send +``` + +Бэкенды регистрируются через декоратор `@register("name")` в `voice_tts.tts`. +Новый бэкенд добавляется одним файлом с декоратором — никаких правок в `server.py`. + +```python +from voice_tts.tts import register as _register_backend + +@_register_backend("my_model") +class MyEngine(TTSEngine): + ... +``` + +## Установка + +```bash +python3.11 -m venv .venv +source .venv/bin/activate +pip install torch torchaudio --index-url https://download.pytorch.org/whl/cu126 +pip install -r requirements.txt +``` + +## Конфигурация + +Все переменные в `.env` (см. `.env.example`). Ключевые: + +| Переменная | По умолч. | Описание | +|-----------|-----------|----------| +| `TTS_BACKEND` | `s2` | `dummy` / `s2` / `fish_speech` / `f5_tts` / `xtts_v2` | +| `DEFAULT_VOICE_REF` | — | Референсное аудио для клонирования | +| `DEFAULT_REF_TEXT` | — | Точный транскрипт референса | +| `MIN_SEGMENT_LENGTH` | `30` | Мин. длина сегмента (прогрессивно: 12 → 30) | +| `MAX_SEGMENT_LENGTH` | `200` | Макс. длина сегмента | +| `FAST_START_INITIAL` | `12` | Порог для первого сегмента (снижает задержку) | +| `FAST_START_COUNT` | `3` | Сегментов с прогрессивным порогом | + +## Тесты + +```bash +pytest tests/ -v +``` + +## Документация + +- [`docs/01_overview.md`](docs/01_overview.md) — обзор проекта +- [`docs/02_architecture.md`](docs/02_architecture.md) — архитектура +- [`docs/03_websocket_protocol.md`](docs/03_websocket_protocol.md) — протокол +- [`docs/05_usage.md`](docs/05_usage.md) — использование и развёртывание diff --git a/docs/05_usage.md b/docs/05_usage.md index 41105ee..c700c25 100644 --- a/docs/05_usage.md +++ b/docs/05_usage.md @@ -2,189 +2,140 @@ ## Установка -> Рекомендуется Python 3.11. Python 3.14+ пока не имеет совместимых wheel для -> `torch`, поэтому используйте 3.10–3.12. - ```bash -# Клонировать / перейти в директорию проекта -cd voice - -# Создать виртуальное окружение python3.11 -m venv .venv -source .venv/bin/activate # Windows: .venv\Scripts\activate - -# Установить PyTorch с CUDA 12.6 (обязательно первым) +source .venv/bin/activate pip install torch torchaudio --index-url https://download.pytorch.org/whl/cu126 - -# Установить остальные зависимости pip install -r requirements.txt - -# (Опционально) скопировать настройки cp .env.example .env +# отредактировать .env под свой бэкенд ``` -## Запуск сервера +## Запуск: S2-Pro INT4 (рекомендуемый) + +Требуется **PyTorch 2.4.0+cu121** (установлен). S2-сервер и WebSocket-прокси +запускаются в разных терминалах: ```bash -# Основной режим: Fish Speech 1.5 на GPU (по умолчанию) -python -m voice_tts.main +# Терминал 1: S2 API сервер +cd models/fish-speech +PYTHONPATH=. /home/gmikcon/Projects/voice/.venv/bin/python tools/api_server.py \ + --llama-checkpoint-path checkpoints/fs-1.2-int4-g128-s2-pro-nf4 \ + --decoder-checkpoint-path checkpoints/s2-pro/codec.pth \ + --listen 127.0.0.1:8081 --compile +# Первый запуск ~3.5 мин (компиляция), потом ~50 tok/s, ~9.6 GB VRAM -# С настроенным референсом и warm-up (рекомендуется) +# Терминал 2: WebSocket прокси +cd /home/gmikcon/Projects/voice +TTS_BACKEND=s2 python -m voice_tts.main +``` + +### S2 без компиляции (меньше VRAM, медленнее) + +```bash +# Без --compile: ~6 tok/s +cd models/fish-speech && PYTHONPATH=. /home/gmikcon/Projects/voice/.venv/bin/python \ + tools/api_server.py \ + --llama-checkpoint-path checkpoints/fs-1.2-int4-g128-s2-pro-nf4 \ + --decoder-checkpoint-path checkpoints/s2-pro/codec.pth \ + --listen 127.0.0.1:8081 +``` + +### Обновление чекпоинта S2 + +Если выходит новая версия S2-Pro: +1. Скачать новый чекпоинт с HuggingFace +2. Запустить квантизацию: `python tools/llama/quantize.py --checkpoint-path ... --mode int4 --groupsize 128` +3. Обновить `--llama-checkpoint-path` на новый квантизованный чекпоинт + +## Запуск: Fish Speech 1.5 + +```bash TTS_BACKEND=fish_speech \ -DEFAULT_VOICE_REF=voices/self_ref_clean.wav \ -DEFAULT_REF_TEXT="Добрый вечер, меня зовут Евгений." \ +DEFAULT_VOICE_REF=voices/default_ref.wav \ +DEFAULT_REF_TEXT="текст референса" \ WARMUP=true \ python -m voice_tts.main +``` -# Быстрый вариант XTTS-v2 +## Запуск: XTTS-v2 + +```bash TTS_BACKEND=xtts_v2 \ -DEFAULT_VOICE_REF=voices/self_ref_clean.wav \ +DEFAULT_VOICE_REF=voices/default_ref.wav \ python -m voice_tts.main +``` -# Тестовый режим без модели +Модель скачивается автоматически при первом запуске. + +## Запуск: тестовый dummy + +```bash TTS_BACKEND=dummy python -m voice_tts.main ``` -Сервер поднимется на `ws://localhost:8765/ws`. - -## Проверка работоспособности +## Проверка ```bash curl http://localhost:8765/health -# {"status":"ok","backend":"fish_speech"} +# {"status":"ok","backend":"s2"} ``` ## Клиенты -В директории `examples/` лежат готовые клиенты: - -- `examples/client_python.py` — Python-клиент с воспроизведением через `sounddevice`. -- `examples/client_browser.html` — HTML/JS клиент для браузера с `AudioContext`. - -### Python-клиент +В `examples/`: +- `client_python.py` — Python-клиент с sounddevice +- `client_browser.html` — браузерный клиент ```bash pip install websockets sounddevice python examples/client_python.py --uri ws://localhost:8765/ws "Привет, мир!" ``` -Опции: - -```bash -python examples/client_python.py \ - --uri ws://localhost:8765/ws \ - --voice-ref voices/self_ref_clean.wav \ - --language ru \ - --speed 1.0 \ - "Это тестовая фраза для проверки." -``` - -Клиент: -1. Отправляет `init` с настройками. -2. Разбивает текст на слова и шлёт их как потоковые `text`. -3. Отправляет `flush`. -4. Получает `audio`-чанки, декодирует base64 PCM16 и складывает в аудиобуфер. -5. `sounddevice` воспроизводит аудио в реальном времени из callback. -6. По `Ctrl+C` отправляет `stop` и выходит. - -### Браузерный клиент - -```bash -# 1. Терминал 1: запустить TTS сервер (dummy — без GPU, любой бэкенд по выбору) -TTS_BACKEND=dummy python -m voice_tts.main - -# 2. Терминал 2: открыть клиент через HTTP (file:// не даёт WebSocket в некоторых браузерах) -python -m http.server 8080 --directory examples/ -# Открыть http://localhost:8080/client_browser.html -``` - -Нажмите **Connect**, затем **Speak streaming**. Клиент: - -1. Шлёт `init` с настройками (язык, скорость, эмоция). -2. Шлёт слова по одному как потоковые `text` с задержкой 120 мс. -3. Завершает `flush`. -4. Получает `audio`-чанки с динамическим `sample_rate` (поддерживается любой бэкенд). -5. Декодирует PCM16 из base64 и ставит в очередь `AudioBuffer` для бесшовного воспроизведения. - -Кнопка **Stop** отправляет `stop`. Кнопка **Test audio** проверяет звук в браузере независимо от сервера. - -## Настройка через переменные окружения (.env) - -| Переменная | Описание | По умолчанию | -|------------|----------|--------------| -| `HOST` | Хост сервера | `0.0.0.0` | -| `PORT` | Порт сервера | `8765` | -| `LOG_LEVEL` | Уровень логирования | `INFO` | -| `TTS_BACKEND` | Бэкенд (`dummy` / `fish_speech` / `xtts_v2`) | `fish_speech` | -| `TTS_MODEL_PATH` | Папка с checkpoint Fish Speech / XTTS | — | -| `TTS_VOCAB_PATH` | Исходники Fish Speech v1.5.1 | `models/fish-speech-v1.5.1` | -| `TTS_SAMPLE_RATE` | Частота дискретизации | `44100` | -| `TTS_SPEED` | Множитель скорости речи | `1.2` | -| `VOICES_DIR` | Директория с референсами | `voices` | -| `DEFAULT_VOICE_REF` | Референс по умолчанию | — | -| `DEFAULT_REF_TEXT` | Точный текст референса (skip Whisper) | — | -| `MIN_SEGMENT_LENGTH` | Мин. длина сегмента | `30` | -| `MAX_SEGMENT_LENGTH` | Макс. длина сегмента | `200` | -| `MAX_BUFFER_WAIT_MS` | Макс. ожидание перед flush | `500` | -| `DEVICE` | `cuda` или `cpu` | `cuda` | -| `DTYPE` | `bfloat16` / `float16` / `float32` | `bfloat16` | -| `FISH_COMPILE` | `torch.compile` для Fish Speech | `false` | -| `FISH_CHUNK_LENGTH` | Длина LLM-чанка Fish Speech | `200` | -| `FISH_USE_MEMORY_CACHE` | Кэшировать VQ референса | `on` | -| `WARMUP` | Прогреть CUDA и кэшировать референс | `false` | -| `WARMUP_TEXT` | Текст для warm-up | `Привет. Это тестовая фраза.` | - -## Модели - -### Fish Speech 1.5 - -По умолчанию используется локальный checkpoint `models/fishaudio_fish-speech-1.5/`: - -- `model.pth` — LLaMA языковая модель, -- `firefly-gan-vq-fsq-8x1024-21hz-generator.pth` — VQ-GAN декодер, -- `tokenizer.tiktoken`, `config.json`, `special_tokens.json`. - -Исходный код Fish Speech v1.5.1 должен лежать в `models/fish-speech-v1.5.1/`, -чтобы Python мог импортировать нужные модули. - -### XTTS-v2 - -Coqui-модель `tts_models/multilingual/multi-dataset/xtts_v2` скачивается -автоматически при первом запуске `TTS_BACKEND=xtts_v2`. Можно указать -локальный путь через `TTS_MODEL_PATH`. - ## Референсные аудио -Поместите файлы в директорию `voices/`: - +Поместите в `voices/`: ``` voices/ -├── self_ref_clean.wav -├── self_ref_clean.lab -├── default_neutral.wav -├── default_happy.wav -├── default_sad.wav +├── product_voice_sample_clean.wav +├── product_voice_sample_clean.lab (точный транскрипт для Fish Speech) └── ... ``` -Требования к референсу: -- WAV или другой формат, читаемый `torchaudio`. -- Моно, 16+ кГц. -- Длина 5–15 секунд для Fish Speech, 3–10 секунд для XTTS-v2. -- Чистая речь одного спикера без фонового шума. -- Для Fish Speech рядом с `.wav` можно положить `.lab` с точным транскриптом. - Иначе сервер использует `DEFAULT_REF_TEXT` или Whisper-транскрипцию. +Требования: +- WAV, моно, 16+ кГц +- 5–15 секунд чистой речи одного спикера +- Для Fish Speech рядом с `.wav` можно положить `.lab` с транскриптом ## Тесты ```bash -# Быстрые тесты без загрузки тяжёлых моделей -python -m pytest tests/ -v +pytest tests/ -v ``` -Для запуска в CI рекомендуется отдельно прогонять быстрые и тяжёлые тесты: +## Добавление нового бэкенда -```bash -python -m pytest tests/test_segmenter.py tests/test_server.py -v -python -m pytest tests/test_fish_speech_backend.py -v +1. Создать файл `src/voice_tts/tts/my_model.py` +2. Унаследоваться от `TTSEngine`, добавить декоратор: + +```python +from voice_tts.tts import register as _register_backend +from voice_tts.tts.engine import TTSEngine + +@_register_backend("my_model") +class MyEngine(TTSEngine): + def __init__(self): + # читать настройки из voice_tts.config.settings + + async def synthesize(self, text, ref_audio_path, language, speed, emotion, ref_text=None): + ... + + async def warm_up(self): + ... ``` + +3. Готово — бэкенд доступен через `TTS_BACKEND=my_model` + +## Переменные окружения + +Полный список — в `.env.example`. diff --git a/docs/06_technical_notes.md b/docs/06_technical_notes.md index 20ab2ef..9e0ced8 100644 --- a/docs/06_technical_notes.md +++ b/docs/06_technical_notes.md @@ -35,29 +35,58 @@ Для русского языка предложения обычно короче, чем на английском, поэтому `max_length` выбран консервативно. -## F5-TTS: особенности интеграции +## Fish Speech 1.5 -- Используется готовая модель `F5TTS_v1_Base` через официальный пакет `f5-tts`. -- Модель автоматически загружается из Hugging Face при первом вызове `load()`; - можно скачать заранее скриптом `scripts/download_f5_tts.py`. -- Референсное аудио транскрибируется автоматически, если `ref_text` не задан. -- Результат транскрипции и путь к обработанному аудио кэшируются по паре - `(путь, эмоция)`, чтобы не повторять работу при смене сегментов. -- Скорость регулируется параметром `speed` инференса F5-TTS. -- Текущая целевая частота дискретизации — **24 kHz**; модель сама - передискретирует референс при необходимости. -- Инференс выполняется в отдельном потоке (`asyncio.to_thread`) с - `asyncio.Lock`, чтобы не блокировать event loop и не допускать - параллельных CUDA-вызовов. Все `send_text` сериализуются через - отдельный `_send_lock`, чтобы избежать deadlock внутри `websockets`. +### Особенности интеграции + +- Локальный checkpoint в `models/fishaudio_fish-speech-1.5/`: + `model.pth`, `firefly-gan-vq-fsq-8x1024-21hz-generator.pth`, + `tokenizer.tiktoken`, `config.json`, `special_tokens.json`. +- Исходный код `models/fish-speech-v1.5.1/` добавляется в `sys.path` для импорта модулей. +- Выходная частота дискретизации — **44,1 кГц**. +- Точный транскрипт референса важен: используется `.lab` рядом с референсом, + затем `DEFAULT_REF_TEXT`, затем Whisper-транскрипция, затем placeholder. +- Скорость регулируется resampling'ом после синтеза (`TTS_SPEED`). +- Параллельные CUDA-вызовы сериализуются через `asyncio.Lock` и + `asyncio.to_thread`. + +### Тюнинг + +- `FISH_USE_MEMORY_CACHE=on` — кэшировать VQ-представление референса (включено). +- `FISH_CHUNK_LENGTH` — длина LLM-чанка (100–300, по умолчанию 200). + Больше = длиннее связные куски, но выше задержка. +- `FISH_COMPILE=true` — пытается включить `torch.compile`. **Не включать по умолчанию:** + при повторном инференсе возникает ошибка `accessing tensor output of CUDAGraphs that has been overwritten`. + Исследуется отдельно. ### Замеры задержки (RTX 3090, Python 3.11, CUDA 12.6) -- Первый запуск без `DEFAULT_REF_TEXT`: ~5–6 с, большая часть уходит на Whisper-транскрипцию референса. -- С `DEFAULT_REF_TEXT` и `WARMUP=true`: warm-up занимает ~2 с (загрузка модели + один инференс). -- После warm-up с кэшированным референсом: первый audio-chunk ~1.1 с на коротком сегменте. -- 4 сегмента подряд: первый finished ~1.1 с, последний ~4.4 с. -- `stop` + возобновление работает без переподключения WebSocket. +- Первый запуск без `DEFAULT_REF_TEXT`: ~5–6 с, большая часть уходит на Whisper-транскрипцию. +- С `DEFAULT_REF_TEXT` и `WARMUP=true`: загрузка модели + один инференс. +- RTF (real-time factor) Fish Speech ~1.4 на коротких сегментах: медленнее реального времени. +- RTF XTTS-v2 ~0.34: быстрее реального времени. +- Fish Speech даёт более естественную русскую интонацию, поэтому выбран по умолчанию. +- XTTS-v2 — резервный быстрый бэкенд для сценариев, где задержка важнее качества. + +### Сравнение с XTTS-v2 + +| Показатель | Fish Speech 1.5 | XTTS-v2 | +|------------|-----------------|---------| +| RTF | ~1.4 | ~0.34 | +| Русская интонация | естественнее | приемлемо, акцент чаще | +| Английская речь | хорошо | хорошо | +| Размер weights | ~2 ГБ LLM + VQGAN | ~3 ГБ | +| Sample rate | 44,1 кГц | 24 кГц | +| Требует ref transcript | да, точный | да, но терпимее | +| `torch.compile` | нестабилен | не применяется | + +## XTTS-v2 + +- Coqui-модель `tts_models/multilingual/multi-dataset/xtts_v2`. +- Автоматически скачивается при первом `TTS_BACKEND=xtts_v2`. +- Можно указать локальный checkpoint через `TTS_MODEL_PATH`. +- Sample rate 24 кГц; клиенты `examples/client_*.py`/`examples/client_browser.html` + настроены на динамический sample rate из сообщений `audio`. ## Мультиязычность @@ -94,10 +123,11 @@ 3. **Fine-tuning F5-TTS** — самый трудоёмкий, но даёт лучший контроль над акцентом и языком. -Сейчас оставляем F5-TTS как единственный backend, но архитектура -(`TTSEngine` + `_BACKEND_MAP`) позволяет добавить fallback позже. +Сейчас оставляем Fish Speech 1.5 как бэкенд по умолчанию для en/ru, +XTTS-v2 — как быстрый резерв. Архитектура (`TTSEngine` + `_BACKEND_MAP`) +позволяет добавить fallback позже. - F5-TTS может не идеально произносить украинский / европейские языки из коробки — возможно потребуется fine-tuning или fallback. Сейчас протокол и сегментатор не ограничивают язык; качество зависит от самой модели. - RTX 3060 (12 GB) подойдёт для базовой модели, но batch-size и длина референса придётся ограничивать. - Быстрый `stop` во время CUDA kernel не прервёт уже запущенный kernel, но предотвратит отправку результата. -- `main.py` создаёт engine до старта uvicorn; при `TTS_BACKEND=f5_tts` первый запуск может занять десяток секунд из-за загрузки модели и vocos. +- `main.py` создаёт engine до старта uvicorn; при `TTS_BACKEND=fish_speech` первый запуск занимает десяток секунд из-за загрузки LLM и VQ-GAN. diff --git a/examples/client_browser.html b/examples/client_browser.html index a7f237e..e0da15a 100644 --- a/examples/client_browser.html +++ b/examples/client_browser.html @@ -3,61 +3,31 @@ - Voice TTS WebSocket Client + Voice TTS -

Voice TTS WebSocket Client

- - - - +

Voice TTS

+ +
-
- - -
-
- - -
-
- - -
-
- - - - - - - -
- - - + +
- -
+
Нажми Connect для запуска.
diff --git a/pyproject.toml b/pyproject.toml index 7baa355..f56bcbe 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -12,7 +12,7 @@ "torch>=2.4.0", "torchaudio>=2.4.0", "loguru>=0.7.0", - "f5-tts>=1.1.0", + "TTS>=0.25.0", "soundfile>=0.12.0", "pydub>=0.25.0", "tqdm>=4.66.0", @@ -52,3 +52,4 @@ [tool.pytest.ini_options] asyncio_mode = "auto" testpaths = ["tests"] +addopts = "-p no:cacheprovider" diff --git a/scripts/benchmark_backends.py b/scripts/benchmark_backends.py new file mode 100644 index 0000000..4257ca6 --- /dev/null +++ b/scripts/benchmark_backends.py @@ -0,0 +1,220 @@ +"""Benchmark local TTS backends: Fish Speech 1.5 vs XTTS-v2. + +Measures: + - cold load time + - per-sentence synthesis time + - real-time factor (RTF) + - Whisper ASR accuracy + +Outputs WAV files and a JSON report in outputs/benchmark/. +""" + +import asyncio +import json +import os +import time +from pathlib import Path + +import numpy as np +import torch +import torchaudio +from loguru import logger + +# Ensure project source is importable. +ROOT = Path(__file__).resolve().parent.parent +if str(ROOT / "src") not in os.sys.path: + os.sys.path.insert(0, str(ROOT / "src")) + +from voice_tts.audio.formats import float_to_wav_bytes # noqa: E402 +from voice_tts.config import settings # noqa: E402 +from voice_tts.tts.fish_speech_backend import FishSpeechEngine # noqa: E402 +from voice_tts.tts.xtts_backend import XTTSv2Engine # noqa: E402 + + +SENTENCES = [ + ("en", "Hello, this is a short English sentence for the benchmark."), + ( + "ru", + "Добрый вечер, меня зовут Евгений. Это тестовое предложение для проверки качества синтеза.", + ), + ( + "en", + "The quick brown fox jumps over the lazy dog, testing every letter of the alphabet.", + ), + ( + "ru", + "Наша цель — сделать речь естественной и понятной на английском и русском языках.", + ), +] + + +def _load_whisper(model_name: str = "large-v3", device: str = "cuda"): + from faster_whisper import WhisperModel + + model_path = ROOT / "models" / "faster-whisper" / model_name + model_path.parent.mkdir(parents=True, exist_ok=True) + if not model_path.exists(): + logger.info("Downloading faster-whisper {} ...", model_name) + return WhisperModel( + model_name, + device=device if device.startswith("cuda") else "cpu", + compute_type="float16" if device.startswith("cuda") else "int8", + download_root=str(model_path.parent), + ) + + +def _transcribe(model, wav_path: Path) -> str: + segments, _ = model.transcribe(str(wav_path), language=None) + return " ".join(s.text.strip() for s in segments).strip() + + +def _save_wav(audio: np.ndarray, sr: int, path: Path) -> None: + path.parent.mkdir(parents=True, exist_ok=True) + path.write_bytes(float_to_wav_bytes(audio, sr)) + + +def _normalise(text: str) -> str: + return " ".join(text.lower().replace(",", " ").replace(".", " ").replace("—", " ").split()) + + +def _wer(ref: str, hyp: str) -> float: + """Simple word-error rate.""" + r = _normalise(ref).split() + h = _normalise(hyp).split() + if not r: + return 0.0 if not h else 1.0 + # Levenshtein distance + prev = list(range(len(h) + 1)) + for i, rc in enumerate(r, 1): + curr = [i] + for j, hc in enumerate(h, 1): + cost = 0 if rc == hc else 1 + curr.append(min(curr[-1] + 1, prev[j] + 1, prev[j - 1] + cost)) + prev = curr + return prev[-1] / len(r) + + +async def _synth(engine, text: str, lang: str, ref: Path, speed: float = 1.0): + # Fish Speech has async synthesize, XTTS too; both are heavy CUDA sync. + # Run inside a thread to mimic server behavior. + def _run(): + kwargs = dict( + text=text, + ref_audio_path=ref, + language=lang, + speed=speed, + emotion="neutral", + ) + if isinstance(engine, FishSpeechEngine): + kwargs["ref_text"] = settings.default_ref_text + return asyncio.run(engine.synthesize(**kwargs)) + + return await asyncio.to_thread(_run) + + +async def benchmark_backend(name: str, factory, ref: Path, output_dir: Path): + logger.info("Benchmarking {} ...", name) + report = {"backend": name, "sentences": [], "load_seconds": None} + + t0 = time.perf_counter() + engine = factory() + engine.load() + report["load_seconds"] = round(time.perf_counter() - t0, 3) + + # Optional warm-up to make per-sentence timing more representative. + warmup_text = "One two three." if name == "xtts_v2" else "Раз, два, три." + await _synth(engine, warmup_text, "en" if name == "xtts_v2" else "ru", ref, 1.0) + + whisper_model = _load_whisper(device=settings.device) + + for lang, text in SENTENCES: + t0 = time.perf_counter() + audio = await _synth(engine, text, lang, ref, settings.tts_speed) + elapsed = time.perf_counter() - t0 + + duration = len(audio) / engine.sample_rate if engine.sample_rate else 0.0 + rtf = elapsed / duration if duration else 0.0 + + wav_path = output_dir / name / f"{lang}_{len(report['sentences'])}.wav" + _save_wav(audio, engine.sample_rate, wav_path) + + hyp = _transcribe(whisper_model, wav_path) + wer = _wer(text, hyp) + + report["sentences"].append( + { + "language": lang, + "text": text, + "duration_seconds": round(duration, 3), + "synth_seconds": round(elapsed, 3), + "rtf": round(rtf, 3), + "whisper": hyp, + "wer": round(wer, 3), + "wav": str(wav_path.relative_to(output_dir)), + } + ) + logger.info( + "[{} {}] RTF={} dur={}s synth={}s WER={}", + name, + lang, + round(rtf, 3), + round(duration, 3), + round(elapsed, 3), + round(wer, 3), + ) + + return report + + +def _fish_factory(): + return FishSpeechEngine( + checkpoint_path=settings.tts_model_path, + source_root=settings.tts_vocab_path, + device=settings.device, + compile=settings.fish_compile, + use_memory_cache=settings.fish_use_memory_cache, + chunk_length=settings.fish_chunk_length, + ) + + +def _xtts_factory(): + return XTTSv2Engine( + model_name=settings.tts_model_name, + device=settings.device, + ) + + +async def main(): + output_dir = ROOT / "outputs" / "benchmark" + output_dir.mkdir(parents=True, exist_ok=True) + + ref = Path(settings.default_voice_ref) if settings.default_voice_ref else None + if not ref or not ref.exists(): + raise FileNotFoundError("Set DEFAULT_VOICE_REF to an existing reference WAV.") + + reports = [] + + if "fish_speech" in os.environ.get("BENCHMARK_BACKENDS", "fish_speech,xtts_v2"): + reports.append(await benchmark_backend("fish_speech", _fish_factory, ref, output_dir)) + if "xtts_v2" in os.environ.get("BENCHMARK_BACKENDS", "fish_speech,xtts_v2"): + reports.append(await benchmark_backend("xtts_v2", _xtts_factory, ref, output_dir)) + + summary_path = output_dir / "summary.json" + summary_path.write_text(json.dumps(reports, ensure_ascii=False, indent=2)) + logger.info("Report saved to {}", summary_path) + + # Print a quick Markdown table. + print("\n## Summary") + print("| Backend | Lang | Dur (s) | Synth (s) | RTF | WER |") + print("|---------|------|---------|-----------|-----|-----|") + for report in reports: + for s in report["sentences"]: + print( + f"| {report['backend']} | {s['language']} | " + f"{s['duration_seconds']} | {s['synth_seconds']} | " + f"{s['rtf']} | {s['wer']} |" + ) + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/scripts/benchmark_compile.py b/scripts/benchmark_compile.py new file mode 100644 index 0000000..382020e --- /dev/null +++ b/scripts/benchmark_compile.py @@ -0,0 +1,73 @@ +"""Quick test: Fish Speech with and without torch.compile on one sentence.""" + +import asyncio +import os +import time +from pathlib import Path + +ROOT = Path(__file__).resolve().parent.parent +if str(ROOT / "src") not in os.sys.path: + os.sys.path.insert(0, str(ROOT / "src")) + +from voice_tts.config import settings +from voice_tts.tts.fish_speech_backend import FishSpeechEngine + + +TEXT = "The quick brown fox jumps over the lazy dog." +LANG = "en" +REF = Path(settings.default_voice_ref) + + +async def run(compile: bool) -> float: + engine = FishSpeechEngine( + checkpoint_path=settings.tts_model_path, + source_root=settings.tts_vocab_path, + device=settings.device, + compile=compile, + chunk_length=settings.fish_chunk_length, + ) + engine.load() + + # First run (warm-up + compile) + t0 = time.perf_counter() + await asyncio.to_thread( + lambda: asyncio.run( + engine.synthesize( + text=TEXT, + ref_audio_path=REF, + language=LANG, + speed=1.0, + emotion="neutral", + ref_text=settings.default_ref_text, + ) + ) + ) + first = time.perf_counter() - t0 + + # Second run + t0 = time.perf_counter() + await asyncio.to_thread( + lambda: asyncio.run( + engine.synthesize( + text=TEXT, + ref_audio_path=REF, + language=LANG, + speed=1.0, + emotion="neutral", + ref_text=settings.default_ref_text, + ) + ) + ) + second = time.perf_counter() - t0 + return first, second + + +async def main(): + f1, f2 = await run(compile=False) + print(f"compile=False first={f1:.2f}s second={f2:.2f}s") + c1, c2 = await run(compile=True) + print(f"compile=True first={c1:.2f}s second={c2:.2f}s") + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/src/voice_tts/api/server.py b/src/voice_tts/api/server.py index 27e0d57..cb392f7 100644 --- a/src/voice_tts/api/server.py +++ b/src/voice_tts/api/server.py @@ -24,20 +24,9 @@ from voice_tts.audio.formats import float_to_pcm16, pcm16_to_base64 from voice_tts.config import settings from voice_tts.session.state import SessionState, VoiceProfile -from voice_tts.tts.engine import DummyTTSEngine, TTSEngine -from voice_tts.tts.f5_backend import F5TTSEngine -from voice_tts.tts.fish_speech_backend import FishSpeechEngine +from voice_tts.tts import create_engine as _create_tts_engine +from voice_tts.tts.engine import TTSEngine from voice_tts.tts.segmenter import Segmenter -from voice_tts.tts.xtts_backend import XTTSv2Engine - - -# Supported TTS backends -_BACKEND_MAP: dict[str, type[TTSEngine]] = { - "dummy": DummyTTSEngine, - "f5_tts": F5TTSEngine, - "xtts_v2": XTTSv2Engine, - "fish_speech": FishSpeechEngine, -} class SessionManager: @@ -49,12 +38,15 @@ self.segmenter = Segmenter( min_length=settings.min_segment_length, max_length=settings.max_segment_length, + fast_start_initial=settings.fast_start_initial, + fast_start_count=settings.fast_start_count, ) self.state = SessionState(session_id="") self._running = True self._tasks: list[asyncio.Task] = [] self._send_lock = asyncio.Lock() self._synth_lock = asyncio.Lock() + self._synth_tasks: set[asyncio.Task] = set() async def run(self) -> None: await self.ws.accept() @@ -187,6 +179,10 @@ self.state.stop() self.state.clear_buffer() + # Cancel all in-flight synthesis tasks. + for t in set(self._synth_tasks): + t.cancel() + # Drain pending audio queue while not self.state.audio_queue.empty(): try: @@ -230,8 +226,10 @@ ref_path = self.state.voice.ref_for(self.state.emotion) + # Track this task so stop can cancel it. + task = asyncio.current_task() + self._synth_tasks.add(task) try: - # Serialize GPU inference and run it off the main event loop. async with self._synth_lock: audio = await asyncio.to_thread( SessionManager._sync_synthesize, @@ -242,10 +240,14 @@ self.state.speed, self.state.emotion, ) + except asyncio.CancelledError: + raise except Exception as exc: logger.exception("TTS synthesis failed") await self._send(ErrorMessage(message=f"TTS failed: {exc}", seq=segment_seq)) return + finally: + self._synth_tasks.discard(task) if self.state.is_stopped(): return @@ -291,32 +293,14 @@ speed: float, emotion: str, ) -> "np.ndarray": - """Thread-safe wrapper around the synchronous TTS inference call.""" + """Thread-safe wrapper around the TTS inference call. + + All backends expose an async ``synthesize``. Inside a thread from + ``asyncio.to_thread`` there is no running loop, so we drive the + coroutine with a fresh transient event loop. + """ import numpy as np - # The dummy backend is async; run it on a transient event loop in this thread. - if isinstance(engine, DummyTTSEngine): - try: - loop = asyncio.get_running_loop() - except RuntimeError: - loop = None - - async def _run() -> "np.ndarray": - return await engine.synthesize( - text=text, - ref_audio_path=ref_path, - language=language, - speed=speed, - emotion=emotion, - ) - - if loop is not None: - return asyncio.run_coroutine_threadsafe(_run(), loop).result() - return asyncio.run(_run()) - - # Backends expose an async synthesize method that blocks on CPU/CUDA work. - # Inside a thread from asyncio.to_thread there is no running loop, so we - # drive the coroutine with a fresh transient event loop. kwargs: dict = dict( text=text, ref_audio_path=ref_path, @@ -324,9 +308,7 @@ speed=speed, emotion=emotion, ) - if isinstance(engine, F5TTSEngine) and settings.default_ref_text: - kwargs["ref_text"] = settings.default_ref_text - if isinstance(engine, FishSpeechEngine) and settings.default_ref_text: + if settings.default_ref_text: kwargs["ref_text"] = settings.default_ref_text return asyncio.run(engine.synthesize(**kwargs)) @@ -342,35 +324,7 @@ def _create_engine() -> TTSEngine: - backend = settings.tts_backend - engine_cls = _BACKEND_MAP.get(backend) - if engine_cls is None: - raise RuntimeError( - f"Unknown TTS backend: {backend}. " - f"Available backends: {list(_BACKEND_MAP.keys())}" - ) - if engine_cls is F5TTSEngine: - engine = engine_cls( - model=settings.tts_model_name, - sample_rate=settings.tts_sample_rate, - ) - elif engine_cls is XTTSv2Engine: - engine = engine_cls( - model_name=settings.tts_model_name, - sample_rate=settings.tts_sample_rate, - ) - elif engine_cls is FishSpeechEngine: - engine = engine_cls( - checkpoint_path=settings.tts_model_path, - source_root=settings.tts_vocab_path, - sample_rate=settings.tts_sample_rate, - device=settings.device, - compile=settings.fish_compile, - use_memory_cache=settings.fish_use_memory_cache, - chunk_length=settings.fish_chunk_length, - ) - else: - engine = engine_cls(sample_rate=settings.tts_sample_rate) + engine = _create_tts_engine() if hasattr(engine, "load"): engine.load() return engine diff --git a/src/voice_tts/config.py b/src/voice_tts/config.py index 00d3788..207970f 100644 --- a/src/voice_tts/config.py +++ b/src/voice_tts/config.py @@ -10,11 +10,16 @@ log_level: str = "INFO" # TTS model configuration - tts_backend: str = "f5_tts" # or "dummy" for tests - tts_model_name: str = "F5TTS_v1_Base" # env: TTS_MODEL_NAME + tts_backend: str = "fish_speech" # "dummy" / "f5_tts" / "xtts_v2" / "fish_speech" + # XTTS-v2 model name (Coqui model manager path); used when backend is xtts_v2. + tts_model_name: str = "tts_models/multilingual/multi-dataset/xtts_v2" + # Local checkpoint path. For Fish Speech this is the folder containing model.pth, + # firefly-gan-vq-fsq-8x1024-21hz-generator.pth, tokenizer.tiktoken, config.json, etc. tts_model_path: Path | None = None + # Source tree path for Fish Speech modules (e.g. models/fish-speech-v1.5.1). tts_vocab_path: Path | None = None - tts_sample_rate: int = 24_000 + tts_sample_rate: int = 44_100 + tts_speed: float = 1.2 # env: TTS_SPEED # Reference voices directory voices_dir: Path = Path("voices") @@ -23,6 +28,8 @@ min_segment_length: int = 30 max_segment_length: int = 200 max_buffer_wait_ms: int = 500 + fast_start_initial: int = 12 # first segment threshold for lower latency + fast_start_count: int = 3 # how many segments use progressive sizing # GPU / inference device: str = "cuda" # or "cpu" @@ -32,6 +39,20 @@ default_voice_ref: Path | None = None # env: DEFAULT_VOICE_REF default_ref_text: str | None = None # env: DEFAULT_REF_TEXT + # S2-Pro backend settings + s2_api_url: str = "http://127.0.0.1:8081" + + # Fish Speech-specific settings + fish_compile: bool = False # torch.compile the LLaMA model (slow first run) + fish_chunk_length: int = 200 # 100-300; higher = longer coherent chunks + fish_use_memory_cache: str = "on" # "on" / "off" reference VQ cache + fish_top_p: float = 0.7 # nucleus sampling (0-1); lower = more deterministic + fish_temperature: float = 0.7 # sampling temperature; lower = more stable + fish_repetition_penalty: float = 1.2 # >1 reduces repeated tokens + fish_seed: int | None = None # None = random; set for reproducible output + fish_tail_silence_threshold: float = 0.02 # trim trailing silence below this RMS + fish_lowpass_cutoff: int = 0 # Hz; low-pass filter output to reduce VQ noise (0 = off) + # Warm-up warmup: bool = False # run a dummy inference at startup warmup_text: str = "Привет. Это тестовая фраза." diff --git a/src/voice_tts/tts/__init__.py b/src/voice_tts/tts/__init__.py new file mode 100644 index 0000000..740e740 --- /dev/null +++ b/src/voice_tts/tts/__init__.py @@ -0,0 +1,62 @@ +"""TTS backend registry and factory. + +Each backend module self-registers via ``@register(name)``. +Usage: + + >>> from voice_tts.tts import create_engine + >>> engine = create_engine("s2") # by name + >>> engine = create_engine() # from settings.TTS_BACKEND +""" + +from pathlib import Path + +from voice_tts.config import settings +from voice_tts.tts.engine import TTSEngine, DummyTTSEngine + +_BACKENDS: dict[str, type[TTSEngine]] = { + "dummy": DummyTTSEngine, +} + + +def register(name: str): + """Decorator that registers a :class:`TTSEngine` subclass.""" + def decorator(cls): + _BACKENDS[name] = cls + return cls + return decorator + + +def get_backend(name: str | None = None) -> type[TTSEngine]: + """Look up a backend class by name (defaults to ``settings.tts_backend``).""" + name = name or settings.tts_backend + if name not in _BACKENDS: + raise RuntimeError( + f"Unknown TTS backend: {name}. " + f"Available backends: {list(_BACKENDS.keys())}" + ) + return _BACKENDS[name] + + +def list_backends() -> list[str]: + """Return names of all registered backends.""" + return list(_BACKENDS.keys()) + + +def create_engine(name: str | None = None) -> TTSEngine: + """Create a TTS engine instance. + + Each backend reads its configuration from ``settings``; no arguments + are passed to the constructor besides ``name``. + """ + cls = get_backend(name) + return cls() + + +# Lazy-import backends so they self-register via ``@register``. +# Each submodule gracefully handles missing optional dependencies. +from voice_tts.tts import ( # noqa: E402, F401 + s2_backend, + f5_backend, + xtts_backend, + fish_speech_backend, +) diff --git a/src/voice_tts/tts/engine.py b/src/voice_tts/tts/engine.py index 07fad11..4053d4e 100644 --- a/src/voice_tts/tts/engine.py +++ b/src/voice_tts/tts/engine.py @@ -20,6 +20,7 @@ language: str, speed: float, emotion: str, + ref_text: str | None = None, ) -> np.ndarray: """Return audio as float32 ndarray normalized to [-1, 1].""" ... @@ -43,6 +44,7 @@ language: str, speed: float, emotion: str, + ref_text: str | None = None, ) -> np.ndarray: duration_sec = max(0.5, len(text) * 0.08) / speed num_samples = int(self.sample_rate * duration_sec) @@ -52,6 +54,5 @@ audio *= np.hanning(num_samples) return audio.astype(np.float32) - async def warm_up(self) -> None: pass diff --git a/src/voice_tts/tts/f5_backend.py b/src/voice_tts/tts/f5_backend.py index e9cad82..c835793 100644 --- a/src/voice_tts/tts/f5_backend.py +++ b/src/voice_tts/tts/f5_backend.py @@ -4,6 +4,7 @@ from loguru import logger from voice_tts.config import settings +from voice_tts.tts import register as _register_backend from voice_tts.tts.engine import TTSEngine @@ -21,33 +22,34 @@ torchaudio = None +@_register_backend("f5_tts") class F5TTSEngine(TTSEngine): """F5-TTS backend for local GPU inference with voice cloning.""" def __init__( self, - model: str = "F5TTS_v1_Base", - sample_rate: int = 24_000, + model: str | None = None, + sample_rate: int | None = None, device: str | None = None, - nfe_step: int = 32, - cfg_strength: float = 2.0, - sway_sampling_coef: float = -1.0, - speed: float = 1.0, - target_rms: float = 0.1, - cross_fade_duration: float = 0.0, - remove_silence: bool = False, + nfe_step: int | None = None, + cfg_strength: float | None = None, + sway_sampling_coef: float | None = None, + speed: float | None = None, + target_rms: float | None = None, + cross_fade_duration: float | None = None, + remove_silence: bool | None = None, ): super().__init__() - self.model_name = model - self.sample_rate = sample_rate + self.model_name = model or settings.tts_model_name + self.sample_rate = sample_rate or settings.tts_sample_rate self.device = device or settings.device - self.nfe_step = nfe_step - self.cfg_strength = cfg_strength - self.sway_sampling_coef = sway_sampling_coef - self.speed = speed - self.target_rms = target_rms - self.cross_fade_duration = cross_fade_duration - self.remove_silence = remove_silence + self.nfe_step = nfe_step or 32 + self.cfg_strength = cfg_strength or 2.0 + self.sway_sampling_coef = sway_sampling_coef or -1.0 + self.speed = speed or 1.0 + self.target_rms = target_rms or 0.1 + self.cross_fade_duration = cross_fade_duration or 0.0 + self.remove_silence = remove_silence or False self._f5: "F5TTS | None" = None self._ref_cache: dict[str, tuple[str, str]] = {} # emotion_key -> (processed_audio_path, ref_text) @@ -66,8 +68,18 @@ self.model_name, self.device, ) + ckpt_file = str(settings.tts_model_path) if settings.tts_model_path else "" + vocab_file = str(settings.tts_vocab_path) if settings.tts_vocab_path else "" + logger.info( + "Loading F5-TTS model {} ckpt={} vocab={} ...", + self.model_name, + ckpt_file or "(default)", + vocab_file or "(default)", + ) self._f5 = F5TTS( model=self.model_name, + ckpt_file=ckpt_file, + vocab_file=vocab_file, device=self.device, ) self.sample_rate = self._f5.target_sample_rate diff --git a/src/voice_tts/tts/fish_speech_backend.py b/src/voice_tts/tts/fish_speech_backend.py new file mode 100644 index 0000000..ced922d --- /dev/null +++ b/src/voice_tts/tts/fish_speech_backend.py @@ -0,0 +1,281 @@ +"""Fish Speech 1.5 backend for local GPU inference with zero-shot voice cloning.""" + +from pathlib import Path + +import numpy as np +from loguru import logger + +from voice_tts.config import settings +from voice_tts.tts import register as _register_backend +from voice_tts.tts.engine import TTSEngine + +try: + import sys + + import torch + import torchaudio + + FISH_SPEECH_AVAILABLE = True +except ImportError as exc: + logger.warning("torch/torchaudio not available for Fish Speech: {}", exc) + FISH_SPEECH_AVAILABLE = False + torch = None + torchaudio = None + + +_FISH_SAMPLE_RATE = 44_100 +_DEFAULT_SOURCE_ROOT = Path("models/fish-speech") + + +@_register_backend("fish_speech") +class FishSpeechEngine(TTSEngine): + """Fish Speech 1.5 backend supporting English and Russian zero-shot TTS.""" + + sample_rate: int = _FISH_SAMPLE_RATE + + def __init__( + self, + checkpoint_path: Path | str | None = None, + source_root: Path | str | None = None, + sample_rate: int | None = None, + device: str | None = None, + precision: torch.dtype | None = None, + compile: bool | None = None, + use_memory_cache: str | None = None, + chunk_length: int | None = None, + top_p: float | None = None, + temperature: float | None = None, + repetition_penalty: float | None = None, + seed: int | None = None, + tail_silence_threshold: float | None = None, + lowpass_cutoff: int | None = None, + ): + super().__init__() + + if not FISH_SPEECH_AVAILABLE: + raise RuntimeError( + "Fish Speech backend requires torch/torchaudio and the Fish Speech " + "source tree at models/fish-speech" + ) + + self.sample_rate = sample_rate or settings.tts_sample_rate + self.device = device or settings.device + self.precision = precision or ( + torch.bfloat16 if settings.dtype == "bfloat16" else torch.float32 + ) + self.compile = compile if compile is not None else settings.fish_compile + self.use_memory_cache = use_memory_cache or settings.fish_use_memory_cache + self.chunk_length = chunk_length or settings.fish_chunk_length + self.top_p = top_p if top_p is not None else settings.fish_top_p + self.temperature = temperature if temperature is not None else settings.fish_temperature + self.repetition_penalty = repetition_penalty if repetition_penalty is not None else settings.fish_repetition_penalty + self.seed = seed if seed is not None else settings.fish_seed + self.tail_silence_threshold = tail_silence_threshold if tail_silence_threshold is not None else settings.fish_tail_silence_threshold + self.lowpass_cutoff = lowpass_cutoff if lowpass_cutoff is not None else settings.fish_lowpass_cutoff + self.repetition_penalty = repetition_penalty + self.seed = seed + self.tail_silence_threshold = tail_silence_threshold + self.lowpass_cutoff = lowpass_cutoff + + self.source_root = Path( + source_root or settings.tts_vocab_path or _DEFAULT_SOURCE_ROOT + ) + self.checkpoint_path = Path( + checkpoint_path or settings.tts_model_path or "models/fishaudio_fish-speech-1.5" + ) + + self._llama_queue = None + self._decoder = None + self._engine = None + self._loaded = False + # Cache reference audio bytes/text per path to avoid repeated disk reads. + self._ref_cache: dict[Path, tuple[bytes, str]] = {} + + def _ensure_source_path(self) -> None: + """Make sure the cloned Fish Speech source is on sys.path.""" + root = str(self.source_root.resolve()) + if root not in sys.path: + sys.path.insert(0, root) + + def _is_supported_language(self, language: str) -> bool: + # Fish Speech 1.5 is multilingual; we expose English and Russian. + return language.lower() in {"en", "ru"} + + def load(self) -> None: + if self._loaded: + return + + self._ensure_source_path() + + from fish_speech.inference_engine import TTSInferenceEngine + from fish_speech.models.text2semantic.inference import ( + launch_thread_safe_queue, + ) + from fish_speech.models.vqgan.inference import load_model as load_decoder_model + + logger.info( + "Loading Fish Speech 1.5 from {} (source: {}) ...", + self.checkpoint_path, + self.source_root, + ) + + llama_checkpoint = self.checkpoint_path / "model.pth" + decoder_checkpoint = ( + self.checkpoint_path + / "firefly-gan-vq-fsq-8x1024-21hz-generator.pth" + ) + + self._llama_queue = launch_thread_safe_queue( + checkpoint_path=str(self.checkpoint_path), + device=self.device, + precision=self.precision, + compile=self.compile, + ) + + self._decoder = load_decoder_model( + config_name="firefly_gan_vq", + checkpoint_path=str(decoder_checkpoint), + device=self.device, + ) + + self._engine = TTSInferenceEngine( + llama_queue=self._llama_queue, + decoder_model=self._decoder, + precision=self.precision, + compile=self.compile, + ) + + self.sample_rate = self._decoder.spec_transform.sample_rate + self._loaded = True + logger.info( + "Fish Speech 1.5 loaded. Output sample rate: {}", self.sample_rate + ) + + async def warm_up(self) -> None: + if not self._loaded: + self.load() + logger.info("Fish Speech warm-up skipped; first synthesis will warm the cache.") + + def _normalize_language(self, language: str) -> str: + language = language.lower() + if language.startswith("ru"): + return "ru" + if language.startswith("en"): + return "en" + return language + + async def synthesize( + self, + text: str, + ref_audio_path: Path | None, + language: str, + speed: float, + emotion: str, + ref_text: str | None = None, + ) -> np.ndarray: + if not self._loaded: + self.load() + + if isinstance(ref_audio_path, str): + ref_audio_path = Path(ref_audio_path) + + if ref_audio_path is None: + ref_audio_path = settings.default_voice_ref + if ref_audio_path is None: + raise ValueError("Fish Speech requires a reference audio file (voice_ref).") + if not ref_audio_path.exists(): + raise FileNotFoundError(f"Reference audio not found: {ref_audio_path}") + + lang = self._normalize_language(language) + if not self._is_supported_language(lang): + raise ValueError( + f"Language '{language}' is not supported by Fish Speech backend. " + "Currently enabled languages: en, ru." + ) + + from fish_speech.utils.schema import ServeReferenceAudio, ServeTTSRequest + + # Cache reference audio bytes to avoid repeated disk reads. The transcript + # is cheap to resolve, so we re-evaluate fallback precedence every call. + cached_audio = self._ref_cache.get(ref_audio_path) + if cached_audio is None: + ref_audio = ref_audio_path.read_bytes() + self._ref_cache[ref_audio_path] = ref_audio + else: + ref_audio = cached_audio + + # Reference transcript precedence: caller-provided > .lab > settings > placeholder. + if ref_text: + pass + else: + ref_text_path = ref_audio_path.with_suffix(".lab") + if ref_text_path.exists(): + ref_text = ref_text_path.read_text(encoding="utf-8").strip() + elif settings.default_ref_text: + ref_text = settings.default_ref_text + else: + ref_text = ( + "Hello, this is a reference voice recording." + if lang == "en" + else "Здравствуйте. Это тестовая запись голоса." + ) + + req = ServeTTSRequest( + text=text, + references=[ + ServeReferenceAudio(audio=ref_audio, text=ref_text) + ], + seed=self.seed, + top_p=self.top_p, + temperature=self.temperature, + repetition_penalty=self.repetition_penalty, + chunk_length=self.chunk_length, + use_memory_cache=self.use_memory_cache, + ) + + segments = [] + for result in self._engine.inference(req): + logger.debug( + "Fish Speech inference result: code={} error={}", + result.code, + result.error, + ) + if result.code == "error": + error = result.error or RuntimeError("Unknown Fish Speech error") + raise RuntimeError(f"Fish Speech synthesis failed: {error}") + if result.audio is not None: + sr, audio = result.audio + segments.append(audio.astype(np.float32)) + + if not segments: + raise RuntimeError("Fish Speech produced no audio.") + + wav = np.concatenate(segments) + + # Normalize to [-1, 1]. + peak = np.max(np.abs(wav)) + if peak > 1.0: + wav = wav / peak + + # Gentle low-pass to reduce VQ codec noise. + if self.lowpass_cutoff > 0: + wt = torch.from_numpy(wav).unsqueeze(0) # (1, samples) + wt = torchaudio.functional.lowpass_biquad( + wt, self.sample_rate, self.lowpass_cutoff + ) + wav = wt.squeeze(0).numpy() + + # Apply speed adjustment via resampling. Fish Speech is already + # reasonably fast; the default settings.tts_speed allows global tuning. + effective_speed = speed if speed is not None else settings.tts_speed + if effective_speed != 1.0 and effective_speed > 0: + new_rate = int(self.sample_rate * (1.0 / effective_speed)) + resampler = torchaudio.transforms.Resample(self.sample_rate, new_rate) + wav = resampler(torch.from_numpy(wav)).numpy().astype(np.float32) + + if wav.ndim == 0: + wav = np.zeros(1, dtype=np.float32) + elif wav.ndim > 1: + wav = wav.squeeze() + + return wav diff --git a/src/voice_tts/tts/s2_backend.py b/src/voice_tts/tts/s2_backend.py new file mode 100644 index 0000000..804a4aa --- /dev/null +++ b/src/voice_tts/tts/s2_backend.py @@ -0,0 +1,140 @@ +"""Fish Audio S2-Pro backend — HTTP client to the local S2 API server.""" + +import io +from pathlib import Path + +import numpy as np +import soundfile as sf +from loguru import logger + +from voice_tts.config import settings +from voice_tts.tts import register as _register_backend +from voice_tts.tts.engine import TTSEngine + +try: + import requests + + S2_AVAILABLE = True +except ImportError: + S2_AVAILABLE = False + requests = None + +_S2_SAMPLE_RATE = 44_100 + + +@_register_backend("s2") +class S2Engine(TTSEngine): + """Fish Audio S2-Pro backend that delegates to a local S2 API server. + + The S2 API server must be running separately on ``api_url``. + The reference audio is uploaded once and reused via ``reference_id``. + """ + + sample_rate: int = _S2_SAMPLE_RATE + + def __init__(self): + super().__init__() + if not S2_AVAILABLE: + raise RuntimeError("S2 backend requires the 'requests' package") + + self.api_url = settings.s2_api_url.rstrip("/") + self.sample_rate = settings.tts_sample_rate + self._ref_uploaded = False + + # ------------------------------------------------------------------ + # Internal helpers + # ------------------------------------------------------------------ + + def _upload_reference(self, ref_audio_path: Path) -> None: + """Upload the reference audio to the S2 server. + + If a reference with id ``default`` already exists we try to delete it + first via HTTP, then fall back to removing its directory from the + filesystem (the S2 server stores references in ``references//`` + relative to its working directory). + """ + ref_text = settings.default_ref_text or "" + + # Try HTTP delete first (works with most setups). + deleted = False + try: + r = requests.delete( + f"{self.api_url}/v1/references/delete", + json={"reference_id": "default"}, + timeout=10, + ) + deleted = r.status_code in (200, 404) + except Exception: + pass + + # If HTTP delete didn't work, try removing the directory directly. + if not deleted: + candidate = Path("models/fish-speech/references/default") + if candidate.exists(): + import shutil + shutil.rmtree(candidate) + logger.info("Removed stale reference directory: {}", candidate) + + with open(ref_audio_path, "rb") as fh: + resp = requests.post( + f"{self.api_url}/v1/references/add", + data={"id": "default", "text": ref_text}, + files={"audio": ("ref.wav", fh, "audio/wav")}, + timeout=30, + ) + if resp.status_code == 200: + logger.info("Reference 'default' uploaded to S2 server") + self._ref_uploaded = True + else: + raise RuntimeError( + f"S2 reference upload failed: {resp.status_code} {resp.text}" + ) + + # ------------------------------------------------------------------ + # TTSEngine interface + # ------------------------------------------------------------------ + + async def synthesize( + self, + text: str, + ref_audio_path: Path | None, + language: str, + speed: float, + emotion: str, + ref_text: str | None = None, + ) -> np.ndarray: + if ref_audio_path is not None and not self._ref_uploaded: + self._upload_reference(ref_audio_path) + + payload = { + "text": text, + "reference_id": "default", + "format": "wav", + "chunk_length": 300, + } + + resp = requests.post( + f"{self.api_url}/v1/tts", + json=payload, + timeout=120, + ) + if resp.status_code != 200: + raise RuntimeError(f"S2 TTS request failed: {resp.status_code} {resp.text}") + + buf = io.BytesIO(resp.content) + audio, sr = sf.read(buf) + self.sample_rate = int(sr) + return audio.astype(np.float32) + + async def warm_up(self) -> None: + """Check that the S2 API server is alive.""" + try: + resp = requests.get(f"{self.api_url}/v1/health", timeout=5) + if resp.status_code != 200: + raise RuntimeError(f"S2 server health check failed: {resp.status_code}") + logger.info("S2 API server is healthy at {}", self.api_url) + except requests.ConnectionError as exc: + raise RuntimeError( + f"S2 API server not reachable at {self.api_url}. " + "Make sure the S2 server is running." + ) from exc diff --git a/src/voice_tts/tts/segmenter.py b/src/voice_tts/tts/segmenter.py index 26315a6..790a1f7 100644 --- a/src/voice_tts/tts/segmenter.py +++ b/src/voice_tts/tts/segmenter.py @@ -11,15 +11,24 @@ class Segmenter: - """Splits streaming text into TTS-ready segments.""" + """Splits streaming text into TTS-ready segments. + + Supports progressive chunking: the first few segments use a lower + ``min_length`` so audio starts sooner (lower initial latency). + """ def __init__( self, min_length: int = 30, max_length: int = 200, + fast_start_initial: int = 12, + fast_start_count: int = 3, ): self.min_length = min_length self.max_length = max_length + self._fast_start_initial = fast_start_initial + self._fast_start_count = fast_start_count + self._segments_emitted = 0 # End-of-sentence delimiters self.sentence_breaks = re.compile(r"[.。!??!\n]+") @@ -27,6 +36,20 @@ self.clause_breaks = re.compile(r"[,;:\-—()()]") self.whitespace_re = re.compile(r"\s+") + def _effective_min_length(self) -> int: + """Minimum length before a segment is considered 'ready'. + + Gradually ramps up from ``fast_start_initial`` to ``min_length`` + over the first ``fast_start_count`` segments. + """ + if self._segments_emitted < self._fast_start_count: + fraction = self._segments_emitted / max(self._fast_start_count - 1, 1) + return int( + self._fast_start_initial + + (self.min_length - self._fast_start_initial) * fraction + ) + return self.min_length + def feed(self, buffer: str) -> tuple[str, list[Segment]]: """ Consume `buffer` and return (remaining_buffer, ready_segments). @@ -34,13 +57,18 @@ Logic: - Prefer cutting at sentence boundaries. - If a segment grows beyond max_length, cut at the nearest clause boundary. - - Segments shorter than min_length are returned only if the caller forces flush - or the input is shorter than min_length but ends with a sentence break. + - Segments shorter than the effective min_length are returned only if + the caller forces flush or the input is shorter but ends with a + sentence break. + - Progressive chunking: early segments use a lower threshold so audio + starts sooner. """ segments: list[Segment] = [] remaining = buffer while remaining: + min_len = self._effective_min_length() + # Find the first sentence boundary. first_sentence_cut = -1 for match in self.sentence_breaks.finditer(remaining): @@ -52,7 +80,7 @@ segment_text = remaining[:first_sentence_cut].strip() # Case 1: sentence is long enough -> emit immediately. - if len(segment_text) >= self.min_length: + if len(segment_text) >= min_len: remaining = remaining[first_sentence_cut:].lstrip() if segment_text: segments.append(Segment(text=segment_text, is_final=True)) @@ -72,7 +100,7 @@ break combined_text = remaining[:next_boundary].strip() combined_cut = next_boundary - if len(combined_text) >= self.min_length: + if len(combined_text) >= min_len: remaining = remaining[combined_cut:].lstrip() if combined_text: segments.append(Segment(text=combined_text, is_final=True)) @@ -96,7 +124,7 @@ last_clause = -1 for match in self.clause_breaks.finditer(window): pos = match.end() - if pos >= self.min_length: + if pos >= min_len: last_clause = pos if last_clause != -1: segment_text = remaining[:last_clause].strip() @@ -108,6 +136,7 @@ # Nothing to cut yet; wait for more text. break + self._segments_emitted += len(segments) return remaining, segments def flush(self, buffer: str) -> list[Segment]: diff --git a/src/voice_tts/tts/utils.py b/src/voice_tts/tts/utils.py index 099ee8e..5035268 100644 --- a/src/voice_tts/tts/utils.py +++ b/src/voice_tts/tts/utils.py @@ -10,20 +10,90 @@ "。", "!", "?", ";", ":", } +# Emoji range: all Unicode emoji blocks +_EMOJI_PATTERN = re.compile( + "[" + "\U0001f600-\U0001f64f" # emoticons + "\U0001f300-\U0001f5ff" # symbols & pictographs + "\U0001f680-\U0001f6ff" # transport & map symbols + "\U0001f1e0-\U0001f1ff" # flags (iOS) + "\U00002600-\U000027BF" # misc symbols, dingbats + "\U0001f900-\U0001f9ff" # supplemental symbols + "\U0001fa00-\U0001fa6f" # chess symbols + "\U0001fa70-\U0001faff" # symbols extended-A + "\U00002702-\U000027B0" # dingbats + "\U000024C2-\U00002500" # enclosed / geometric shapes + "\U00002B05-\U00002B55" # arrows + "\U0001d300-\U0001d7ff" # musical symbols, etc. + "\U0001f000-\U0001f02f" # mahjong tiles + "\U0001f030-\U0001f09f" # domino tiles + "\U00002100-\U0000214f" # letterlike symbols + "\U0001f0a0-\U0001f0ff" # playing cards + "\U0001f600-\U0001f64f" # emoticons (duplicate range for safety) + "\U0000FE00-\U0000FE0F" # variation selectors + "\U0000FE20-\U0000FE23" # combining half marks + "\U0000200D" # zero-width joiner + "\U0000200C" # zero-width non-joiner + "]+" +) + +_MARKDOWN_PATTERN = re.compile( + r"```[\s\S]*?```" # fenced code blocks + r"|`[^`\n]+`" # inline code + r"|\[([^\]]+)\]\([^)]+\)" # markdown links → keep link text + r"|!\[([^\]]*)\]\([^)]+\)" # markdown images + r"|^#{1,6}\s+" # headings + r"|^>+\s+" # blockquotes + r"|^\s*[-*+]\s+(?![-*+])" # unordered lists + r"|^\s*\d+[.)]\s+" # ordered lists + r"|^\s*\|.*\|" # tables + r"|^[-=]{3,}\s*$" # horizontal rules + r"|(?:^|(?<=\s))\*{1,3}(?=\S)" # leading bold/italic + r"|(?<=\S)\*{1,3}(?=\s|$)" # trailing bold/italic + r"|(?:^|(?<=\s))_{1,3}(?=\S)" # leading underline emphasis + r"|(?<=\S)_{1,3}(?=\s|$)" # trailing underline emphasis + r"|~~(.*?)~~" # strikethrough + r"|(?:^|(?<=\s))~{1,3}(?=\S)" # leading strikethrough marker + r"|(?<=\S)~{1,3}(?=\s|$)" # trailing strikethrough marker +, re.MULTILINE) + +_URL_PATTERN = re.compile(r"https?://[^\s<>\"']+|www\.[^\s<>\"']+") + +_HTML_PATTERN = re.compile(r"<[^>]+>") + +_SPECIAL_SYMBOLS = re.compile(r"[^\w\s.,!?;:\-—\"'()«»„“”‘’…\n]") + def normalize_whitespace(text: str) -> str: """Collapse repeated whitespace and strip edges, preserving single spaces.""" return re.sub(r"\s+", " ", text).strip() +def clean_text_for_tts(text: str) -> str: + """Remove characters that TTS should never pronounce.""" + text = re.sub(r"[\x00-\x08\x0b\x0c\x0e-\x1f\x7f]", "", text) + # Remove emojis + text = _EMOJI_PATTERN.sub("", text) + # Remove HTML tags + text = _HTML_PATTERN.sub("", text) + # Remove markdown formatting (keep link/image/strikethrough text) + text = _MARKDOWN_PATTERN.sub( + lambda m: next((g for g in m.groups() if g is not None), ""), text + ) + # Remove URLs + text = _URL_PATTERN.sub("", text) + # Remove special Unicode symbols not used in normal text + text = _SPECIAL_SYMBOLS.sub("", text) + return normalize_whitespace(text) + + def preprocess_text_for_tts(text: str) -> str: """ - Minimal cleanup before TTS. + Clean text before TTS synthesis. + - Remove control characters, emojis, HTML, markdown, URLs, special symbols. - Collapse whitespace. - - Remove control characters. """ - text = re.sub(r"[\x00-\x08\x0b\x0c\x0e-\x1f\x7f]", "", text) - return normalize_whitespace(text) + return clean_text_for_tts(text) def has_sentence_ending(text: str) -> bool: diff --git a/src/voice_tts/tts/xtts_backend.py b/src/voice_tts/tts/xtts_backend.py new file mode 100644 index 0000000..150da9d --- /dev/null +++ b/src/voice_tts/tts/xtts_backend.py @@ -0,0 +1,205 @@ +"""XTTS-v2 backend for local GPU inference with zero-shot voice cloning.""" + +from pathlib import Path + +import numpy as np +from loguru import logger + +from voice_tts.config import settings +from voice_tts.tts import register as _register_backend +from voice_tts.tts.engine import TTSEngine + +try: + import torch + import torchaudio + from TTS.api import TTS + import TTS.utils.io as tts_io + + XTTS_AVAILABLE = True +except ImportError as exc: + logger.warning("TTS/torch dependencies not available: {}", exc) + XTTS_AVAILABLE = False + torch = None + torchaudio = None + + +_XTTS_SAMPLE_RATE = 24_000 + + +def _patch_weights_only() -> None: + """XTTS-v2 checkpoints contain legacy pickle classes; allow full load. + + PyTorch 2.6+ defaults ``torch.load(..., weights_only=True)``. XTTS-v2 + checkpoints require several Coqui classes to be in the safe-globals list + in addition to forcing ``weights_only=False``. This function registers + those classes globally and patches TTS's fsspec loader to default to the + legacy behavior. + """ + if not XTTS_AVAILABLE: + return + + if not hasattr(tts_io, "load_fsspec"): + return + + # PyTorch 2.6+ safe-global allow-list for XTTS-v2 checkpoint pickles. + if hasattr(torch, "serialization") and hasattr( + torch.serialization, "add_safe_globals" + ): + from TTS.config import shared_configs as _shared_configs + from TTS.tts.configs.xtts_config import XttsConfig as _XttsConfig + from TTS.tts.models.xtts import XttsArgs as _XttsArgs + from TTS.tts.models.xtts import XttsAudioConfig as _XttsAudioConfig + + for _cls in ( + _shared_configs.BaseDatasetConfig, + _XttsConfig, + _XttsArgs, + _XttsAudioConfig, + ): + try: + torch.serialization.add_safe_globals([_cls]) + except Exception: + pass + + _orig = tts_io.load_fsspec + + def _patched(model_path: str, map_location: str = "cpu", **kwargs): + kwargs.setdefault("weights_only", False) + return _orig(model_path, map_location=map_location, **kwargs) + + tts_io.load_fsspec = _patched # type: ignore[assignment] + + +@_register_backend("xtts_v2") +class XTTSv2Engine(TTSEngine): + """XTTS-v2 backend supporting English and Russian zero-shot voice cloning.""" + + sample_rate: int = _XTTS_SAMPLE_RATE + + def __init__( + self, + model_name: str | None = None, + sample_rate: int | None = None, + device: str | None = None, + gpu: bool | None = None, + ): + super().__init__() + self.model_name = model_name or settings.tts_model_name + self.sample_rate = sample_rate or settings.tts_sample_rate + self.device = device or settings.device + self.gpu = (gpu if gpu is not None else True) and self.device != "cpu" + + self._model: "TTS | None" = None + + def _is_supported_language(self, language: str) -> bool: + # XTTS-v2 supported languages: en, es, fr, de, it, pt, pl, tr, ru, nl, cs, + # ar, zh-cn, hu, ko, ja, hi. We currently expose en and ru. + return language.lower() in {"en", "ru"} + + def load(self) -> None: + if not XTTS_AVAILABLE: + raise RuntimeError( + "coqui TTS package is not installed. Install it: pip install TTS" + ) + + logger.info("Loading XTTS-v2 model {} ...", self.model_name) + _patch_weights_only() + + # Environment flag required by Coqui to download XTTS. + import os + + os.environ.setdefault("COQUI_TOS_AGREED", "1") + + self._model = TTS(self.model_name, gpu=self.gpu) + # Force the requested device if not already there. + if self.device.startswith("cuda"): + self._model = self._model.to("cuda") + elif self.device == "cpu": + self._model = self._model.to("cpu") + + self.sample_rate = self._model.synthesizer.output_sample_rate or _XTTS_SAMPLE_RATE + logger.info("XTTS-v2 loaded. Output sample rate: {}", self.sample_rate) + + async def warm_up(self) -> None: + if self._model is None: + self.load() + logger.info("Warm-up skipped: provide a reference audio before warm-up.") + + def _normalize_language(self, language: str) -> str: + language = language.lower() + if language.startswith("ru"): + return "ru" + if language.startswith("en"): + return "en" + return language + + async def synthesize( + self, + text: str, + ref_audio_path: Path | None, + language: str, + speed: float, + emotion: str, + ref_text: str | None = None, + ) -> np.ndarray: + if self._model is None: + self.load() + + assert self._model is not None + + if isinstance(ref_audio_path, str): + ref_audio_path = Path(ref_audio_path) + + if ref_audio_path is None: + raise ValueError("XTTS-v2 requires a reference audio file (voice_ref).") + if not ref_audio_path.exists(): + raise FileNotFoundError(f"Reference audio not found: {ref_audio_path}") + + lang = self._normalize_language(language) + if not self._is_supported_language(lang): + raise ValueError( + f"Language '{language}' is not supported by XTTS-v2. " + "Currently enabled languages: en, ru." + ) + + # Speed is not a direct argument for XTTS.tts_to_file, but we can resample + # the audio to approximate it. We keep the produced samples and stretch them. + out_path = "/tmp/opencode/xtts_synth_tmp.wav" + self._model.tts_to_file( + text=text, + speaker_wav=str(ref_audio_path), + language=lang, + file_path=out_path, + ) + + wav, sr = torchaudio.load(out_path) + wav = wav.mean(dim=0).numpy().astype(np.float32) + + # Resample if the model returned a different rate. + if sr != self.sample_rate: + resampler = torchaudio.transforms.Resample(sr, self.sample_rate) + wav = resampler(torch.from_numpy(wav)).numpy().astype(np.float32) + + # Normalize to [-1, 1]. + peak = np.max(np.abs(wav)) + if peak > 1.0: + wav = wav / peak + + # Approximate speed change by resampling. Simple resampling changes + # pitch slightly; for small adjustments around 1.0 it is acceptable, + # but for larger factors we keep the duration change while reducing + # pitch shift artifacts via phase vocoder would be too expensive. + # The default speed comes from settings.tts_speed and can be overridden + # per-request via the WebSocket config/speak messages. + effective_speed = speed if speed is not None else settings.tts_speed + if effective_speed != 1.0 and effective_speed > 0: + new_rate = int(self.sample_rate * (1.0 / effective_speed)) + resampler = torchaudio.transforms.Resample(self.sample_rate, new_rate) + wav = resampler(torch.from_numpy(wav)).numpy().astype(np.float32) + + if wav.ndim == 0: + wav = np.zeros(1, dtype=np.float32) + elif wav.ndim > 1: + wav = wav.squeeze() + + return wav diff --git a/tests/test_fish_speech_backend.py b/tests/test_fish_speech_backend.py new file mode 100644 index 0000000..ec356e0 --- /dev/null +++ b/tests/test_fish_speech_backend.py @@ -0,0 +1,233 @@ +"""Unit tests for the Fish Speech 1.5 backend with mocked inference engine.""" + +from pathlib import Path + +import numpy as np +import pytest +import torch + +from voice_tts.config import settings +from voice_tts.tts.fish_speech_backend import FishSpeechEngine + + +@pytest.fixture +def mocked_fish(monkeypatch, tmp_path): + """Return a FishSpeechEngine with all heavy model imports stubbed.""" + + fake_root = tmp_path / "fish-speech" + fake_root.mkdir() + (fake_root / "fish_speech").mkdir() + (fake_root / "fish_speech" / "__init__.py").write_text("") + (fake_root / "fish_speech" / "inference_engine.py").write_text( + "class TTSInferenceEngine:\n pass\n" + ) + (fake_root / "fish_speech" / "inference_engine").mkdir() + (fake_root / "fish_speech" / "inference_engine" / "__init__.py").write_text( + "class TTSInferenceEngine:\n pass\n" + ) + (fake_root / "fish_speech" / "models").mkdir() + (fake_root / "fish_speech" / "models" / "__init__.py").write_text("") + (fake_root / "fish_speech" / "models" / "text2semantic").mkdir() + (fake_root / "fish_speech" / "models" / "text2semantic" / "__init__.py").write_text("") + (fake_root / "fish_speech" / "models" / "text2semantic" / "inference.py").write_text( + "def launch_thread_safe_queue(*args, **kwargs):\n" + " return object()\n" + ) + vqgan = fake_root / "fish_speech" / "models" / "vqgan" + vqgan.mkdir() + (vqgan / "__init__.py").write_text("") + (vqgan / "inference.py").write_text( + "def load_model(*args, **kwargs):\n" + " class FakeDecoder:\n" + " spec_transform = type('T', (), {'sample_rate': 44100})()\n" + " def encode(self, audios, lengths):\n" + " return [None]\n" + " return FakeDecoder()\n" + ) + (fake_root / "fish_speech" / "utils").mkdir() + (fake_root / "fish_speech" / "utils" / "__init__.py").write_text("") + (fake_root / "fish_speech" / "utils" / "schema.py").write_text( + "from pydantic import BaseModel\n" + "from typing import List\n" + "class ServeReferenceAudio(BaseModel):\n" + " audio: bytes\n" + " text: str\n" +"class ServeTTSRequest(BaseModel):\n" +" text: str\n" +" references: List[ServeReferenceAudio] = []\n" +" seed: int | None = None\n" +" top_p: float = 0.7\n" +" temperature: float = 0.7\n" +" repetition_penalty: float = 1.2\n" +" max_new_tokens: int = 1024\n" +" chunk_length: int = 200\n" +" use_memory_cache: str = 'on'\n" + ) + + # Stub the heavy model functions used inside load(). + class FakeEngine: + def __init__(self, *args, **kwargs): + pass + + def inference(self, req): + sample_rate = 44_100 + audio = np.linspace(-0.9, 0.9, sample_rate, dtype=np.float32) + yield type( + "R", (), {"code": "final", "error": None, "audio": (sample_rate, audio)} + )() + + # Point sys.path at fake root temporarily so imports succeed. + monkeypatch.syspath_prepend(str(fake_root)) + + from fish_speech import inference_engine as _ie + + monkeypatch.setattr(_ie, "TTSInferenceEngine", FakeEngine) + + ref_wav = tmp_path / "ref.wav" + ref_wav.write_bytes(b"RIFF" + b"\x00" * 40) + ref_lab = tmp_path / "ref.lab" + ref_lab.write_text("Reference transcript from lab file.", encoding="utf-8") + + engine = FishSpeechEngine( + checkpoint_path=tmp_path / "checkpoint", + source_root=fake_root, + device="cpu", + precision=torch.float32, + chunk_length=200, + ) + engine.load() + return engine, ref_wav, ref_lab + + +@pytest.mark.asyncio +async def test_synthesize_returns_audio(mocked_fish): + engine, ref_wav, _ = mocked_fish + audio = await engine.synthesize( + text="Hello world.", + ref_audio_path=ref_wav, + language="en", + speed=1.0, + emotion="neutral", + ) + assert audio.ndim == 1 + assert len(audio) == engine.sample_rate + assert audio.dtype == np.float32 + assert np.max(np.abs(audio)) <= 1.1 + + +@pytest.mark.asyncio +async def test_unsupported_language_raises(mocked_fish): + engine, ref_wav, _ = mocked_fish + with pytest.raises(ValueError, match="not supported"): + await engine.synthesize( + text="Bonjour.", + ref_audio_path=ref_wav, + language="fr", + speed=1.0, + emotion="neutral", + ) + + +@pytest.mark.asyncio +async def test_missing_reference_raises(mocked_fish): + engine, _, _ = mocked_fish + with pytest.raises(FileNotFoundError): + await engine.synthesize( + text="Hello.", + ref_audio_path=Path("/nonexistent/ref.wav"), + language="en", + speed=1.0, + emotion="neutral", + ) + + +@pytest.mark.asyncio +async def test_ref_text_fallback_order(monkeypatch, mocked_fish): + engine, ref_wav, ref_lab = mocked_fish + + captured = {} + + def _spy_inference(req): + captured["text"] = req.references[0].text + sample_rate = 44_100 + audio = np.zeros(sample_rate, dtype=np.float32) + yield type( + "R", (), {"code": "final", "error": None, "audio": (sample_rate, audio)} + )() + + # Replace the already-created engine's inference method with a spy. + engine._engine.inference = _spy_inference + + # 1. Explicit ref_text wins. + await engine.synthesize( + text="Hi.", + ref_audio_path=ref_wav, + language="en", + speed=1.0, + emotion="neutral", + ref_text="explicit", + ) + assert captured["text"] == "explicit" + + # 2. .lab file next to reference is used when no explicit text. + await engine.synthesize( + text="Hi.", + ref_audio_path=ref_wav, + language="en", + speed=1.0, + emotion="neutral", + ) + assert captured["text"] == "Reference transcript from lab file." + + # 3. settings.default_ref_text wins over placeholder. + monkeypatch.setattr(settings, "default_ref_text", "settings default") + ref_lab.unlink(missing_ok=True) + await engine.synthesize( + text="Hi.", + ref_audio_path=ref_wav, + language="en", + speed=1.0, + emotion="neutral", + ) + assert captured["text"] == "settings default" + + # 4. Placeholder when nothing else is available. + monkeypatch.setattr(settings, "default_ref_text", None) + await engine.synthesize( + text="Hi.", + ref_audio_path=ref_wav, + language="en", + speed=1.0, + emotion="neutral", + ) + assert captured["text"] == "Hello, this is a reference voice recording." + + +@pytest.mark.asyncio +async def test_speed_resampling_changes_length(mocked_fish): + engine, ref_wav, _ = mocked_fish + slow = await engine.synthesize( + text="Hello.", + ref_audio_path=ref_wav, + language="en", + speed=1.5, + emotion="neutral", + ) + fast = await engine.synthesize( + text="Hello.", + ref_audio_path=ref_wav, + language="en", + speed=0.8, + emotion="neutral", + ) + assert len(fast) > len(slow) + # Speed 1.5 should produce fewer samples than original; 0.8 should produce more. + normal = await engine.synthesize( + text="Hello.", + ref_audio_path=ref_wav, + language="en", + speed=1.0, + emotion="neutral", + ) + assert len(slow) < len(normal) + assert len(fast) > len(normal)