Newer
Older
voice / src / voice_tts / tts / segmenter.py
"""Lightweight streaming text segmenter."""

import re
from dataclasses import dataclass


@dataclass
class Segment:
    text: str
    is_final: bool = False


class Segmenter:
    """Splits streaming text into TTS-ready segments.

    Supports progressive chunking: the first few segments use a lower
    ``min_length`` so audio starts sooner (lower initial latency).
    """

    def __init__(
        self,
        min_length: int = 30,
        max_length: int = 200,
        fast_start_initial: int = 12,
        fast_start_count: int = 3,
    ):
        self.min_length = min_length
        self.max_length = max_length
        self._fast_start_initial = fast_start_initial
        self._fast_start_count = fast_start_count
        self._segments_emitted = 0

        # End-of-sentence delimiters
        self.sentence_breaks = re.compile(r"[.。!??!\n]+")
        # Clause delimiters for long segments
        self.clause_breaks = re.compile(r"[,;:\-—()()]")
        self.whitespace_re = re.compile(r"\s+")

    def _effective_min_length(self) -> int:
        """Minimum length before a segment is considered 'ready'.

        Gradually ramps up from ``fast_start_initial`` to ``min_length``
        over the first ``fast_start_count`` segments.
        """
        if self._segments_emitted < self._fast_start_count:
            fraction = self._segments_emitted / max(self._fast_start_count - 1, 1)
            return int(
                self._fast_start_initial
                + (self.min_length - self._fast_start_initial) * fraction
            )
        return self.min_length

    def feed(self, buffer: str) -> tuple[str, list[Segment]]:
        """
        Consume `buffer` and return (remaining_buffer, ready_segments).

        Logic:
        - Prefer cutting at sentence boundaries.
        - If a segment grows beyond max_length, cut at the nearest clause boundary.
        - Segments shorter than the effective min_length are returned only if
          the caller forces flush or the input is shorter but ends with a
          sentence break.
        - Progressive chunking: early segments use a lower threshold so audio
          starts sooner.
        """
        segments: list[Segment] = []
        remaining = buffer

        while remaining:
            min_len = self._effective_min_length()

            # Find the first sentence boundary.
            first_sentence_cut = -1
            for match in self.sentence_breaks.finditer(remaining):
                first_sentence_cut = match.end()
                break

            # If there is a complete sentence, decide whether to emit it.
            if first_sentence_cut != -1:
                segment_text = remaining[:first_sentence_cut].strip()

                # Case 1: sentence is long enough -> emit immediately.
                if len(segment_text) >= min_len:
                    remaining = remaining[first_sentence_cut:].lstrip()
                    if segment_text:
                        segments.append(Segment(text=segment_text, is_final=True))
                    continue

                # Case 2: sentence is short. Combine it with subsequent sentences
                # until the combined chunk is long enough, or there are no more
                # complete sentences ahead.
                combined_cut = first_sentence_cut
                emitted_combined = False
                while True:
                    next_boundary = -1
                    for match in self.sentence_breaks.finditer(remaining, combined_cut):
                        next_boundary = match.end()
                        break
                    if next_boundary == -1:
                        break
                    combined_text = remaining[:next_boundary].strip()
                    combined_cut = next_boundary
                    if len(combined_text) >= min_len:
                        remaining = remaining[combined_cut:].lstrip()
                        if combined_text:
                            segments.append(Segment(text=combined_text, is_final=True))
                        emitted_combined = True
                        break

                if emitted_combined:
                    continue

                # Case 3: short sentence and no further complete sentence ahead.
                # Flush it so the user doesn't wait indefinitely.
                remaining = remaining[first_sentence_cut:].lstrip()
                if segment_text:
                    segments.append(Segment(text=segment_text, is_final=True))
                continue

            # No usable sentence boundary yet.
            # Consider max-length clause cut.
            if len(remaining) >= self.max_length:
                window = remaining[: self.max_length]
                last_clause = -1
                for match in self.clause_breaks.finditer(window):
                    pos = match.end()
                    if pos >= min_len:
                        last_clause = pos
                if last_clause != -1:
                    segment_text = remaining[:last_clause].strip()
                    remaining = remaining[last_clause:].lstrip()
                    if segment_text:
                        segments.append(Segment(text=segment_text, is_final=True))
                    continue

            # Nothing to cut yet; wait for more text.
            break

        self._segments_emitted += len(segments)
        return remaining, segments

    def flush(self, buffer: str) -> list[Segment]:
        """Force-convert all remaining text to a final segment."""
        text = buffer.strip()
        if not text:
            return []
        return [Segment(text=text, is_final=True)]