"""WebSocket message protocol models."""

from typing import Any, Literal
from pydantic import BaseModel, Field


class ClientMessage(BaseModel):
    """Base class for incoming WebSocket messages from the agent."""

    type: str
    seq: int | None = None


class InitMessage(ClientMessage):
    """Initialize or reconfigure a session: voice reference, language, speed, emotion."""

    type: Literal["init"] = "init"
    session_id: str | None = None
    voice_ref: str | None = None
    voice_refs: dict[str, str] | None = None
    language: str = "ru"
    speed: float = Field(1.0, ge=0.5, le=2.0)
    emotion: str = "neutral"


class TextMessage(ClientMessage):
    """Streaming text chunk from the LLM/agent."""

    type: Literal["text"] = "text"
    payload: str
    emotion: str | None = None


class FlushMessage(ClientMessage):
    """Force-speak everything currently buffered."""

    type: Literal["flush"] = "flush"


class StopMessage(ClientMessage):
    """Interrupt and clear everything: buffer, current generation, audio queue."""

    type: Literal["stop"] = "stop"
    reason: str = "interrupt"


class EmotionMessage(ClientMessage):
    """Change emotion for upcoming segments."""

    type: Literal["emotion"] = "emotion"
    emotion: str


class ConfigMessage(ClientMessage):
    """Change runtime parameters (speed, etc.)."""

    type: Literal["config"] = "config"
    speed: float | None = Field(None, ge=0.5, le=2.0)
    language: str | None = None


ClientMessageUnion = InitMessage | TextMessage | FlushMessage | StopMessage | EmotionMessage | ConfigMessage


# ---------------------------------------------------------------------------
# Server -> client messages
# ---------------------------------------------------------------------------


class ServerMessage(BaseModel):
    """Base class for outgoing WebSocket messages."""

    type: str
    seq: int | None = None


class AudioMessage(ServerMessage):
    """Audio chunk encoded as base64 PCM."""

    type: Literal["audio"] = "audio"
    format: Literal["pcm_s16le"] = "pcm_s16le"
    sample_rate: int
    channels: int = 1
    data: str  # base64


class StatusMessage(ServerMessage):
    """Lifecycle/status events."""

    type: Literal["status"] = "status"
    event: Literal[
        "session_ready",
        "segment_started",
        "segment_finished",
        "finished",
        "stopped",
        "error",
    ]
    reason: str | None = None
    extra: dict[str, Any] | None = None


class ErrorMessage(ServerMessage):
    """Error notification."""

    type: Literal["error"] = "error"
    message: str


ServerMessageUnion = AudioMessage | StatusMessage | ErrorMessage
