import hashlib
import time
from dataclasses import dataclass
from io import BytesIO
from pathlib import Path

import httpx
import structlog
from PIL import Image
from tenacity import (
    before_sleep_log,
    retry_if_exception_type,
    stop_after_attempt,
    wait_exponential,
)

from vmk_data_collector.core.exceptions import ImageDownloadError
from vmk_data_collector.core.metrics import image_download_duration_seconds

_MAX_IMAGE_BYTES = 50 * 1024 * 1024

logger = structlog.get_logger()


@dataclass
class PropertyImageDownloadResult:
    local_path: str
    image_hash: str
    width: int
    height: int
    file_size: int


_IMAGE_RETRY = {
    "stop": stop_after_attempt(3),
    "wait": wait_exponential(min=1, max=10),
    "retry": retry_if_exception_type((httpx.ConnectError, httpx.TimeoutException)),
    "before_sleep": before_sleep_log(logger, "warning"),
    "reraise": True,
}


class ImageDownloader:
    def __init__(self, storage_path: Path) -> None:
        self._storage_path = storage_path

    @_IMAGE_RETRY
    async def download(
        self,
        property_id: int,
        image_url: str,
        order_index: int,
    ) -> PropertyImageDownloadResult:
        start = time.perf_counter()
        logger.info(
            "image_download_start",
            property_id=property_id,
            url=image_url,
            order=order_index,
        )

        try:
            async with httpx.AsyncClient(timeout=30) as client, client.stream(
                "GET", image_url
            ) as response:
                response.raise_for_status()
                content_length = response.headers.get("content-length")
                if content_length and int(content_length) > _MAX_IMAGE_BYTES:
                    raise ImageDownloadError(
                        f"Image too large: {content_length} bytes"
                    )
                content = bytearray()
                async for chunk in response.iter_bytes():
                    content.extend(chunk)
                    if len(content) > _MAX_IMAGE_BYTES:
                        raise ImageDownloadError(
                            f"Image exceeds max size of {_MAX_IMAGE_BYTES} bytes"
                        )
            content = bytes(content)

            image_hash = hashlib.sha256(content).hexdigest()
            ext = self._detect_extension(
                response.headers.get("content-type", ""), image_url
            )

            with Image.open(BytesIO(content)) as img:
                width, height = img.size

            property_dir = self._storage_path / str(property_id)
            property_dir.mkdir(parents=True, exist_ok=True)

            local_path = property_dir / f"{image_hash}.{ext}"
            local_path.write_bytes(content)

            file_size = len(content)

            logger.info(
                "image_download_complete",
                property_id=property_id,
                hash=image_hash,
                width=width,
                height=height,
                size=file_size,
            )

            return PropertyImageDownloadResult(
                local_path=str(local_path),
                image_hash=image_hash,
                width=width,
                height=height,
                file_size=file_size,
            )
        finally:
            image_download_duration_seconds.observe(time.perf_counter() - start)

        with Image.open(BytesIO(content)) as img:
            width, height = img.size

        property_dir = self._storage_path / str(property_id)
        property_dir.mkdir(parents=True, exist_ok=True)

        local_path = property_dir / f"{image_hash}.{ext}"
        local_path.write_bytes(content)

        file_size = len(content)

        logger.info(
            "image_download_complete",
            property_id=property_id,
            hash=image_hash,
            width=width,
            height=height,
            size=file_size,
        )

        return PropertyImageDownloadResult(
            local_path=str(local_path),
            image_hash=image_hash,
            width=width,
            height=height,
            file_size=file_size,
        )

    @staticmethod
    def _detect_extension(content_type: str, url: str) -> str:
        ct = content_type.lower()
        if "jpeg" in ct or "jpg" in ct:
            return "jpg"
        if "png" in ct:
            return "png"
        if "webp" in ct:
            return "webp"
        if "gif" in ct:
            return "gif"

        from urllib.parse import urlparse

        path = urlparse(url).path.lower()
        for ext in (".jpg", ".jpeg", ".png", ".webp", ".gif"):
            if path.endswith(ext):
                return ext.lstrip(".")
        return "jpg"
