"""Unit tests for the Fish Speech 1.5 backend with mocked inference engine."""

from pathlib import Path

import numpy as np
import pytest
import torch

from voice_tts.config import settings
from voice_tts.tts.fish_speech_backend import FishSpeechEngine


@pytest.fixture
def mocked_fish(monkeypatch, tmp_path):
    """Return a FishSpeechEngine with all heavy model imports stubbed."""

    fake_root = tmp_path / "fish-speech"
    fake_root.mkdir()
    (fake_root / "fish_speech").mkdir()
    (fake_root / "fish_speech" / "__init__.py").write_text("")
    (fake_root / "fish_speech" / "inference_engine.py").write_text(
        "class TTSInferenceEngine:\n    pass\n"
    )
    (fake_root / "fish_speech" / "inference_engine").mkdir()
    (fake_root / "fish_speech" / "inference_engine" / "__init__.py").write_text(
        "class TTSInferenceEngine:\n    pass\n"
    )
    (fake_root / "fish_speech" / "models").mkdir()
    (fake_root / "fish_speech" / "models" / "__init__.py").write_text("")
    (fake_root / "fish_speech" / "models" / "text2semantic").mkdir()
    (fake_root / "fish_speech" / "models" / "text2semantic" / "__init__.py").write_text("")
    (fake_root / "fish_speech" / "models" / "text2semantic" / "inference.py").write_text(
        "def launch_thread_safe_queue(*args, **kwargs):\n"
        "    return object()\n"
    )
    vqgan = fake_root / "fish_speech" / "models" / "vqgan"
    vqgan.mkdir()
    (vqgan / "__init__.py").write_text("")
    (vqgan / "inference.py").write_text(
        "def load_model(*args, **kwargs):\n"
        "    class FakeDecoder:\n"
        "        spec_transform = type('T', (), {'sample_rate': 44100})()\n"
        "        def encode(self, audios, lengths):\n"
        "            return [None]\n"
        "    return FakeDecoder()\n"
    )
    (fake_root / "fish_speech" / "utils").mkdir()
    (fake_root / "fish_speech" / "utils" / "__init__.py").write_text("")
    (fake_root / "fish_speech" / "utils" / "schema.py").write_text(
        "from pydantic import BaseModel\n"
        "from typing import List\n"
        "class ServeReferenceAudio(BaseModel):\n"
        "    audio: bytes\n"
        "    text: str\n"
"class ServeTTSRequest(BaseModel):\n"
"    text: str\n"
"    references: List[ServeReferenceAudio] = []\n"
"    seed: int | None = None\n"
"    top_p: float = 0.7\n"
"    temperature: float = 0.7\n"
"    repetition_penalty: float = 1.2\n"
"    max_new_tokens: int = 1024\n"
"    chunk_length: int = 200\n"
"    use_memory_cache: str = 'on'\n"
    )

    # Stub the heavy model functions used inside load().
    class FakeEngine:
        def __init__(self, *args, **kwargs):
            pass

        def inference(self, req):
            sample_rate = 44_100
            audio = np.linspace(-0.9, 0.9, sample_rate, dtype=np.float32)
            yield type(
                "R", (), {"code": "final", "error": None, "audio": (sample_rate, audio)}
            )()

    # Point sys.path at fake root temporarily so imports succeed.
    monkeypatch.syspath_prepend(str(fake_root))

    from fish_speech import inference_engine as _ie

    monkeypatch.setattr(_ie, "TTSInferenceEngine", FakeEngine)

    ref_wav = tmp_path / "ref.wav"
    ref_wav.write_bytes(b"RIFF" + b"\x00" * 40)
    ref_lab = tmp_path / "ref.lab"
    ref_lab.write_text("Reference transcript from lab file.", encoding="utf-8")

    engine = FishSpeechEngine(
        checkpoint_path=tmp_path / "checkpoint",
        source_root=fake_root,
        device="cpu",
        precision=torch.float32,
        chunk_length=200,
    )
    engine.load()
    return engine, ref_wav, ref_lab


@pytest.mark.asyncio
async def test_synthesize_returns_audio(mocked_fish):
    engine, ref_wav, _ = mocked_fish
    audio = await engine.synthesize(
        text="Hello world.",
        ref_audio_path=ref_wav,
        language="en",
        speed=1.0,
        emotion="neutral",
    )
    assert audio.ndim == 1
    assert len(audio) == engine.sample_rate
    assert audio.dtype == np.float32
    assert np.max(np.abs(audio)) <= 1.1


@pytest.mark.asyncio
async def test_unsupported_language_raises(mocked_fish):
    engine, ref_wav, _ = mocked_fish
    with pytest.raises(ValueError, match="not supported"):
        await engine.synthesize(
            text="Bonjour.",
            ref_audio_path=ref_wav,
            language="fr",
            speed=1.0,
            emotion="neutral",
        )


@pytest.mark.asyncio
async def test_missing_reference_raises(mocked_fish):
    engine, _, _ = mocked_fish
    with pytest.raises(FileNotFoundError):
        await engine.synthesize(
            text="Hello.",
            ref_audio_path=Path("/nonexistent/ref.wav"),
            language="en",
            speed=1.0,
            emotion="neutral",
        )


@pytest.mark.asyncio
async def test_ref_text_fallback_order(monkeypatch, mocked_fish):
    engine, ref_wav, ref_lab = mocked_fish

    captured = {}

    def _spy_inference(req):
        captured["text"] = req.references[0].text
        sample_rate = 44_100
        audio = np.zeros(sample_rate, dtype=np.float32)
        yield type(
            "R", (), {"code": "final", "error": None, "audio": (sample_rate, audio)}
        )()

    # Replace the already-created engine's inference method with a spy.
    engine._engine.inference = _spy_inference

    # 1. Explicit ref_text wins.
    await engine.synthesize(
        text="Hi.",
        ref_audio_path=ref_wav,
        language="en",
        speed=1.0,
        emotion="neutral",
        ref_text="explicit",
    )
    assert captured["text"] == "explicit"

    # 2. .lab file next to reference is used when no explicit text.
    await engine.synthesize(
        text="Hi.",
        ref_audio_path=ref_wav,
        language="en",
        speed=1.0,
        emotion="neutral",
    )
    assert captured["text"] == "Reference transcript from lab file."

    # 3. settings.default_ref_text wins over placeholder.
    monkeypatch.setattr(settings, "default_ref_text", "settings default")
    ref_lab.unlink(missing_ok=True)
    await engine.synthesize(
        text="Hi.",
        ref_audio_path=ref_wav,
        language="en",
        speed=1.0,
        emotion="neutral",
    )
    assert captured["text"] == "settings default"

    # 4. Placeholder when nothing else is available.
    monkeypatch.setattr(settings, "default_ref_text", None)
    await engine.synthesize(
        text="Hi.",
        ref_audio_path=ref_wav,
        language="en",
        speed=1.0,
        emotion="neutral",
    )
    assert captured["text"] == "Hello, this is a reference voice recording."


@pytest.mark.asyncio
async def test_speed_resampling_changes_length(mocked_fish):
    engine, ref_wav, _ = mocked_fish
    slow = await engine.synthesize(
        text="Hello.",
        ref_audio_path=ref_wav,
        language="en",
        speed=1.5,
        emotion="neutral",
    )
    fast = await engine.synthesize(
        text="Hello.",
        ref_audio_path=ref_wav,
        language="en",
        speed=0.8,
        emotion="neutral",
    )
    assert len(fast) > len(slow)
    # Speed 1.5 should produce fewer samples than original; 0.8 should produce more.
    normal = await engine.synthesize(
        text="Hello.",
        ref_audio_path=ref_wav,
        language="en",
        speed=1.0,
        emotion="neutral",
    )
    assert len(slow) < len(normal)
    assert len(fast) > len(normal)
