"""Webhook components for gnexus-gauth."""

import hashlib
import hmac
from datetime import datetime, timezone

from gnexus_gauth.config import GAuthConfig
from gnexus_gauth.contracts import (
    ClockInterface,
    WebhookParserInterface,
    WebhookVerifierInterface,
)
from gnexus_gauth.dto import VerifiedWebhook, WebhookEvent
from gnexus_gauth.exceptions import (
    WebhookPayloadException,
    WebhookVerificationException,
)
from gnexus_gauth.support import SystemClock


class HmacWebhookVerifier(WebhookVerifierInterface):
    """HMAC-SHA256 webhook signature verifier."""

    def __init__(
        self,
        config: GAuthConfig,
        clock: ClockInterface | None = None,
    ) -> None:
        self._config = config
        self._clock = clock or SystemClock()

    def verify(self, raw_body: str, headers: dict, secret: str) -> VerifiedWebhook:
        normalized = self._normalize_headers(headers)

        for required in ("x-gnexus-event-id", "x-gnexus-event-type", "x-gnexus-event-timestamp", "x-gnexus-signature"):
            if not normalized.get(required):
                raise WebhookVerificationException(f"Missing webhook header: {required}.")

        signature = self._parse_signature_header(normalized["x-gnexus-signature"])
        timestamp = signature["timestamp"]
        expected = hmac.new(
            secret.encode(),
            f"{timestamp}.{raw_body}".encode(),
            hashlib.sha256,
        ).hexdigest()

        if not hmac.compare_digest(expected, signature["hash"]):
            raise WebhookVerificationException("Invalid webhook signature.")

        tolerance = self._config.webhook_tolerance_seconds
        if tolerance > 0:
            now_ts = int(self._clock.now().timestamp())
            if abs(now_ts - timestamp) > tolerance:
                raise WebhookVerificationException("Webhook timestamp is outside the allowed tolerance window.")

        return VerifiedWebhook(
            raw_body=raw_body,
            normalized_headers=normalized,
            signature_id=normalized["x-gnexus-event-id"],
            verified_at=self._clock.now(),
        )

    @staticmethod
    def _normalize_headers(headers: dict) -> dict[str, str]:
        normalized: dict[str, str] = {}
        for name, value in headers.items():
            normalized_name = str(name).lower().replace("_", "-")
            if isinstance(value, list):
                value = value[0] if value else ""
            normalized[normalized_name] = str(value).strip()
        return normalized

    @staticmethod
    def _parse_signature_header(header: str) -> dict:
        parts: dict[str, str] = {}
        for chunk in header.split(","):
            chunk = chunk.strip()
            if "=" not in chunk:
                continue
            key, sep, value = chunk.partition("=")
            if sep:
                parts[key] = value

        if "t" not in parts or "v1" not in parts:
            raise WebhookVerificationException("Malformed webhook signature header.")

        if not parts["t"].isdigit():
            raise WebhookVerificationException("Webhook timestamp must be numeric.")

        return {
            "timestamp": int(parts["t"]),
            "hash": parts["v1"].lower(),
        }


class JsonWebhookParser(WebhookParserInterface):
    """JSON webhook payload parser."""

    def parse(self, raw_body: str) -> WebhookEvent:
        import json

        try:
            payload = json.loads(raw_body)
        except (ValueError, TypeError) as exc:
            raise WebhookPayloadException("Webhook payload is not valid JSON.") from exc

        if not isinstance(payload, dict):
            raise WebhookPayloadException("Webhook payload is not valid JSON.")

        event_type = payload.get("type")
        if not isinstance(event_type, str) or not event_type:
            raise WebhookPayloadException("Webhook payload is missing event type.")

        occurred_at: datetime | None = None
        raw_occurred = payload.get("occurred_at")
        if isinstance(raw_occurred, str) and raw_occurred:
            try:
                occurred_at = datetime.fromisoformat(raw_occurred)
                if occurred_at.tzinfo is None:
                    occurred_at = occurred_at.replace(tzinfo=timezone.utc)
            except ValueError as exc:
                raise WebhookPayloadException("Webhook payload contains invalid occurred_at.") from exc

        return WebhookEvent(
            event_id=str(payload["id"]) if "id" in payload else None,
            event_type=event_type,
            occurred_at=occurred_at,
            target_identifiers=payload["target"] if isinstance(payload.get("target"), dict) else {},
            actor_identifiers=payload["actor"] if isinstance(payload.get("actor"), dict) else {},
            metadata=payload["data"] if isinstance(payload.get("data"), dict) else {},
            raw_payload=payload,
        )
