import base64
import json
from io import BytesIO
from pathlib import Path
from typing import Any

import httpx
import structlog
from PIL import Image
from tenacity import (
    AsyncRetrying,
    before_sleep_log,
    retry_if_exception_type,
    stop_after_attempt,
    wait_exponential,
)

from vmk_data_collector.core.circuit_breaker import CircuitBreaker
from vmk_data_collector.core.exceptions import OllamaFatalError, OllamaRetryableError
from vmk_data_collector.core.metrics import ai_requests_total

logger = structlog.get_logger()


def _classify_httpx_error(
    exc: httpx.HTTPStatusError,
) -> OllamaRetryableError | OllamaFatalError:
    status = exc.response.status_code
    if status >= 500 or status == 429:
        return OllamaRetryableError(f"Ollama returned {status}: {exc}")
    return OllamaFatalError(f"Ollama returned {status}: {exc}")


_RETRY_CONFIG = {
    "stop": stop_after_attempt(3),
    "wait": wait_exponential(min=1, max=10),
    "retry": retry_if_exception_type(OllamaRetryableError),
    "before_sleep": before_sleep_log(logger, "warning"),
    "reraise": True,
}


class OllamaClient:
    def __init__(self, base_url: str, timeout: int = 120) -> None:
        self._client = httpx.AsyncClient(
            base_url=base_url,
            timeout=timeout,
        )
        self._circuit_breaker = CircuitBreaker(
            failure_threshold=5,
            recovery_timeout=30.0,
            expected_exception=(OllamaRetryableError,),
        )

    async def chat(
        self,
        model: str,
        messages: list[dict[str, Any]],
        json_mode: bool = False,
    ) -> dict[str, Any]:
        try:
            async for attempt in AsyncRetrying(**_RETRY_CONFIG):
                with attempt:
                    result = await self._circuit_breaker.call(
                        self._chat_raw,
                        model,
                        messages,
                        json_mode,
                    )
            ai_requests_total.labels(model=model, status="success").inc()
            return result
        except Exception:
            ai_requests_total.labels(model=model, status="error").inc()
            raise

    async def _chat_raw(
        self,
        model: str,
        messages: list[dict[str, Any]],
        json_mode: bool = False,
    ) -> dict[str, Any]:
        payload: dict[str, Any] = {
            "model": model,
            "messages": messages,
            "stream": False,
        }
        if json_mode:
            payload["format"] = "json"

        logger.info(
            "ollama_chat_request",
            model=model,
            json_mode=json_mode,
            message_count=len(messages),
        )
        try:
            response = await self._client.post("/api/chat", json=payload)
            response.raise_for_status()
            data = response.json()
        except httpx.ConnectError as exc:
            raise OllamaRetryableError(f"Connection error: {exc}") from exc
        except httpx.TimeoutException as exc:
            raise OllamaRetryableError(f"Timeout: {exc}") from exc
        except httpx.HTTPStatusError as exc:
            raise _classify_httpx_error(exc) from exc
        except json.JSONDecodeError as exc:
            raise OllamaFatalError(f"Invalid JSON response: {exc}") from exc

        logger.info("ollama_chat_response", model=model)
        return data

    async def chat_with_images(
        self,
        model: str,
        messages: list[dict[str, Any]],
        images_base64: list[str],
    ) -> dict[str, Any]:
        try:
            async for attempt in AsyncRetrying(**_RETRY_CONFIG):
                with attempt:
                    result = await self._circuit_breaker.call(
                        self._chat_with_images_raw,
                        model,
                        messages,
                        images_base64,
                    )
            ai_requests_total.labels(model=model, status="success").inc()
            return result
        except Exception:
            ai_requests_total.labels(model=model, status="error").inc()
            raise

    async def _chat_with_images_raw(
        self,
        model: str,
        messages: list[dict[str, Any]],
        images_base64: list[str],
    ) -> dict[str, Any]:
        if messages and images_base64:
            messages[-1]["images"] = images_base64

        logger.info(
            "ollama_vision_request",
            model=model,
            image_count=len(images_base64),
        )
        try:
            response = await self._client.post("/api/chat", json={
                "model": model,
                "messages": messages,
                "stream": False,
            })
            response.raise_for_status()
            data = response.json()
        except httpx.ConnectError as exc:
            raise OllamaRetryableError(f"Connection error: {exc}") from exc
        except httpx.TimeoutException as exc:
            raise OllamaRetryableError(f"Timeout: {exc}") from exc
        except httpx.HTTPStatusError as exc:
            raise _classify_httpx_error(exc) from exc
        except json.JSONDecodeError as exc:
            raise OllamaFatalError(f"Invalid JSON response: {exc}") from exc

        logger.info("ollama_vision_response", model=model)
        return data

    async def close(self) -> None:
        await self._client.aclose()

    @staticmethod
    def image_to_base64(
        image_path: str,
        resize: bool = True,
        max_size: int = 1024,
        quality: int = 85,
    ) -> str:
        img = Image.open(image_path)
        if resize and (img.width > max_size or img.height > max_size):
            img.thumbnail((max_size, max_size))
            buffer = BytesIO()
            img = img.convert("RGB")
            img.save(buffer, format="JPEG", quality=quality)
            return base64.b64encode(buffer.getvalue()).decode("utf-8")
        with Path(image_path).open("rb") as f:
            return base64.b64encode(f.read()).decode("utf-8")
