"""Simple Python WebSocket client for the Voice TTS server.

Streams text in chunks, receives base64 PCM audio and plays it via sounddevice.
Install dependencies:
    pip install websockets sounddevice
"""

from __future__ import annotations

import argparse
import asyncio
import base64
import json
import sys
from typing import Any

import numpy as np
import sounddevice as sd
import websockets


class VoiceTTSClient:
    """WebSocket client that speaks LLM/agent output in real time."""

    def __init__(
        self,
        uri: str = "ws://localhost:8765/ws",
        voice_ref: str | None = None,
        voice_refs: dict[str, str] | None = None,
        language: str = "ru",
        speed: float = 1.0,
        emotion: str = "neutral",
        sample_rate: int = 24_000,
        block_size: int = 2048,
    ):
        self.uri = uri
        self.voice_ref = voice_ref
        self.voice_refs = voice_refs or {}
        self.language = language
        self.speed = speed
        self.emotion = emotion
        self.sample_rate = sample_rate
        self.block_size = block_size

        self._ws: websockets.WebSocketClientProtocol | None = None
        self._seq = 0
        self._audio_buffer: bytearray = bytearray()
        self._current_segment_seq: int | None = None
        self._stream: sd.RawOutputStream | None = None
        self._lock = asyncio.Lock()

    def _next_seq(self) -> int:
        self._seq += 1
        return self._seq

    def _send_dict(self, payload: dict[str, Any]) -> None:
        if self._ws is None:
            raise RuntimeError("WebSocket is not connected")
        asyncio.create_task(self._ws.send(json.dumps(payload, ensure_ascii=False)))

    async def connect(self) -> None:
        self._ws = await websockets.connect(self.uri)

        # Start audio output stream. sounddevice resamples if necessary.
        self._stream = sd.RawOutputStream(
            samplerate=self.sample_rate,
            channels=1,
            dtype="int16",
            blocksize=self.block_size,
            callback=self._audio_callback,
        )
        self._stream.start()

        init_msg = {
            "type": "init",
            "seq": self._next_seq(),
            "session_id": "python-client",
            "language": self.language,
            "speed": self.speed,
            "emotion": self.emotion,
        }
        if self.voice_ref:
            init_msg["voice_ref"] = self.voice_ref
        if self.voice_refs:
            init_msg["voice_refs"] = self.voice_refs

        await self._ws.send(json.dumps(init_msg, ensure_ascii=False))

    def _audio_callback(self, outdata: np.ndarray, frames: int, _time, _status) -> None:
        """Pull audio bytes from the buffer into the sounddevice stream."""
        needed = frames * 2  # int16 = 2 bytes
        available = len(self._audio_buffer)
        if available >= needed:
            chunk = bytes(self._audio_buffer[:needed])
            self._audio_buffer = self._audio_buffer[needed:]
        else:
            chunk = bytes(self._audio_buffer) + b"\x00" * (needed - available)
            self._audio_buffer = bytearray()
        outdata[:] = np.frombuffer(chunk, dtype=np.int16).reshape(-1, 1)

    async def speak_text(self, text: str, chunk_delay: float = 0.15) -> None:
        """Simulate streaming text by sending it word-by-word."""
        if self._ws is None:
            raise RuntimeError("Call connect() first")

        words = text.split()
        for i, word in enumerate(words):
            payload = word + (" " if i < len(words) - 1 else "")
            await self._ws.send(
                json.dumps(
                    {"type": "text", "payload": payload, "seq": self._next_seq()},
                    ensure_ascii=False,
                )
            )
            await asyncio.sleep(chunk_delay)

        await self._ws.send(json.dumps({"type": "flush", "seq": self._next_seq()}))

    async def stop(self, reason: str = "interrupt") -> None:
        if self._ws is None:
            return
        await self._ws.send(
            json.dumps({"type": "stop", "reason": reason, "seq": self._next_seq()})
        )
        async with self._lock:
            self._audio_buffer = bytearray()

    async def run(self, text: str) -> None:
        await self.connect()
        assert self._ws is not None
        try:
            receive_task = asyncio.create_task(self._receive_loop())
            await self.speak_text(text)
            await receive_task
        except websockets.exceptions.ConnectionClosed:
            pass
        finally:
            await self.close()

    async def _receive_loop(self) -> None:
        assert self._ws is not None
        finished_events = {"stopped", "finished"}
        while True:
            try:
                raw = await self._ws.recv()
            except websockets.exceptions.ConnectionClosed:
                break

            msg = json.loads(raw)
            msg_type = msg.get("type")

            if msg_type == "audio":
                pcm = base64.b64decode(msg["data"])
                async with self._lock:
                    self._audio_buffer.extend(pcm)

            elif msg_type == "status":
                event = msg.get("event")
                if event in finished_events:
                    # Wait for the audio buffer to drain before exiting.
                    await self._drain()
                    break
                print(f"[status] {event} seq={msg.get('seq')}")

            elif msg_type == "error":
                print(f"[error] {msg.get('message')}", file=sys.stderr)

    async def _drain(self) -> None:
        """Wait until the local audio buffer has been played."""
        while True:
            async with self._lock:
                if len(self._audio_buffer) == 0:
                    break
            await asyncio.sleep(0.05)
        # Give sounddevice a little extra time to finish its current block.
        await asyncio.sleep(0.2)

    async def close(self) -> None:
        if self._stream is not None:
            self._stream.stop()
            self._stream.close()
            self._stream = None
        if self._ws is not None:
            await self._ws.close()
            self._ws = None


def main() -> None:
    parser = argparse.ArgumentParser(description="Voice TTS WebSocket client")
    parser.add_argument("--uri", default="ws://localhost:8765/ws")
    parser.add_argument("--voice-ref", default=None)
    parser.add_argument("--language", default="ru")
    parser.add_argument("--speed", type=float, default=1.0)
    parser.add_argument("--emotion", default="neutral")
    parser.add_argument("--sample-rate", type=int, default=24_000)
    parser.add_argument("text", nargs="*", default=["Привет. Это тестовая фраза."])
    args = parser.parse_args()

    text = " ".join(args.text)
    client = VoiceTTSClient(
        uri=args.uri,
        voice_ref=args.voice_ref,
        language=args.language,
        speed=args.speed,
        emotion=args.emotion,
        sample_rate=args.sample_rate,
    )

    try:
        asyncio.run(client.run(text))
    except KeyboardInterrupt:
        print("Interrupted")


if __name__ == "__main__":
    main()
