from pathlib import Path

import numpy as np
from loguru import logger

from voice_tts.config import settings
from voice_tts.tts import register as _register_backend
from voice_tts.tts.engine import TTSEngine


try:
    import torch
    import torchaudio
    from f5_tts.api import F5TTS
    from f5_tts.infer.utils_infer import preprocess_ref_audio_text

    F5_AVAILABLE = True
except ImportError as exc:
    logger.warning("f5-tts/torch dependencies not available: {}", exc)
    F5_AVAILABLE = False
    torch = None
    torchaudio = None


@_register_backend("f5_tts")
class F5TTSEngine(TTSEngine):
    """F5-TTS backend for local GPU inference with voice cloning."""

    def __init__(
        self,
        model: str | None = None,
        sample_rate: int | None = None,
        device: str | None = None,
        nfe_step: int | None = None,
        cfg_strength: float | None = None,
        sway_sampling_coef: float | None = None,
        speed: float | None = None,
        target_rms: float | None = None,
        cross_fade_duration: float | None = None,
        remove_silence: bool | None = None,
    ):
        super().__init__()
        self.model_name = model or settings.tts_model_name
        self.sample_rate = sample_rate or settings.tts_sample_rate
        self.device = device or settings.device
        self.nfe_step = nfe_step or 32
        self.cfg_strength = cfg_strength or 2.0
        self.sway_sampling_coef = sway_sampling_coef or -1.0
        self.speed = speed or 1.0
        self.target_rms = target_rms or 0.1
        self.cross_fade_duration = cross_fade_duration or 0.0
        self.remove_silence = remove_silence or False

        self._f5: "F5TTS | None" = None
        self._ref_cache: dict[str, tuple[str, str]] = {}  # emotion_key -> (processed_audio_path, ref_text)

    def _get_key(self, ref_path: Path | None, emotion: str) -> str:
        path_str = str(ref_path) if ref_path else ""
        return f"{emotion}::{path_str}"

    def load(self) -> None:
        if not F5_AVAILABLE:
            raise RuntimeError(
                "f5-tts/torch package is not installed. Install it: pip install f5-tts torch torchaudio"
            )
        logger.info(
            "Loading F5-TTS model {} on device {} ...",
            self.model_name,
            self.device,
        )
        ckpt_file = str(settings.tts_model_path) if settings.tts_model_path else ""
        vocab_file = str(settings.tts_vocab_path) if settings.tts_vocab_path else ""
        logger.info(
            "Loading F5-TTS model {} ckpt={} vocab={} ...",
            self.model_name,
            ckpt_file or "(default)",
            vocab_file or "(default)",
        )
        self._f5 = F5TTS(
            model=self.model_name,
            ckpt_file=ckpt_file,
            vocab_file=vocab_file,
            device=self.device,
        )
        self.sample_rate = self._f5.target_sample_rate
        logger.info("F5-TTS loaded. Target sample rate: {}", self.sample_rate)

    async def warm_up(self) -> None:
        if self._f5 is None:
            self.load()
        logger.info("Warming up F5-TTS ...")
        logger.info("Warm-up skipped: provide a reference audio before warm-up.")

    def _ensure_reference(
        self,
        ref_audio_path: Path | None,
        emotion: str,
        ref_text_override: str | None = None,
    ) -> tuple[str, str]:
        if not F5_AVAILABLE:
            raise RuntimeError("f5-tts/torch is not installed")
        if ref_audio_path is None:
            raise ValueError("F5-TTS requires a reference audio file (voice_ref).")

        if isinstance(ref_audio_path, str):
            ref_audio_path = Path(ref_audio_path)

        key = self._get_key(ref_audio_path, emotion)
        if key in self._ref_cache:
            return self._ref_cache[key]

        if not ref_audio_path.exists():
            raise FileNotFoundError(f"Reference audio not found: {ref_audio_path}")

        ref_text = ref_text_override or settings.default_ref_text or ""
        processed_audio, ref_text = preprocess_ref_audio_text(
            str(ref_audio_path),
            ref_text,
        )
        self._ref_cache[key] = (processed_audio, ref_text)
        logger.info(
            "Reference cached for emotion={} path={} text={}",
            emotion,
            ref_audio_path,
            ref_text,
        )
        return processed_audio, ref_text

    async def synthesize(
        self,
        text: str,
        ref_audio_path: Path | None,
        language: str,
        speed: float,
        emotion: str,
        ref_text: str | None = None,
    ) -> np.ndarray:
        if self._f5 is None:
            self.load()

        assert self._f5 is not None

        if isinstance(ref_audio_path, str):
            ref_audio_path = Path(ref_audio_path)

        processed_audio, ref_text = self._ensure_reference(ref_audio_path, emotion, ref_text)

        wav, sr, _spec = self._f5.infer(
            ref_file=processed_audio,
            ref_text=ref_text,
            gen_text=text,
            nfe_step=self.nfe_step,
            cfg_strength=self.cfg_strength,
            sway_sampling_coef=self.sway_sampling_coef,
            speed=speed,
            cross_fade_duration=self.cross_fade_duration,
            target_rms=self.target_rms,
            remove_silence=self.remove_silence,
        )

        if wav is None:
            raise RuntimeError("F5-TTS produced no audio")

        # Ensure shape and dtype
        if isinstance(wav, torch.Tensor):
            wav = wav.squeeze().cpu().numpy()
        wav = np.asarray(wav, dtype=np.float32)
        if wav.ndim == 0:
            wav = np.zeros(1, dtype=np.float32)
        elif wav.ndim > 1:
            wav = wav.squeeze()

        return wav
