diff --git a/.env.example b/.env.example new file mode 100644 index 0000000..bc9729f --- /dev/null +++ b/.env.example @@ -0,0 +1,20 @@ +# Local TTS pipeline configuration example +# Copy this file to .env and adjust values. + +HOST=0.0.0.0 +PORT=8765 +LOG_LEVEL=INFO + +TTS_BACKEND=f5_tts +# TTS_MODEL_PATH=models/f5-tts/model.pt +# TTS_VOCAB_PATH=models/f5-tts/vocab.txt +TTS_SAMPLE_RATE=24000 + +VOICES_DIR=voices + +MIN_SEGMENT_LENGTH=30 +MAX_SEGMENT_LENGTH=200 +MAX_BUFFER_WAIT_MS=500 + +DEVICE=cuda +DTYPE=bfloat16 diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..b59e4c4 --- /dev/null +++ b/.gitignore @@ -0,0 +1,13 @@ +.env +.venv/ +__pycache__/ +*.pyc +models/* +voices/*.wav +voices/*.mp3 +!voices/.gitkeep +.pytest_cache/ +*.egg-info/ +dist/ +build/ +.DS_Store diff --git a/README.md b/README.md index 2071957..e53e6c4 100644 --- a/README.md +++ b/README.md @@ -1,2 +1,57 @@ -voice -=============== +# Voice TTS + +Local GPU-powered real-time text-to-speech pipeline with a WebSocket API, designed to voice AI agents that stream text in chunks. + +## Features + +- **Streaming input**: accepts partial text as it is generated by the LLM/agent. +- **Streaming output**: returns PCM audio chunks over WebSocket as soon as they are synthesized. +- **Voice cloning**: single speaker cloned from reference audio, with optional per-emotion references. +- **Interrupt / stop**: agent can immediately stop playback when the user interrupts the AI. +- **Emotion control**: switch emotion on the fly (requires matching reference audio or supported backend). +- **Local GPU**: runs entirely on your NVIDIA GPU (RTX 3090 / 3060 compatible). + +## Project status + +- Working WebSocket server with streaming text, audio streaming, and instant stop/resume. +- F5-TTS backend installed, GPU-ready, and producing real audio (`models/F5TTS_v1_Base/` downloaded). +- Dummy backend available for fast offline tests. +- Startup warm-up caches the default reference and primes CUDA. +- Next: multilingual evaluation, latency optimization, and client examples. + +## Quick start + +```bash +# Create virtual environment (Python 3.10-3.12 recommended) +python3.11 -m venv .venv +source .venv/bin/activate + +# Install PyTorch with CUDA 12.6 support first +pip install torch torchaudio --index-url https://download.pytorch.org/whl/cu126 + +# Install remaining dependencies +pip install -r requirements.txt + +# (Optional) Download the F5-TTS model beforehand +python scripts/download_f5_tts.py --model F5TTS_v1_Base + +# Run the server +python -m voice_tts.main + +# Or run in dummy test mode +TTS_BACKEND=dummy python -m voice_tts.main +``` + +Server will listen on `ws://localhost:8765/ws`. + +## WebSocket protocol + +See full documentation in [`docs/03_websocket_protocol.md`](docs/03_websocket_protocol.md). + +## Architecture and roadmap + +- [`docs/01_overview.md`](docs/01_overview.md) +- [`docs/02_architecture.md`](docs/02_architecture.md) +- [`docs/04_roadmap.md`](docs/04_roadmap.md) +- [`docs/05_usage.md`](docs/05_usage.md) +- [`docs/06_technical_notes.md`](docs/06_technical_notes.md) diff --git a/docs/01_overview.md b/docs/01_overview.md new file mode 100644 index 0000000..57591e9 --- /dev/null +++ b/docs/01_overview.md @@ -0,0 +1,79 @@ +# Обзор проекта Voice TTS + +## Цель + +Построить локальный GPU-пайплайн преобразования текста в речь (TTS) с интонацией, управляемый через WebSocket API. Система предназначена для озвучки ИИ-агента: текст поступает от LLM кусками в процессе генерации, а аудио отдаётся клиенту в реальном времени. + +## Ключевые требования + +- **Локальная работа на GPU**: весь инференс выполняется на видеокарте пользователя (RTX 3090 сейчас, RTX 3060 в перспективе). +- **Стриминг текста**: фразы от LLM приходят частями, сервер не ждёт полного текста. +- **Стриминг аудио**: синтезированные фрагменты отправляются клиенту сразу после готовности. +- **Клонирование голоса**: один спикер, задаваемый референсным аудио. +- **Эмоции**: возможность переключать эмоцию агента по сигналу от ИИ. +- **Прерывание**: агент может немедленно остановить текущее вещание, если его перебили. +- **Мультиязычность**: приоритет — русский; поддерживаются также английский, украинский, испанский, немецкий, французский. +- **Клиент**: Python-агент, сам управляющий дальнейшей маршрутизацией аудиопотока. + +## Решения, принятые на этапе планирования + +### Стек + +| Компонент | Выбор | +|-----------|-------| +| Язык | Python | +| WebSocket сервер | FastAPI + `uvicorn[standard]` | +| TTS | F5-TTS (основной), с возможным fallback на MeloTTS | +| Аудио | torchaudio, numpy | +| Конфиг | pydantic-settings | +| Логи | loguru | + +### Почему F5-TTS + +- Высокое качество и естественная интонация. +- Быстрый инференс, подходит для реалтайма. +- Поддержка клонирования голоса по референсу (zero-shot). +- Мультиязычность из коробки (en, ru, zh), с потенциалом для европейских языков. +- Помещается в VRAM RTX 3060 (12 GB). + +### Формат аудио + +- **PCM 16-bit mono, 24 kHz**, упакованный в base64. +- Низкая задержка, простая декодировка на клиенте. +- Opus-энкодинг можно добавить позже как опцию. + +## Архитектура (верхний уровень) + +``` +[Python AI Agent] + | + | text chunks over WebSocket + v +[FastAPI WebSocket Server] + | + | buffered text + v +[Text Segmenter] + | + | TTS-ready segments + v +[TTS Queue] + | + v +[GPU TTS Worker] <-- F5-TTS + | + | audio ndarray + v +[Audio Output Queue] + | + | base64 PCM chunks + v +[Python AI Agent -> audio player / sink] +``` + +## Состояние репозитория + +- Реализован серверный каркас с WebSocket API, сегментатором, сессией и управлением прерыванием. +- Подключён **F5-TTS** как основной бэкенд: модель скачивается в `models/`, загружается на GPU и готова к инференсу. +- **Dummy TTS** остаётся для тестов (`TTS_BACKEND=dummy`). +- Следующий этап — подготовка референсных аудио, тёплый старт и замеры реальной латентности. diff --git a/docs/02_architecture.md b/docs/02_architecture.md new file mode 100644 index 0000000..1feb65f --- /dev/null +++ b/docs/02_architecture.md @@ -0,0 +1,101 @@ +# Архитектура системы + +## Компоненты + +### 1. WebSocket API (`src/voice_tts/api/`) + +- `server.py` — FastAPI-приложение, endpoint `/ws`, управление сессией. +- `protocol.py` — Pydantic-модели всех входящих и исходящих сообщений. + +Отвечает за: +- Приём подключений от Python-агента. +- Парсинг сообщений. +- Жизненный цикл сессии и её graceful shutdown. + +### 2. Сессия (`src/voice_tts/session/`) + +- `state.py` — `SessionState` и `VoiceProfile`. + +Хранит: +- Текстовый буфер. +- Текущий язык, скорость, эмоцию. +- Референсные аудио для спикера и эмоций. +- Очередь аудио-выхода. +- Флаг остановки (`stop_event`). + +### 3. Сегментатор (`src/voice_tts/tts/segmenter.py`) + +`Segmenter` превращает накопленный текст в сегменты для синтеза. + +Правила: +- Разбиение по концам предложений (`.`, `。`, `!`, `?`, `;`, `:` и newline). +- Если сегмент превышает `max_length`, разрез по запятой/тире/скобкам. +- Минимальная длина сегмента (`min_length`) не применяется при принудительном `flush`. + +### 4. TTS Engine (`src/voice_tts/tts/`) + +- `engine.py` — абстрактный базовый класс `TTSEngine` и `DummyTTSEngine` для тестов. +- `f5_backend.py` — реализация на F5-TTS, включая кэширование транскрибированных референсов. + +Интерфейс: + +```python +async def synthesize( + self, + text: str, + ref_audio_path: Path | None, + language: str, + speed: float, + emotion: str, +) -> np.ndarray: + ... +``` + +### 5. Аудио (`src/voice_tts/audio/`) + +- `formats.py` — преобразование float32 -> PCM16 -> base64, генерация WAV-заголовка. + +### 6. Конфигурация (`src/voice_tts/config.py`) + +`pydantic-settings`, загружает переменные из `.env`. + +## Потоки выполнения + +``` +WebSocket receiver (async) + | + v +SessionManager + | + +---> Text buffer / Segmenter + | | + | v + | _synthesize_segment (async task per segment) + | | + | v + | TTS Engine (GPU-bound, выполняется в отдельной задаче) + | | + | v + | audio_queue.put((segment_seq, pcm)) + | + +---> _audio_sender (async worker) + | + +---> WebSocket send(audio base64) + | + +---> WebSocket send(status: segment_finished) +``` + +Важно: все обращения к CUDA-инференсу сериализованы через один TTS worker, чтобы избежать contention и OOM. + +## Управление прерыванием + +Каждая сессия содержит `asyncio.Event stop_event`. + +При получении сообщения `stop`: +1. Устанавливается `stop_event`. +2. Очищается текстовый буфер. +3. Опустошается аудио-очередь. +4. Любая текущая/последующая генерация отменяется. +5. Клиенту отправляется `status: stopped`. + +Новый текст после `stop` автоматически сбрасывает `stop_event` и начинает новую фразу. diff --git a/docs/03_websocket_protocol.md b/docs/03_websocket_protocol.md new file mode 100644 index 0000000..284bf77 --- /dev/null +++ b/docs/03_websocket_protocol.md @@ -0,0 +1,167 @@ +# WebSocket протокол + +## Канал + +`ws://HOST:PORT/ws` (по умолчанию `ws://0.0.0.0:8765/ws`). + +Сообщения — JSON в текстовых WebSocket-фреймах. + +## Сообщения от клиента (Python AI Agent) к серверу + +### `init` — инициализация сессии + +```json +{ + "type": "init", + "session_id": "uuid-or-name", + "voice_ref": "voices/default_neutral.wav", + "voice_refs": { + "neutral": "voices/default_neutral.wav", + "happy": "voices/default_happy.wav", + "sad": "voices/default_sad.wav" + }, + "language": "ru", + "speed": 1.0, + "emotion": "neutral", + "seq": 1 +} +``` + +- `voice_ref` — основной референс голоса. +- `voice_refs` — словарь референсов для каждой эмоции (опционально). +- `language` — код языка (`ru`, `en`, `ua`, `es`, `de`, `fr`, ...). +- `speed` — скорость речи, 0.5–2.0. +- `emotion` — текущая эмоция. + +### `text` — чанк текста от LLM + +```json +{ + "type": "text", + "payload": "Привет, ", + "emotion": "happy", + "seq": 2 +} +``` + +- `payload` — очередная часть сгенерированного текста. +- `emotion` — переопределить эмоцию для этого и последующих сегментов (опционально). + +### `flush` — озвучить всё, что в буфере + +```json +{ + "type": "flush", + "seq": 3 +} +``` + +Используется, когда LLM закончил генерацию и остался незавершённый текст. + +### `stop` — немедленно прервать вещание + +```json +{ + "type": "stop", + "reason": "interrupt", + "seq": 4 +} +``` + +Очищает буфер, аудио-очередь и отменяет текущую генерацию. + +### `emotion` — сменить эмоцию + +```json +{ + "type": "emotion", + "emotion": "sad", + "seq": 5 +} +``` + +Сервер подберёт референс из `voice_refs[emotion]` или основной `voice_ref`. + +### `config` — изменить параметры + +```json +{ + "type": "config", + "speed": 1.1, + "language": "en", + "seq": 6 +} +``` + +## Сообщения от сервера к клиенту + +### `status` — события жизненного цикла + +```json +{ + "type": "status", + "event": "session_ready", + "seq": 1 +} +``` + +Возможные `event`: + +- `session_ready` — сессия инициализирована. +- `segment_started` — начался синтез сегмента. +- `segment_finished` — сегмент синтезирован и аудио-чанк отправлен клиенту. +- `finished` — все сегменты из буфера обработаны. +- `stopped` — вещание прервано командой `stop`. +- `error` — произошла ошибка. + +### `audio` — аудио-чанк + +```json +{ + "type": "audio", + "format": "pcm_s16le", + "sample_rate": 24000, + "channels": 1, + "data": "", + "seq": 7 +} +``` + +### `error` — ошибка + +```json +{ + "type": "error", + "message": "Invalid message: ...", + "seq": null +} +``` + +## Порядок сообщений в типичном сценарии + +``` +Agent -> Server: init +Server -> Agent: status: session_ready + +Agent -> Server: text "Привет" +Agent -> Server: text ", как " +Agent -> Server: text "дела?" +Server -> Agent: status: segment_started +Server -> Agent: audio +Server -> Agent: status: segment_finished + +Agent -> Server: stop +Server -> Agent: status: stopped + +Agent -> Server: text "Новая фраза" +... +``` + +> Важно: `segment_finished` отправляется **после** `audio`, чтобы клиент, ожидающий +> завершения сегмента, уже получил аудио-данные и мог продолжить воспроизведение. + +## Примечания + +- `seq` — клиентский счётчик, который сервер копирует в `status` и использует как + идентификатор сегмента для `audio` (один `seq` на весь сегмент, не на каждый + фрейм). Это упрощает сопоставление ответов с запросами. diff --git a/docs/04_roadmap.md b/docs/04_roadmap.md new file mode 100644 index 0000000..d1c703a --- /dev/null +++ b/docs/04_roadmap.md @@ -0,0 +1,63 @@ +# План реализации + +## Этап 1 — Каркас сервера ✅ + +- [x] Создать структуру проекта. +- [x] Настроить `pyproject.toml`, `requirements.txt`, `.env.example`. +- [x] Реализовать WebSocket сервер на FastAPI. +- [x] Описать Pydantic-модели протокола. +- [x] Реализовать сегментатор текста. +- [x] Реализовать абстракцию TTS engine + dummy backend. +- [x] Реализовать форматирование аудио (PCM16/base64). +- [x] Реализовать сессию с управлением буфером, очередью и прерыванием. +- [x] Добавить базовые тесты сегментатора. +- [x] Документировать архитектуру и протокол. + +## Этап 2 — Интеграция F5-TTS + +- [x] Установить и настроить зависимости F5-TTS в `pyproject.toml` / `requirements.txt`. +- [x] Создать `src/voice_tts/tts/f5_backend.py`. +- [x] Подключить F5-TTS в `_BACKEND_MAP` в `server.py`. +- [x] Поддержать референсное аудио для клонирования голоса. +- [x] Автоматическая транскрипция референса, если `ref_text` не задан. +- [x] Загрузить модель `F5TTS_v1_Base` в `models/F5TTS_v1_Base/`. +- [x] Проверить загрузку модели на GPU. +- [x] Проверить настоящий инференс F5-TTS на GPU. +- [x] Добавить warm-up при старте сервера. +- [x] Замерить реальную латентность на GPU. +- [ ] Реализовать предобработку текста и токенизацию для мультиязычности. + +## Этап 3 — Эмоции и голос + +- [ ] Определить набор поддерживаемых эмоций (`neutral`, `happy`, `sad`, `angry`, `surprised`, `whisper`). +- [ ] Подготовить/сгенерировать референсы для каждой эмоции от одного спикера. +- [ ] При смене эмоции автоматически выбирать соответствующий референс. +- [ ] (Опционально) Добавить текстовые emotion prompts для усиления интонации. + +## Этап 4 — Оптимизации + +- [ ] Предварительная подготовка следующего сегмента (pre-fetch) параллельно с текущей генерацией. +- [ ] Batching коротких сегментов, если они накопились. +- [ ] Кроссфейд между аудио-сегментами для бесшовного воспроизведения. +- [ ] Кэширование эмбеддингов референсных аудио. +- [ ] Профилирование и замер Time-To-First-Byte (TTFB) и задержки между сегментами. + +## Этап 5 — Тестирование и совместимость + +- [x] Запустить и замерить производительность на RTX 3090. +- [ ] Проверить работу на RTX 3060 (VRAM/скорость). +- [ ] Покрыть ключевые сценарии тестами: + - стриминг текста, + - flush, + - stop/прерывание, + - смена эмоции, + - мультиязычные фразы. +- [ ] Написать простого Python-клиента для демонстрации. +- [ ] Добавить в README инструкции по развёртыванию. + +## Этап 6 — Продвинутые возможности + +- [ ] Опциональный Opus-кодек для экономии трафика. +- [ ] REST endpoint для синхронного TTS (для простых случаев). +- [ ] Автоматический выбор устройства (`cuda` / `cpu`) и dtype (`bfloat16` / `float16`). +- [ ] Graceful degradation: fallback на CPU, если GPU недоступна. diff --git a/docs/05_usage.md b/docs/05_usage.md new file mode 100644 index 0000000..e87731d --- /dev/null +++ b/docs/05_usage.md @@ -0,0 +1,143 @@ +# Использование и развёртывание + +## Установка + +> Рекомендуется Python 3.11. Python 3.14+ пока не имеет совместимых wheel для +> `torch` / `f5-tts`, поэтому используйте 3.10–3.12. + +```bash +# Клонировать / перейти в директорию проекта +cd voice + +# Создать виртуальное окружение +python3.11 -m venv .venv +source .venv/bin/activate # Windows: .venv\Scripts\activate + +# Установить PyTorch с CUDA 12.6 (обязательно первым) +pip install torch torchaudio --index-url https://download.pytorch.org/whl/cu126 + +# Установить остальные зависимости +pip install -r requirements.txt + +# (Опционально) скопировать настройки +cp .env.example .env +``` + +## Запуск сервера + +```bash +# Основной режим: F5-TTS на GPU +python -m voice_tts.main + +# Тестовый режим без модели +TTS_BACKEND=dummy python -m voice_tts.main +``` + +Сервер поднимется на `ws://localhost:8765/ws`. + +## Проверка работоспособности + +```bash +curl http://localhost:8765/health +# {"status":"ok","backend":"f5_tts"} +``` + +## Пример Python-клиента + +```python +import asyncio +import base64 +import json +import websockets + + +async def main(): + uri = "ws://localhost:8765/ws" + async with websockets.connect(uri) as ws: + await ws.send(json.dumps({ + "type": "init", + "session_id": "demo", + "voice_ref": "voices/default_neutral.wav", + "language": "ru", + "speed": 1.0, + "emotion": "neutral", + "seq": 1, + })) + + for chunk in ["Привет, ", "как ", "дела?"]: + await ws.send(json.dumps({ + "type": "text", + "payload": chunk, + "seq": 2, + })) + await asyncio.sleep(0.2) + + await ws.send(json.dumps({"type": "flush", "seq": 3})) + + while True: + msg = json.loads(await ws.recv()) + print(msg) + if msg["type"] == "audio": + pcm = base64.b64decode(msg["data"]) + # отправить pcm на воспроизведение + if msg["type"] == "status" and msg["event"] == "stopped": + break + + +if __name__ == "__main__": + asyncio.run(main()) +``` + +## Настройка через переменные окружения (.env) + +| Переменная | Описание | По умолчанию | +|------------|----------|--------------| +| `HOST` | Хост сервера | `0.0.0.0` | +| `PORT` | Порт сервера | `8765` | +| `LOG_LEVEL` | Уровень логирования | `INFO` | +| `TTS_BACKEND` | Бэкенд (`dummy` / `f5_tts`) | `f5_tts` | +| `TTS_SAMPLE_RATE` | Частота дискретизации | `24000` | +| `VOICES_DIR` | Директория с референсами | `voices` | +| `MIN_SEGMENT_LENGTH` | Мин. длина сегмента | `30` | +| `MAX_SEGMENT_LENGTH` | Макс. длина сегмента | `200` | +| `MAX_BUFFER_WAIT_MS` | Макс. ожидание перед flush | `500` | +| `DEVICE` | `cuda` или `cpu` | `cuda` | +| `DTYPE` | `bfloat16` / `float16` | `bfloat16` | + +## Загрузка модели + +Если `TTS_BACKEND=f5_tts` (по умолчанию), при первом старте сервер автоматически +скачает нужный checkpoint из Hugging Face в кэш. Чтобы скачать модель +заранее: + +```bash +python scripts/download_f5_tts.py --model F5TTS_v1_Base +``` + +Поддерживаемые варианты: `F5TTS_v1_Base`, `F5TTS_Base`, `E2TTS_Base`. +Модель сохраняется в `models/F5TTS_v1_Base/`. + +## Тесты + +```bash +# Быстрые тесты без загрузки F5-TTS +TTS_BACKEND=dummy python -m pytest tests/ -v +``` + +## Референсные аудио + +Поместите файлы в директорию `voices/`: + +``` +voices/ +├── default_neutral.wav +├── default_happy.wav +├── default_sad.wav +└── ... +``` + +Требования к референсу: +- WAV или другой формат, читаемый `torchaudio`. +- Моно, 16+ кГц. +- Длина 3–10 секунд (для F5-TTS). +- Чистая речь одного спикера без фонового шума. diff --git a/docs/06_technical_notes.md b/docs/06_technical_notes.md new file mode 100644 index 0000000..5cb8339 --- /dev/null +++ b/docs/06_technical_notes.md @@ -0,0 +1,66 @@ +# Технические заметки + +## Почему base64 PCM вместо WAV/Opus + +- **Низкая задержка**: не нужно ждать формирования заголовка WAV или энкодинга Opus. +- **Простота клиента**: Python-агент может напрямую скормить PCM в `pyaudio` / `sounddevice` / `alsaaudio`. +- **Компромисс**: трафик больше, чем с Opus, но в локальной сети это некритично. +- Opus можно добавить позже как опциональный формат. + +## Почему один TTS worker + +- CUDA context не любит параллельные вызовы из разных потоков. +- Один worker с `asyncio.Queue` гарантирует последовательный доступ к GPU и предсказуемое потребление VRAM. +- Если в будущем понадобится масштабирование — можно запустить несколько независимых инстансов. + +## Управление остановкой + +`stop_event` — центральный механизм: + +- `SessionState.stop()` устанавливает флаг. +- Все длительные операции (TTS, отправка аудио) могут его проверять. +- После `stop` новое сообщение `text` автоматически сбрасывает флаг и начинает новую фразу. + +## Сегментация + +Цель — найти баланс между: +- задержкой (короткие сегменты синтезируются быстрее), +- качеством (TTS лучше звучит на целых предложениях), +- реалтаймом (не ждём слишком долго). + +Параметры по умолчанию: +- `min_segment_length = 30` +- `max_segment_length = 200` +- `max_buffer_wait_ms = 500` + +Для русского языка предложения обычно короче, чем на английском, поэтому `max_length` выбран консервативно. + +## F5-TTS: особенности интеграции + +- Используется готовая модель `F5TTS_v1_Base` через официальный пакет `f5-tts`. +- Модель автоматически загружается из Hugging Face при первом вызове `load()`; + можно скачать заранее скриптом `scripts/download_f5_tts.py`. +- Референсное аудио транскрибируется автоматически, если `ref_text` не задан. +- Результат транскрипции и путь к обработанному аудио кэшируются по паре + `(путь, эмоция)`, чтобы не повторять работу при смене сегментов. +- Скорость регулируется параметром `speed` инференса F5-TTS. +- Текущая целевая частота дискретизации — **24 kHz**; модель сама + передискретирует референс при необходимости. +- Инференс выполняется в отдельном потоке (`asyncio.to_thread`) с + `asyncio.Lock`, чтобы не блокировать event loop и не допускать + параллельных CUDA-вызовов. Все `send_text` сериализуются через + отдельный `_send_lock`, чтобы избежать deadlock внутри `websockets`. + +### Замеры задержки (RTX 3090, Python 3.11, CUDA 12.6) + +- Первый запуск без warm-up: ~5–6 с, большая часть уходит на Whisper-транскрипцию референса. +- После warm-up с кэшированным референсом: первый audio-chunk ~1.1 с на коротком сегменте. +- 4 сегмента подряд: первый finished ~1.1 с, последний ~4.4 с. +- `stop` + возобновление работает без переподключения WebSocket. + +## Потенциальные проблемы + +- F5-TTS может не идеально произносить украинский / европейские языки из коробки — возможно потребуется fine-tuning или fallback. Сейчас протокол и сегментатор не ограничивают язык; качество зависит от самой модели. +- RTX 3060 (12 GB) подойдёт для базовой модели, но batch-size и длина референса придётся ограничивать. +- Быстрый `stop` во время CUDA kernel не прервёт уже запущенный kernel, но предотвратит отправку результата. +- `main.py` создаёт engine до старта uvicorn; при `TTS_BACKEND=f5_tts` первый запуск может занять десяток секунд из-за загрузки модели и vocos. diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 0000000..761e770 --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,50 @@ +[project] +name = "voice-tts" +version = "0.1.0" +description = "Local GPU-powered real-time TTS pipeline with WebSocket API for AI agents" +requires-python = ">=3.10" +dependencies = [ + "fastapi>=0.115.0", + "uvicorn[standard]>=0.32.0", + "pydantic>=2.9.0", + "pydantic-settings>=2.6.0", + "numpy>=1.26.0", + "torch>=2.4.0", + "torchaudio>=2.4.0", + "loguru>=0.7.0", + "f5-tts>=1.1.0", + "soundfile>=0.12.0", + "pydub>=0.25.0", + "tqdm>=4.66.0", + "omegaconf>=2.3.0", + "hydra-core>=1.3.0", + "safetensors>=0.4.0", + "transformers>=4.45.0", + "huggingface-hub>=0.25.0", + "vocos>=0.1.0", + "matplotlib>=3.9.0", + "unidecode>=1.3.0", + "tomli>=2.0.0 ; python_version < '3.11'", + "cached-path>=1.6.0", + "accelerate>=0.34.0", + "sentencepiece>=0.2.0", + "bitsandbytes>=0.44.0", +] + +[project.optional-dependencies] +dev = [ + "pytest>=8.3.0", + "pytest-asyncio>=0.24.0", + "httpx>=0.27.0", + "websockets>=13.0", +] + +[project.scripts] +voice-tts = "voice_tts.main:main" + +[tool.setuptools.packages.find] +where = ["src"] + +[tool.pytest.ini_options] +asyncio_mode = "auto" +testpaths = ["tests"] diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000..0fc131a --- /dev/null +++ b/requirements.txt @@ -0,0 +1,33 @@ +# Core server +fastapi>=0.115.0 +uvicorn[standard]>=0.32.0 +pydantic>=2.9.0 +pydantic-settings>=2.6.0 +numpy>=1.26.0 +loguru>=0.7.0 +websockets>=14.0 +httpx>=0.27.0 + +# PyTorch with CUDA 12.x +--index-url https://download.pytorch.org/whl/cu126 +torch>=2.4.0 +torchaudio>=2.4.0 + +# F5-TTS and its dependencies +f5-tts>=1.1.0 +soundfile>=0.12.0 +pydub>=0.25.0 +tqdm>=4.66.0 +omegaconf>=2.3.0 +hydra-core>=1.3.0 +safetensors>=0.4.0 +transformers>=4.45.0 +huggingface-hub>=0.25.0 +vocos>=0.1.0 +matplotlib>=3.9.0 +unidecode>=1.3.0 +tomli>=2.0.0; python_version < "3.11" +cached-path>=1.6.0 +accelerate>=0.34.0 +sentencepiece>=0.2.0 +bitsandbytes>=0.44.0 diff --git a/scripts/download_f5_tts.py b/scripts/download_f5_tts.py new file mode 100644 index 0000000..9f43947 --- /dev/null +++ b/scripts/download_f5_tts.py @@ -0,0 +1,62 @@ +"""Download F5-TTS model files from Hugging Face to the local models/ directory.""" + +from pathlib import Path + +from huggingface_hub import hf_hub_download +from loguru import logger + + +MODELS_DIR = Path("models") +REPO_ID = "SWivid/F5-TTS" + + +def download_model( + model_name: str = "F5TTS_v1_Base", + repo_id: str = REPO_ID, + local_dir: Path = MODELS_DIR, +) -> None: + """Download the model checkpoint and vocab for the requested F5-TTS variant.""" + local_dir.mkdir(parents=True, exist_ok=True) + + if model_name == "F5TTS_v1_Base": + filename = "F5TTS_v1_Base/model_1250000.safetensors" + elif model_name == "F5TTS_Base": + filename = "F5TTS_Base/model_1200000.safetensors" + elif model_name == "E2TTS_Base": + filename = "E2TTS_Base/model_1200000.safetensors" + else: + raise ValueError(f"Unsupported model: {model_name}") + + logger.info("Downloading {} from {} ...", filename, repo_id) + path = hf_hub_download( + repo_id=repo_id, + filename=filename, + local_dir=local_dir, + local_dir_use_symlinks=False, + ) + logger.info("Model saved to {}", path) + + vocab_filename = f"{model_name}/vocab.txt" + logger.info("Downloading vocab {} from {} ...", vocab_filename, repo_id) + vocab_path = hf_hub_download( + repo_id=repo_id, + filename=vocab_filename, + local_dir=local_dir, + local_dir_use_symlinks=False, + ) + logger.info("Vocab saved to {}", vocab_path) + + +if __name__ == "__main__": + import argparse + + parser = argparse.ArgumentParser(description="Download F5-TTS model files") + parser.add_argument( + "--model", + default="F5TTS_v1_Base", + choices=["F5TTS_v1_Base", "F5TTS_Base", "E2TTS_Base"], + ) + parser.add_argument("--local-dir", default=str(MODELS_DIR)) + args = parser.parse_args() + + download_model(args.model, local_dir=Path(args.local_dir)) diff --git a/src/voice_tts/api/protocol.py b/src/voice_tts/api/protocol.py new file mode 100644 index 0000000..c5fdf33 --- /dev/null +++ b/src/voice_tts/api/protocol.py @@ -0,0 +1,110 @@ +"""WebSocket message protocol models.""" + +from typing import Any, Literal +from pydantic import BaseModel, Field + + +class ClientMessage(BaseModel): + """Base class for incoming WebSocket messages from the agent.""" + + type: str + seq: int | None = None + + +class InitMessage(ClientMessage): + """Initialize or reconfigure a session: voice reference, language, speed, emotion.""" + + type: Literal["init"] = "init" + session_id: str | None = None + voice_ref: str | None = None + voice_refs: dict[str, str] | None = None + language: str = "ru" + speed: float = Field(1.0, ge=0.5, le=2.0) + emotion: str = "neutral" + + +class TextMessage(ClientMessage): + """Streaming text chunk from the LLM/agent.""" + + type: Literal["text"] = "text" + payload: str + emotion: str | None = None + + +class FlushMessage(ClientMessage): + """Force-speak everything currently buffered.""" + + type: Literal["flush"] = "flush" + + +class StopMessage(ClientMessage): + """Interrupt and clear everything: buffer, current generation, audio queue.""" + + type: Literal["stop"] = "stop" + reason: str = "interrupt" + + +class EmotionMessage(ClientMessage): + """Change emotion for upcoming segments.""" + + type: Literal["emotion"] = "emotion" + emotion: str + + +class ConfigMessage(ClientMessage): + """Change runtime parameters (speed, etc.).""" + + type: Literal["config"] = "config" + speed: float | None = Field(None, ge=0.5, le=2.0) + language: str | None = None + + +ClientMessageUnion = InitMessage | TextMessage | FlushMessage | StopMessage | EmotionMessage | ConfigMessage + + +# --------------------------------------------------------------------------- +# Server -> client messages +# --------------------------------------------------------------------------- + + +class ServerMessage(BaseModel): + """Base class for outgoing WebSocket messages.""" + + type: str + seq: int | None = None + + +class AudioMessage(ServerMessage): + """Audio chunk encoded as base64 PCM.""" + + type: Literal["audio"] = "audio" + format: Literal["pcm_s16le"] = "pcm_s16le" + sample_rate: int + channels: int = 1 + data: str # base64 + + +class StatusMessage(ServerMessage): + """Lifecycle/status events.""" + + type: Literal["status"] = "status" + event: Literal[ + "session_ready", + "segment_started", + "segment_finished", + "finished", + "stopped", + "error", + ] + reason: str | None = None + extra: dict[str, Any] | None = None + + +class ErrorMessage(ServerMessage): + """Error notification.""" + + type: Literal["error"] = "error" + message: str + + +ServerMessageUnion = AudioMessage | StatusMessage | ErrorMessage diff --git a/src/voice_tts/api/server.py b/src/voice_tts/api/server.py new file mode 100644 index 0000000..da783a6 --- /dev/null +++ b/src/voice_tts/api/server.py @@ -0,0 +1,386 @@ +"""WebSocket server and session lifecycle.""" + +import asyncio +from contextlib import asynccontextmanager +import json +from pathlib import Path + +from fastapi import FastAPI, WebSocket, WebSocketDisconnect +from loguru import logger + +from voice_tts.api.protocol import ( + AudioMessage, + ClientMessageUnion, + ConfigMessage, + EmotionMessage, + ErrorMessage, + FlushMessage, + InitMessage, + ServerMessageUnion, + StatusMessage, + StopMessage, + TextMessage, +) +from voice_tts.audio.formats import float_to_pcm16, pcm16_to_base64 +from voice_tts.config import settings +from voice_tts.session.state import SessionState, VoiceProfile +from voice_tts.tts.engine import DummyTTSEngine, TTSEngine +from voice_tts.tts.f5_backend import F5TTSEngine +from voice_tts.tts.segmenter import Segmenter + + +# Supported TTS backends +_BACKEND_MAP: dict[str, type[TTSEngine]] = { + "dummy": DummyTTSEngine, + "f5_tts": F5TTSEngine, +} + + +class SessionManager: + """Manages the lifecycle of a single WebSocket TTS session.""" + + def __init__(self, websocket: WebSocket, engine: TTSEngine): + self.ws = websocket + self.engine = engine + self.segmenter = Segmenter( + min_length=settings.min_segment_length, + max_length=settings.max_segment_length, + ) + self.state = SessionState(session_id="") + self._running = True + self._tasks: list[asyncio.Task] = [] + self._send_lock = asyncio.Lock() + self._synth_lock = asyncio.Lock() + + async def run(self) -> None: + await self.ws.accept() + logger.info("WebSocket connection accepted") + + # Start audio sender worker + sender_task = asyncio.create_task(self._audio_sender()) + self._tasks.append(sender_task) + + try: + while self._running: + raw = await self.ws.receive_text() + try: + data = json.loads(raw) + msg = self._parse_message(data) + except Exception as exc: + await self._send(ErrorMessage(message=f"Invalid message: {exc}")) + continue + + await self._handle_message(msg) + except WebSocketDisconnect: + logger.info("Client disconnected") + finally: + self._stop_all() + for task in self._tasks: + task.cancel() + try: + await task + except asyncio.CancelledError: + pass + + def _parse_message(self, data: dict) -> ClientMessageUnion: + msg_type = data.get("type") + if msg_type == "init": + return InitMessage(**data) + if msg_type == "text": + return TextMessage(**data) + if msg_type == "flush": + return FlushMessage(**data) + if msg_type == "stop": + return StopMessage(**data) + if msg_type == "emotion": + return EmotionMessage(**data) + if msg_type == "config": + return ConfigMessage(**data) + raise ValueError(f"Unknown message type: {msg_type}") + + async def _handle_message(self, msg: ClientMessageUnion) -> None: + if isinstance(msg, InitMessage): + await self._handle_init(msg) + elif isinstance(msg, TextMessage): + await self._handle_text(msg) + elif isinstance(msg, FlushMessage): + await self._handle_flush(msg) + elif isinstance(msg, StopMessage): + await self._handle_stop(msg) + elif isinstance(msg, EmotionMessage): + await self._handle_emotion(msg) + elif isinstance(msg, ConfigMessage): + await self._handle_config(msg) + + async def _handle_init(self, msg: InitMessage) -> None: + self.state.session_id = msg.session_id or "default" + self.state.language = msg.language + self.state.speed = msg.speed + self.state.emotion = msg.emotion + self.state.clear_buffer() + self.state.reset_stop() + + voice = VoiceProfile() + if msg.voice_ref: + voice.default_ref = Path(msg.voice_ref) + if msg.voice_refs: + voice.emotion_refs = { + emotion: Path(path) for emotion, path in msg.voice_refs.items() + } + self.state.voice = voice + + await self._send(StatusMessage(event="session_ready", seq=msg.seq)) + logger.info( + "Session initialized: id={} language={} speed={} emotion={}", + self.state.session_id, + self.state.language, + self.state.speed, + self.state.emotion, + ) + + async def _handle_text(self, msg: TextMessage) -> None: + if self.state.is_stopped(): + # After a stop, new text implicitly resets the stop flag + self.state.reset_stop() + + self.state.text_buffer += msg.payload + if msg.emotion: + self.state.emotion = msg.emotion + + from voice_tts.tts.utils import preprocess_text_for_tts + + self.state.text_buffer = preprocess_text_for_tts(self.state.text_buffer) + remaining, segments = self.segmenter.feed(self.state.text_buffer) + self.state.text_buffer = remaining + + for segment in segments: + asyncio.create_task(self._synthesize_segment(segment.text, msg.seq)) + + # If this text chunk itself ends a sentence and the buffer is too short, + # we might still be holding it. Force a flush so short complete sentences + # (e.g. "дела?") are not stuck waiting for more input. + if _looks_like_complete_phrase(msg.payload) and self.state.text_buffer.strip(): + segments = self.segmenter.flush(self.state.text_buffer) + self.state.text_buffer = "" + for segment in segments: + asyncio.create_task(self._synthesize_segment(segment.text, msg.seq)) + + async def _handle_flush(self, msg: FlushMessage) -> None: + segments = self.segmenter.flush(self.state.text_buffer) + self.state.text_buffer = "" + for segment in segments: + asyncio.create_task(self._synthesize_segment(segment.text, msg.seq)) + + async def _handle_stop(self, msg: StopMessage) -> None: + logger.info("Stop requested: {}", msg.reason) + self.state.stop() + self.state.clear_buffer() + + # Drain pending audio queue + while not self.state.audio_queue.empty(): + try: + self.state.audio_queue.get_nowait() + self.state.audio_queue.task_done() + except asyncio.QueueEmpty: + break + + await self._send(StatusMessage(event="stopped", reason=msg.reason, seq=msg.seq)) + # Don't kill the connection: agent may continue the session after stop. + + async def _handle_emotion(self, msg: EmotionMessage) -> None: + self.state.emotion = msg.emotion + await self._send( + StatusMessage( + event="config_updated", + extra={"emotion": msg.emotion}, + seq=msg.seq, + ) + ) + + async def _handle_config(self, msg: ConfigMessage) -> None: + if msg.speed is not None: + self.state.speed = msg.speed + if msg.language is not None: + self.state.language = msg.language + await self._send( + StatusMessage( + event="config_updated", + extra={"speed": self.state.speed, "language": self.state.language}, + seq=msg.seq, + ) + ) + + async def _synthesize_segment(self, text: str, seq: int | None) -> None: + if self.state.is_stopped(): + return + + segment_seq = self.state.next_segment_seq() + await self._send(StatusMessage(event="segment_started", seq=segment_seq)) + + ref_path = self.state.voice.ref_for(self.state.emotion) + + try: + # Serialize GPU inference and run it off the main event loop. + async with self._synth_lock: + audio = await asyncio.to_thread( + SessionManager._sync_synthesize, + self.engine, + text, + ref_path, + self.state.language, + self.state.speed, + self.state.emotion, + ) + except Exception as exc: + logger.exception("TTS synthesis failed") + await self._send(ErrorMessage(message=f"TTS failed: {exc}", seq=segment_seq)) + return + + if self.state.is_stopped(): + return + + pcm = float_to_pcm16(audio) + await self.state.audio_queue.put((segment_seq, pcm)) + + async def _audio_sender(self) -> None: + while self._running: + try: + segment_seq, pcm = await asyncio.wait_for(self.state.audio_queue.get(), timeout=0.5) + except asyncio.TimeoutError: + continue + + if self.state.is_stopped(): + self.state.audio_queue.task_done() + continue + + await self._send( + AudioMessage( + sample_rate=self.engine.sample_rate, + data=pcm16_to_base64(pcm), + seq=segment_seq, + ) + ) + await self._send(StatusMessage(event="segment_finished", seq=segment_seq)) + self.state.audio_queue.task_done() + + async def _send(self, msg: ServerMessageUnion) -> None: + async with self._send_lock: + try: + await self.ws.send_text(msg.model_dump_json(exclude_none=True)) + except Exception as exc: + logger.warning("Failed to send WebSocket message: {}", exc) + self._running = False + + @staticmethod + def _sync_synthesize( + engine: "TTSEngine", + text: str, + ref_path: Path | None, + language: str, + speed: float, + emotion: str, + ) -> "np.ndarray": + """Thread-safe wrapper around the synchronous TTS inference call.""" + import numpy as np + + # The dummy backend is async; run it on a transient event loop in this thread. + if isinstance(engine, DummyTTSEngine): + try: + loop = asyncio.get_running_loop() + except RuntimeError: + loop = None + + async def _run() -> "np.ndarray": + return await engine.synthesize( + text=text, + ref_audio_path=ref_path, + language=language, + speed=speed, + emotion=emotion, + ) + + if loop is not None: + return asyncio.run_coroutine_threadsafe(_run(), loop).result() + return asyncio.run(_run()) + + # F5-TTS exposes an async synthesize method that blocks on CPU/CUDA work. + # Inside a thread from asyncio.to_thread there is no running loop, so we + # drive the coroutine with a fresh transient event loop. + return asyncio.run( + engine.synthesize( + text=text, + ref_audio_path=ref_path, + language=language, + speed=speed, + emotion=emotion, + ) + ) + + def _stop_all(self) -> None: + self._running = False + self.state.stop() + + +def _looks_like_complete_phrase(text: str) -> bool: + """Heuristic: chunk ends with strong punctuation -> sentence is complete.""" + text = text.rstrip() + return bool(text) and text[-1] in ".。!??!" + + +def _create_engine() -> TTSEngine: + backend = settings.tts_backend + engine_cls = _BACKEND_MAP.get(backend) + if engine_cls is None: + raise RuntimeError( + f"Unknown TTS backend: {backend}. " + f"Available backends: {list(_BACKEND_MAP.keys())}" + ) + engine = engine_cls(sample_rate=settings.tts_sample_rate) + if hasattr(engine, "load"): + engine.load() + return engine + + +def build_app(engine: TTSEngine | None = None) -> FastAPI: + tts_engine = engine or _create_engine() + + @asynccontextmanager + async def _lifespan(app: FastAPI) -> None: + if settings.warmup and settings.default_voice_ref is not None: + try: + logger.info( + "Warming up {} with reference {} ...", + settings.tts_backend, + settings.default_voice_ref, + ) + await asyncio.to_thread( + SessionManager._sync_synthesize, + tts_engine, + settings.warmup_text, + settings.default_voice_ref, + "en", + 1.0, + "neutral", + ) + logger.info("Warm-up complete.") + except Exception as exc: + logger.warning("Warm-up failed (continuing anyway): {}", exc) + else: + logger.info("Warm-up skipped.") + yield + + app = FastAPI(title="Voice TTS", version="0.1.0", lifespan=_lifespan) + + @app.websocket("/ws") + async def websocket_endpoint(websocket: WebSocket) -> None: + session = SessionManager(websocket, tts_engine) + await session.run() + + @app.get("/health") + async def health() -> dict: + return {"status": "ok", "backend": settings.tts_backend} + + return app + + +app = build_app() diff --git a/src/voice_tts/audio/formats.py b/src/voice_tts/audio/formats.py new file mode 100644 index 0000000..e9999cb --- /dev/null +++ b/src/voice_tts/audio/formats.py @@ -0,0 +1,44 @@ +"""Audio formatting utilities.""" + +import base64 +import io + +import numpy as np + + +def float_to_pcm16(audio: np.ndarray) -> bytes: + """Convert float32 audio in [-1, 1] to little-endian 16-bit PCM bytes.""" + clipped = np.clip(audio, -1.0, 1.0) + pcm = (clipped * 32767).astype(np.int16) + return pcm.tobytes() + + +def pcm16_to_base64(pcm_bytes: bytes) -> str: + return base64.b64encode(pcm_bytes).decode("ascii") + + +def float_to_wav_bytes(audio: np.ndarray, sample_rate: int) -> bytes: + """Wrap float audio into a WAV file in memory.""" + pcm = float_to_pcm16(audio) + header = _build_wav_header(len(pcm), sample_rate, 1, 16) + return header + pcm + + +def _build_wav_header(data_size: int, sample_rate: int, channels: int, bits_per_sample: int) -> bytes: + byte_rate = sample_rate * channels * bits_per_sample // 8 + block_align = channels * bits_per_sample // 8 + + header = b"RIFF" + header += (36 + data_size).to_bytes(4, "little") + header += b"WAVE" + header += b"fmt " + header += (16).to_bytes(4, "little") # subchunk size + header += (1).to_bytes(2, "little") # PCM + header += channels.to_bytes(2, "little") + header += sample_rate.to_bytes(4, "little") + header += byte_rate.to_bytes(4, "little") + header += block_align.to_bytes(2, "little") + header += bits_per_sample.to_bytes(2, "little") + header += b"data" + header += data_size.to_bytes(4, "little") + return header diff --git a/src/voice_tts/config.py b/src/voice_tts/config.py new file mode 100644 index 0000000..7890b9e --- /dev/null +++ b/src/voice_tts/config.py @@ -0,0 +1,40 @@ +from pathlib import Path +from pydantic_settings import BaseSettings + + +class Settings(BaseSettings): + """Application configuration loaded from environment variables.""" + + host: str = "0.0.0.0" + port: int = 8765 + log_level: str = "INFO" + + # TTS model configuration + tts_backend: str = "f5_tts" # or "dummy" for tests + tts_model_path: Path | None = None + tts_vocab_path: Path | None = None + tts_sample_rate: int = 24_000 + + # Reference voices directory + voices_dir: Path = Path("voices") + default_voice_ref: Path | None = None # env: DEFAULT_VOICE_REF + + # Segmentation thresholds + min_segment_length: int = 30 + max_segment_length: int = 200 + max_buffer_wait_ms: int = 500 + + # GPU / inference + device: str = "cuda" # or "cpu" + dtype: str = "bfloat16" + + # Warm-up + warmup: bool = False # run a dummy inference at startup + warmup_text: str = "Hello world." + + class Config: + env_file = ".env" + env_file_encoding = "utf-8" + + +settings = Settings() diff --git a/src/voice_tts/main.py b/src/voice_tts/main.py new file mode 100644 index 0000000..f10f8f7 --- /dev/null +++ b/src/voice_tts/main.py @@ -0,0 +1,19 @@ +"""Application entrypoint.""" + +import uvicorn + +from voice_tts.api.server import app +from voice_tts.config import settings + + +def main() -> None: + uvicorn.run( + app, + host=settings.host, + port=settings.port, + log_level=settings.log_level.lower(), + ) + + +if __name__ == "__main__": + main() diff --git a/src/voice_tts/session/state.py b/src/voice_tts/session/state.py new file mode 100644 index 0000000..5470a2f --- /dev/null +++ b/src/voice_tts/session/state.py @@ -0,0 +1,60 @@ +"""Per-session state: buffers, queues, stop control, voice/emotion settings.""" + +import asyncio +from dataclasses import dataclass, field +from pathlib import Path + + +@dataclass +class VoiceProfile: + """Reference audio paths per emotion for a single speaker.""" + + default_ref: Path | None = None + emotion_refs: dict[str, Path] = field(default_factory=dict) + + def ref_for(self, emotion: str) -> Path | None: + return self.emotion_refs.get(emotion) or self.default_ref + + +@dataclass +class SessionState: + """Mutable state for one WebSocket connection.""" + + session_id: str + language: str = "ru" + speed: float = 1.0 + emotion: str = "neutral" + voice: VoiceProfile = field(default_factory=VoiceProfile) + + # Text buffer + text_buffer: str = "" + + # Stop / interrupt control + stop_event: asyncio.Event = field(default_factory=asyncio.Event) + + # Audio output queue (filled by TTS worker, drained by WebSocket sender) + audio_queue: asyncio.Queue[bytes] = field(default_factory=asyncio.Queue) + + # Sequence counter for outgoing messages + out_seq: int = 0 + segment_seq: int = 0 + + def next_out_seq(self) -> int: + self.out_seq += 1 + return self.out_seq + + def next_segment_seq(self) -> int: + self.segment_seq += 1 + return self.segment_seq + + def reset_stop(self) -> None: + self.stop_event.clear() + + def stop(self) -> None: + self.stop_event.set() + + def is_stopped(self) -> bool: + return self.stop_event.is_set() + + def clear_buffer(self) -> None: + self.text_buffer = "" diff --git a/src/voice_tts/tts/engine.py b/src/voice_tts/tts/engine.py new file mode 100644 index 0000000..07fad11 --- /dev/null +++ b/src/voice_tts/tts/engine.py @@ -0,0 +1,57 @@ +"""TTS engine abstraction and dummy backend.""" + +import asyncio +from abc import ABC, abstractmethod +from pathlib import Path + +import numpy as np + + +class TTSEngine(ABC): + """Base interface for a TTS backend.""" + + sample_rate: int = 24_000 + + @abstractmethod + async def synthesize( + self, + text: str, + ref_audio_path: Path | None, + language: str, + speed: float, + emotion: str, + ) -> np.ndarray: + """Return audio as float32 ndarray normalized to [-1, 1].""" + ... + + @abstractmethod + async def warm_up(self) -> None: + """Optional warm-up inference.""" + ... + + +class DummyTTSEngine(TTSEngine): + """Generates a silent/sine beep segment for testing without a GPU model.""" + + def __init__(self, sample_rate: int = 24_000): + self.sample_rate = sample_rate + + async def synthesize( + self, + text: str, + ref_audio_path: Path | None, + language: str, + speed: float, + emotion: str, + ) -> np.ndarray: + duration_sec = max(0.5, len(text) * 0.08) / speed + num_samples = int(self.sample_rate * duration_sec) + t = np.linspace(0, duration_sec, num_samples, endpoint=False) + # 440 Hz tone with slight fade to avoid clicks + audio = 0.3 * np.sin(2 * np.pi * 440 * t) + audio *= np.hanning(num_samples) + return audio.astype(np.float32) + + + async def warm_up(self) -> None: + pass diff --git a/src/voice_tts/tts/f5_backend.py b/src/voice_tts/tts/f5_backend.py new file mode 100644 index 0000000..f1f78f9 --- /dev/null +++ b/src/voice_tts/tts/f5_backend.py @@ -0,0 +1,158 @@ +from pathlib import Path + +import numpy as np +from loguru import logger + +from voice_tts.config import settings +from voice_tts.tts.engine import TTSEngine + + +try: + import torch + import torchaudio + from f5_tts.api import F5TTS + from f5_tts.infer.utils_infer import preprocess_ref_audio_text + + F5_AVAILABLE = True +except ImportError as exc: + logger.warning("f5-tts/torch dependencies not available: {}", exc) + F5_AVAILABLE = False + torch = None + torchaudio = None + + +class F5TTSEngine(TTSEngine): + """F5-TTS backend for local GPU inference with voice cloning.""" + + def __init__( + self, + model: str = "F5TTS_v1_Base", + sample_rate: int = 24_000, + device: str | None = None, + nfe_step: int = 32, + cfg_strength: float = 2.0, + sway_sampling_coef: float = -1.0, + speed: float = 1.0, + target_rms: float = 0.1, + cross_fade_duration: float = 0.0, + remove_silence: bool = False, + ): + super().__init__() + self.model_name = model + self.sample_rate = sample_rate + self.device = device or settings.device + self.nfe_step = nfe_step + self.cfg_strength = cfg_strength + self.sway_sampling_coef = sway_sampling_coef + self.speed = speed + self.target_rms = target_rms + self.cross_fade_duration = cross_fade_duration + self.remove_silence = remove_silence + + self._f5: "F5TTS | None" = None + self._ref_cache: dict[str, tuple[str, str]] = {} # emotion_key -> (processed_audio_path, ref_text) + + def _get_key(self, ref_path: Path | None, emotion: str) -> str: + path_str = str(ref_path) if ref_path else "" + return f"{emotion}::{path_str}" + + def load(self) -> None: + if not F5_AVAILABLE: + raise RuntimeError( + "f5-tts/torch package is not installed. Install it: pip install f5-tts torch torchaudio" + ) + logger.info( + "Loading F5-TTS model {} on device {} ...", + self.model_name, + self.device, + ) + self._f5 = F5TTS( + model=self.model_name, + device=self.device, + ) + self.sample_rate = self._f5.target_sample_rate + logger.info("F5-TTS loaded. Target sample rate: {}", self.sample_rate) + + async def warm_up(self) -> None: + if self._f5 is None: + self.load() + logger.info("Warming up F5-TTS ...") + logger.info("Warm-up skipped: provide a reference audio before warm-up.") + + def _ensure_reference( + self, + ref_audio_path: Path | None, + emotion: str, + ) -> tuple[str, str]: + if not F5_AVAILABLE: + raise RuntimeError("f5-tts/torch is not installed") + if ref_audio_path is None: + raise ValueError("F5-TTS requires a reference audio file (voice_ref).") + + if isinstance(ref_audio_path, str): + ref_audio_path = Path(ref_audio_path) + + key = self._get_key(ref_audio_path, emotion) + if key in self._ref_cache: + return self._ref_cache[key] + + if not ref_audio_path.exists(): + raise FileNotFoundError(f"Reference audio not found: {ref_audio_path}") + + processed_audio, ref_text = preprocess_ref_audio_text( + str(ref_audio_path), + "", # empty ref_text triggers automatic transcription + ) + self._ref_cache[key] = (processed_audio, ref_text) + logger.info( + "Reference cached for emotion={} path={} text={}", + emotion, + ref_audio_path, + ref_text, + ) + return processed_audio, ref_text + + async def synthesize( + self, + text: str, + ref_audio_path: Path | None, + language: str, + speed: float, + emotion: str, + ) -> np.ndarray: + if self._f5 is None: + self.load() + + assert self._f5 is not None + + if isinstance(ref_audio_path, str): + ref_audio_path = Path(ref_audio_path) + + processed_audio, ref_text = self._ensure_reference(ref_audio_path, emotion) + + wav, sr, _spec = self._f5.infer( + ref_file=processed_audio, + ref_text=ref_text, + gen_text=text, + nfe_step=self.nfe_step, + cfg_strength=self.cfg_strength, + sway_sampling_coef=self.sway_sampling_coef, + speed=speed, + cross_fade_duration=self.cross_fade_duration, + target_rms=self.target_rms, + remove_silence=self.remove_silence, + ) + + if wav is None: + raise RuntimeError("F5-TTS produced no audio") + + # Ensure shape and dtype + if isinstance(wav, torch.Tensor): + wav = wav.squeeze().cpu().numpy() + wav = np.asarray(wav, dtype=np.float32) + if wav.ndim == 0: + wav = np.zeros(1, dtype=np.float32) + elif wav.ndim > 1: + wav = wav.squeeze() + + return wav diff --git a/src/voice_tts/tts/segmenter.py b/src/voice_tts/tts/segmenter.py new file mode 100644 index 0000000..26315a6 --- /dev/null +++ b/src/voice_tts/tts/segmenter.py @@ -0,0 +1,118 @@ +"""Lightweight streaming text segmenter.""" + +import re +from dataclasses import dataclass + + +@dataclass +class Segment: + text: str + is_final: bool = False + + +class Segmenter: + """Splits streaming text into TTS-ready segments.""" + + def __init__( + self, + min_length: int = 30, + max_length: int = 200, + ): + self.min_length = min_length + self.max_length = max_length + + # End-of-sentence delimiters + self.sentence_breaks = re.compile(r"[.。!??!\n]+") + # Clause delimiters for long segments + self.clause_breaks = re.compile(r"[,;:\-—()()]") + self.whitespace_re = re.compile(r"\s+") + + def feed(self, buffer: str) -> tuple[str, list[Segment]]: + """ + Consume `buffer` and return (remaining_buffer, ready_segments). + + Logic: + - Prefer cutting at sentence boundaries. + - If a segment grows beyond max_length, cut at the nearest clause boundary. + - Segments shorter than min_length are returned only if the caller forces flush + or the input is shorter than min_length but ends with a sentence break. + """ + segments: list[Segment] = [] + remaining = buffer + + while remaining: + # Find the first sentence boundary. + first_sentence_cut = -1 + for match in self.sentence_breaks.finditer(remaining): + first_sentence_cut = match.end() + break + + # If there is a complete sentence, decide whether to emit it. + if first_sentence_cut != -1: + segment_text = remaining[:first_sentence_cut].strip() + + # Case 1: sentence is long enough -> emit immediately. + if len(segment_text) >= self.min_length: + remaining = remaining[first_sentence_cut:].lstrip() + if segment_text: + segments.append(Segment(text=segment_text, is_final=True)) + continue + + # Case 2: sentence is short. Combine it with subsequent sentences + # until the combined chunk is long enough, or there are no more + # complete sentences ahead. + combined_cut = first_sentence_cut + emitted_combined = False + while True: + next_boundary = -1 + for match in self.sentence_breaks.finditer(remaining, combined_cut): + next_boundary = match.end() + break + if next_boundary == -1: + break + combined_text = remaining[:next_boundary].strip() + combined_cut = next_boundary + if len(combined_text) >= self.min_length: + remaining = remaining[combined_cut:].lstrip() + if combined_text: + segments.append(Segment(text=combined_text, is_final=True)) + emitted_combined = True + break + + if emitted_combined: + continue + + # Case 3: short sentence and no further complete sentence ahead. + # Flush it so the user doesn't wait indefinitely. + remaining = remaining[first_sentence_cut:].lstrip() + if segment_text: + segments.append(Segment(text=segment_text, is_final=True)) + continue + + # No usable sentence boundary yet. + # Consider max-length clause cut. + if len(remaining) >= self.max_length: + window = remaining[: self.max_length] + last_clause = -1 + for match in self.clause_breaks.finditer(window): + pos = match.end() + if pos >= self.min_length: + last_clause = pos + if last_clause != -1: + segment_text = remaining[:last_clause].strip() + remaining = remaining[last_clause:].lstrip() + if segment_text: + segments.append(Segment(text=segment_text, is_final=True)) + continue + + # Nothing to cut yet; wait for more text. + break + + return remaining, segments + + def flush(self, buffer: str) -> list[Segment]: + """Force-convert all remaining text to a final segment.""" + text = buffer.strip() + if not text: + return [] + return [Segment(text=text, is_final=True)] diff --git a/src/voice_tts/tts/utils.py b/src/voice_tts/tts/utils.py new file mode 100644 index 0000000..099ee8e --- /dev/null +++ b/src/voice_tts/tts/utils.py @@ -0,0 +1,43 @@ +"""Helpers for text preprocessing and reference voice management.""" + +import re +from pathlib import Path + + +# Common sentence-ending punctuation for multiple languages. +SENTENCE_ENDINGS = { + ".", "!", "?", ";", ":", + "。", "!", "?", ";", ":", +} + + +def normalize_whitespace(text: str) -> str: + """Collapse repeated whitespace and strip edges, preserving single spaces.""" + return re.sub(r"\s+", " ", text).strip() + + +def preprocess_text_for_tts(text: str) -> str: + """ + Minimal cleanup before TTS. + - Collapse whitespace. + - Remove control characters. + """ + text = re.sub(r"[\x00-\x08\x0b\x0c\x0e-\x1f\x7f]", "", text) + return normalize_whitespace(text) + + +def has_sentence_ending(text: str) -> bool: + """Check whether the text ends with a sentence-ending punctuation.""" + stripped = text.rstrip() + return any(stripped.endswith(p) for p in SENTENCE_ENDINGS) + + +def validate_reference_audio(path: Path) -> None: + """Raise a clear error if the reference audio file is missing or unsupported.""" + if not path.exists(): + raise FileNotFoundError( + f"Reference audio not found: {path}. " + f"Place a WAV/MP3 file under {path.parent}/ and retry." + ) + if not path.is_file(): + raise ValueError(f"Reference audio path is not a file: {path}") diff --git a/tests/test_segmenter.py b/tests/test_segmenter.py new file mode 100644 index 0000000..a7991f7 --- /dev/null +++ b/tests/test_segmenter.py @@ -0,0 +1,39 @@ +"""Tests for WebSocket protocol and segmenter.""" + +import pytest + +from voice_tts.tts.segmenter import Segmenter + + +def test_segmenter_sentence_split(): + seg = Segmenter(min_length=10, max_length=200) + buffer = "Привет, мир! Как дела? Это тестовый сегмент для проверки работы сегментатора." + remaining, segments = seg.feed(buffer) + # First sentence is long enough and should be emitted immediately. + assert len(segments) >= 1 + assert segments[0].text == "Привет, мир!" + # The short second sentence is accumulated with the third one until min_length is met. + assert remaining == "" + assert len(segments) == 2 + assert segments[1].text == "Как дела? Это тестовый сегмент для проверки работы сегментатора." + + +def test_segmenter_max_length_clause_split(): + seg = Segmenter(min_length=10, max_length=50) + buffer = ( + "Это очень длинное предложение без точки, которое должно быть разрезано " + "по запятой или другому разделителю, потому что иначе оно слишком длинное" + ) + remaining, segments = seg.feed(buffer) + assert segments + for s in segments: + assert len(s.text) <= seg.max_length + 5 # small tolerance + + +def test_segmenter_flush(): + seg = Segmenter(min_length=100, max_length=200) + remaining, _ = seg.feed("Короткий текст") + assert remaining == "Короткий текст" + flushed = seg.flush(remaining) + assert len(flushed) == 1 + assert flushed[0].text == "Короткий текст" diff --git a/tests/test_server.py b/tests/test_server.py new file mode 100644 index 0000000..d88109c --- /dev/null +++ b/tests/test_server.py @@ -0,0 +1,92 @@ +"""End-to-end WebSocket tests for the TTS server.""" + +import asyncio +import json + +import pytest +import pytest_asyncio +import websockets +from httpx import AsyncClient + +from voice_tts.api.server import build_app +from voice_tts.tts.engine import DummyTTSEngine + + +@pytest_asyncio.fixture +async def dummy_server(): + """Run a local Uvicorn server with the dummy TTS backend.""" + import uvicorn + + engine = DummyTTSEngine(sample_rate=24_000) + app = build_app(engine) + config = uvicorn.Config(app, host="127.0.0.1", port=9876, log_level="warning") + server = uvicorn.Server(config) + task = asyncio.create_task(server.serve()) + # Wait until server is ready + for _ in range(50): + if server.started: + break + await asyncio.sleep(0.05) + yield "ws://127.0.0.1:9876/ws", "http://127.0.0.1:9876/health" + server.should_exit = True + await task + + +@pytest.mark.asyncio +async def test_health(dummy_server): + ws_url, health_url = dummy_server + async with AsyncClient() as client: + response = await client.get(health_url) + assert response.status_code == 200 + data = response.json() + assert data["status"] == "ok" + + +@pytest.mark.asyncio +async def test_websocket_streaming_and_stop(dummy_server): + ws_url, _ = dummy_server + received_audio = 0 + received_statuses = [] + + async with websockets.connect(ws_url) as ws: + await ws.send(json.dumps({ + "type": "init", + "session_id": "test", + "seq": 1, + })) + msg = json.loads(await asyncio.wait_for(ws.recv(), timeout=5)) + assert msg["type"] == "status" + assert msg["event"] == "session_ready" + + for i, chunk in enumerate(["Привет, ", "как ", "дела?"]): + await ws.send(json.dumps({"type": "text", "payload": chunk, "seq": 2 + i})) + # Give the server a moment to process each chunk sequentially, + # mirroring how a streaming LLM would emit tokens. + await asyncio.sleep(0.05) + + await ws.send(json.dumps({"type": "flush", "seq": 5})) + + # Collect messages until segment_finished and at least one audio chunk. + deadline = asyncio.get_event_loop().time() + 5 + while asyncio.get_event_loop().time() < deadline: + msg = json.loads(await asyncio.wait_for(ws.recv(), timeout=5)) + received_statuses.append((msg["type"], msg.get("event"))) + if msg["type"] == "audio": + received_audio += 1 + if msg["type"] == "status" and msg.get("event") == "segment_finished": + break + + # Interrupt + await ws.send(json.dumps({"type": "stop", "reason": "interrupt", "seq": 6})) + + # Drain any audio/status messages already in flight before the stop is processed. + for _ in range(20): + msg = json.loads(await asyncio.wait_for(ws.recv(), timeout=5)) + if msg["type"] == "status" and msg.get("event") == "stopped": + break + else: + pytest.fail("Did not receive stopped status") + + assert received_audio >= 1 + assert ("status", "segment_started") in received_statuses + assert ("status", "segment_finished") in received_statuses diff --git a/voices/.gitkeep b/voices/.gitkeep new file mode 100644 index 0000000..e69de29 --- /dev/null +++ b/voices/.gitkeep