"""WebSocket server and session lifecycle."""
import asyncio
from contextlib import asynccontextmanager
import json
from pathlib import Path
from fastapi import FastAPI, WebSocket, WebSocketDisconnect
from loguru import logger
from voice_tts.api.protocol import (
AudioMessage,
ClientMessageUnion,
ConfigMessage,
EmotionMessage,
ErrorMessage,
FlushMessage,
InitMessage,
ServerMessageUnion,
StatusMessage,
StopMessage,
TextMessage,
)
from voice_tts.audio.formats import float_to_pcm16, pcm16_to_base64
from voice_tts.config import settings
from voice_tts.session.state import SessionState, VoiceProfile
from voice_tts.tts.engine import DummyTTSEngine, TTSEngine
from voice_tts.tts.f5_backend import F5TTSEngine
from voice_tts.tts.segmenter import Segmenter
# Supported TTS backends
_BACKEND_MAP: dict[str, type[TTSEngine]] = {
"dummy": DummyTTSEngine,
"f5_tts": F5TTSEngine,
}
class SessionManager:
"""Manages the lifecycle of a single WebSocket TTS session."""
def __init__(self, websocket: WebSocket, engine: TTSEngine):
self.ws = websocket
self.engine = engine
self.segmenter = Segmenter(
min_length=settings.min_segment_length,
max_length=settings.max_segment_length,
)
self.state = SessionState(session_id="")
self._running = True
self._tasks: list[asyncio.Task] = []
self._send_lock = asyncio.Lock()
self._synth_lock = asyncio.Lock()
async def run(self) -> None:
await self.ws.accept()
logger.info("WebSocket connection accepted")
# Start audio sender worker
sender_task = asyncio.create_task(self._audio_sender())
self._tasks.append(sender_task)
try:
while self._running:
raw = await self.ws.receive_text()
try:
data = json.loads(raw)
msg = self._parse_message(data)
except Exception as exc:
await self._send(ErrorMessage(message=f"Invalid message: {exc}"))
continue
await self._handle_message(msg)
except WebSocketDisconnect:
logger.info("Client disconnected")
finally:
self._stop_all()
for task in self._tasks:
task.cancel()
try:
await task
except asyncio.CancelledError:
pass
def _parse_message(self, data: dict) -> ClientMessageUnion:
msg_type = data.get("type")
if msg_type == "init":
return InitMessage(**data)
if msg_type == "text":
return TextMessage(**data)
if msg_type == "flush":
return FlushMessage(**data)
if msg_type == "stop":
return StopMessage(**data)
if msg_type == "emotion":
return EmotionMessage(**data)
if msg_type == "config":
return ConfigMessage(**data)
raise ValueError(f"Unknown message type: {msg_type}")
async def _handle_message(self, msg: ClientMessageUnion) -> None:
if isinstance(msg, InitMessage):
await self._handle_init(msg)
elif isinstance(msg, TextMessage):
await self._handle_text(msg)
elif isinstance(msg, FlushMessage):
await self._handle_flush(msg)
elif isinstance(msg, StopMessage):
await self._handle_stop(msg)
elif isinstance(msg, EmotionMessage):
await self._handle_emotion(msg)
elif isinstance(msg, ConfigMessage):
await self._handle_config(msg)
async def _handle_init(self, msg: InitMessage) -> None:
self.state.session_id = msg.session_id or "default"
self.state.language = msg.language
self.state.speed = msg.speed
self.state.emotion = msg.emotion
self.state.clear_buffer()
self.state.reset_stop()
voice = VoiceProfile()
if msg.voice_ref:
voice.default_ref = Path(msg.voice_ref)
if msg.voice_refs:
voice.emotion_refs = {
emotion: Path(path) for emotion, path in msg.voice_refs.items()
}
self.state.voice = voice
await self._send(StatusMessage(event="session_ready", seq=msg.seq))
logger.info(
"Session initialized: id={} language={} speed={} emotion={}",
self.state.session_id,
self.state.language,
self.state.speed,
self.state.emotion,
)
async def _handle_text(self, msg: TextMessage) -> None:
if self.state.is_stopped():
# After a stop, new text implicitly resets the stop flag
self.state.reset_stop()
self.state.text_buffer += msg.payload
if msg.emotion:
self.state.emotion = msg.emotion
from voice_tts.tts.utils import preprocess_text_for_tts
self.state.text_buffer = preprocess_text_for_tts(self.state.text_buffer)
remaining, segments = self.segmenter.feed(self.state.text_buffer)
self.state.text_buffer = remaining
for segment in segments:
asyncio.create_task(self._synthesize_segment(segment.text, msg.seq))
# If this text chunk itself ends a sentence and the buffer is too short,
# we might still be holding it. Force a flush so short complete sentences
# (e.g. "дела?") are not stuck waiting for more input.
if _looks_like_complete_phrase(msg.payload) and self.state.text_buffer.strip():
segments = self.segmenter.flush(self.state.text_buffer)
self.state.text_buffer = ""
for segment in segments:
asyncio.create_task(self._synthesize_segment(segment.text, msg.seq))
async def _handle_flush(self, msg: FlushMessage) -> None:
segments = self.segmenter.flush(self.state.text_buffer)
self.state.text_buffer = ""
for segment in segments:
asyncio.create_task(self._synthesize_segment(segment.text, msg.seq))
async def _handle_stop(self, msg: StopMessage) -> None:
logger.info("Stop requested: {}", msg.reason)
self.state.stop()
self.state.clear_buffer()
# Drain pending audio queue
while not self.state.audio_queue.empty():
try:
self.state.audio_queue.get_nowait()
self.state.audio_queue.task_done()
except asyncio.QueueEmpty:
break
await self._send(StatusMessage(event="stopped", reason=msg.reason, seq=msg.seq))
# Don't kill the connection: agent may continue the session after stop.
async def _handle_emotion(self, msg: EmotionMessage) -> None:
self.state.emotion = msg.emotion
await self._send(
StatusMessage(
event="config_updated",
extra={"emotion": msg.emotion},
seq=msg.seq,
)
)
async def _handle_config(self, msg: ConfigMessage) -> None:
if msg.speed is not None:
self.state.speed = msg.speed
if msg.language is not None:
self.state.language = msg.language
await self._send(
StatusMessage(
event="config_updated",
extra={"speed": self.state.speed, "language": self.state.language},
seq=msg.seq,
)
)
async def _synthesize_segment(self, text: str, seq: int | None) -> None:
if self.state.is_stopped():
return
segment_seq = self.state.next_segment_seq()
await self._send(StatusMessage(event="segment_started", seq=segment_seq))
ref_path = self.state.voice.ref_for(self.state.emotion)
try:
# Serialize GPU inference and run it off the main event loop.
async with self._synth_lock:
audio = await asyncio.to_thread(
SessionManager._sync_synthesize,
self.engine,
text,
ref_path,
self.state.language,
self.state.speed,
self.state.emotion,
)
except Exception as exc:
logger.exception("TTS synthesis failed")
await self._send(ErrorMessage(message=f"TTS failed: {exc}", seq=segment_seq))
return
if self.state.is_stopped():
return
pcm = float_to_pcm16(audio)
await self.state.audio_queue.put((segment_seq, pcm))
async def _audio_sender(self) -> None:
while self._running:
try:
segment_seq, pcm = await asyncio.wait_for(self.state.audio_queue.get(), timeout=0.5)
except asyncio.TimeoutError:
continue
if self.state.is_stopped():
self.state.audio_queue.task_done()
continue
await self._send(
AudioMessage(
sample_rate=self.engine.sample_rate,
data=pcm16_to_base64(pcm),
seq=segment_seq,
)
)
await self._send(StatusMessage(event="segment_finished", seq=segment_seq))
self.state.audio_queue.task_done()
async def _send(self, msg: ServerMessageUnion) -> None:
async with self._send_lock:
try:
await self.ws.send_text(msg.model_dump_json(exclude_none=True))
except Exception as exc:
logger.warning("Failed to send WebSocket message: {}", exc)
self._running = False
@staticmethod
def _sync_synthesize(
engine: "TTSEngine",
text: str,
ref_path: Path | None,
language: str,
speed: float,
emotion: str,
) -> "np.ndarray":
"""Thread-safe wrapper around the synchronous TTS inference call."""
import numpy as np
# The dummy backend is async; run it on a transient event loop in this thread.
if isinstance(engine, DummyTTSEngine):
try:
loop = asyncio.get_running_loop()
except RuntimeError:
loop = None
async def _run() -> "np.ndarray":
return await engine.synthesize(
text=text,
ref_audio_path=ref_path,
language=language,
speed=speed,
emotion=emotion,
)
if loop is not None:
return asyncio.run_coroutine_threadsafe(_run(), loop).result()
return asyncio.run(_run())
# F5-TTS exposes an async synthesize method that blocks on CPU/CUDA work.
# Inside a thread from asyncio.to_thread there is no running loop, so we
# drive the coroutine with a fresh transient event loop.
kwargs: dict = dict(
text=text,
ref_audio_path=ref_path,
language=language,
speed=speed,
emotion=emotion,
)
if isinstance(engine, F5TTSEngine) and settings.default_ref_text:
kwargs["ref_text"] = settings.default_ref_text
return asyncio.run(engine.synthesize(**kwargs))
def _stop_all(self) -> None:
self._running = False
self.state.stop()
def _looks_like_complete_phrase(text: str) -> bool:
"""Heuristic: chunk ends with strong punctuation -> sentence is complete."""
text = text.rstrip()
return bool(text) and text[-1] in ".。!??!"
def _create_engine() -> TTSEngine:
backend = settings.tts_backend
engine_cls = _BACKEND_MAP.get(backend)
if engine_cls is None:
raise RuntimeError(
f"Unknown TTS backend: {backend}. "
f"Available backends: {list(_BACKEND_MAP.keys())}"
)
engine = engine_cls(sample_rate=settings.tts_sample_rate)
if hasattr(engine, "load"):
engine.load()
return engine
def build_app(engine: TTSEngine | None = None) -> FastAPI:
tts_engine = engine or _create_engine()
@asynccontextmanager
async def _lifespan(app: FastAPI) -> None:
if settings.warmup and settings.default_voice_ref is not None:
try:
logger.info(
"Warming up {} with reference {} ...",
settings.tts_backend,
settings.default_voice_ref,
)
await asyncio.to_thread(
SessionManager._sync_synthesize,
tts_engine,
settings.warmup_text,
settings.default_voice_ref,
"en",
1.0,
"neutral",
)
logger.info("Warm-up complete.")
except Exception as exc:
logger.warning("Warm-up failed (continuing anyway): {}", exc)
else:
logger.info("Warm-up skipped.")
yield
app = FastAPI(title="Voice TTS", version="0.1.0", lifespan=_lifespan)
@app.websocket("/ws")
async def websocket_endpoint(websocket: WebSocket) -> None:
session = SessionManager(websocket, tts_engine)
await session.run()
@app.get("/health")
async def health() -> dict:
return {"status": "ok", "backend": settings.tts_backend}
return app
app = build_app()