"""OAuth components for gnexus-gauth."""

import base64
import hashlib
import secrets
from urllib.parse import quote, urlencode

import httpx

from gnexus_gauth.config import GAuthConfig
from gnexus_gauth.contracts import TokenEndpointInterface
from gnexus_gauth.dto import TokenSet
from gnexus_gauth.exceptions import (
    TokenExchangeException,
    TokenRefreshException,
    TokenRevokeException,
    TransportException,
)


class PkceGenerator:
    """PKCE parameter generator."""

    @staticmethod
    def generate_verifier(length: int = 64) -> str:
        """Generate a PKCE code verifier."""
        token = secrets.token_urlsafe(length)
        return base64.urlsafe_b64encode(token.encode()).decode().rstrip("=")

    @staticmethod
    def generate_challenge(verifier: str) -> str:
        """Generate S256 challenge from verifier."""
        digest = hashlib.sha256(verifier.encode()).digest()
        return base64.urlsafe_b64encode(digest).decode().rstrip("=")

    @staticmethod
    def generate_state(length: int = 32) -> str:
        """Generate a random state parameter."""
        return secrets.token_urlsafe(length)


class AuthorizationUrlBuilder:
    """Builds authorization URLs with PKCE."""

    def __init__(self, config: GAuthConfig) -> None:
        self._config = config

    def build(
        self,
        state: str,
        pkce_challenge: str,
        return_to: str | None = None,
        scopes: list[str] | None = None,
    ) -> str:
        query = {
            "response_type": "code",
            "client_id": self._config.client_id,
            "redirect_uri": self._config.redirect_uri,
            "state": state,
            "code_challenge": pkce_challenge,
            "code_challenge_method": "S256",
        }
        if scopes:
            query["scope"] = " ".join(scopes)
        if return_to:
            query["return_to"] = return_to
        return f"{self._config.authorize_url}?{urlencode(query, safe='', quote_via=quote)}"


class HttpTokenEndpoint(TokenEndpointInterface):
    """HTTP token endpoint client."""

    def __init__(
        self,
        config: GAuthConfig,
        http_client: httpx.Client | None = None,
    ) -> None:
        self._config = config
        self._http = http_client or httpx.Client()
        self._own_client = http_client is None

    def __del__(self) -> None:
        if getattr(self, "_own_client", False) and hasattr(self, "_http"):
            self._http.close()

    def exchange_authorization_code(self, code: str, pkce_verifier: str) -> TokenSet:
        payload = {
            "grant_type": "authorization_code",
            "client_id": self._config.client_id,
            "client_secret": self._config.client_secret,
            "redirect_uri": self._config.redirect_uri,
            "code": code,
            "code_verifier": pkce_verifier,
        }
        data = self._send_form_request(self._config.token_url, payload, TokenExchangeException)
        return self._map_token_set(data)

    def refresh_token(self, refresh_token: str) -> TokenSet:
        payload = {
            "grant_type": "refresh_token",
            "client_id": self._config.client_id,
            "client_secret": self._config.client_secret,
            "refresh_token": refresh_token,
        }
        data = self._send_form_request(self._config.refresh_url, payload, TokenRefreshException)
        return self._map_token_set(data)

    def revoke_token(self, token: str, token_type_hint: str | None = None) -> None:
        payload = {
            "client_id": self._config.client_id,
            "client_secret": self._config.client_secret,
            "token": token,
        }
        if token_type_hint:
            payload["token_type_hint"] = token_type_hint
        self._send_form_request(self._config.revoke_url, payload, TokenRevokeException, expect_json=False)

    def _send_form_request(
        self,
        url: str,
        payload: dict,
        exception_class: type,
        *,
        expect_json: bool = True,
    ) -> dict:
        headers = {
            "Content-Type": "application/x-www-form-urlencoded",
            "Accept": "application/json",
        }
        if self._config.user_agent:
            headers["User-Agent"] = self._config.user_agent

        try:
            response = self._http.post(url, data=payload, headers=headers)
        except httpx.TransportError as exc:
            raise TransportException("Request to gnexus-auth failed.") from exc

        if response.status_code >= 400:
            message = self._extract_error_message(response.text) or "gnexus-auth returned an error response."
            raise exception_class(message)

        if not expect_json:
            return {}

        try:
            data = response.json()
        except ValueError:
            raise exception_class("gnexus-auth returned malformed JSON.")

        if not isinstance(data, dict):
            raise exception_class("gnexus-auth returned malformed JSON.")

        return data

    @staticmethod
    def _map_token_set(data: dict) -> TokenSet:
        expires_in = int(data.get("expires_in", 0))
        refresh_expires_in = data.get("refresh_expires_in")
        if refresh_expires_in is not None:
            refresh_expires_in = int(refresh_expires_in)

        scope = data.get("scope")
        scopes = []
        if isinstance(scope, str) and scope:
            scopes = [s for s in scope.split() if s]

        from datetime import datetime, timedelta, timezone

        expires_at = None
        if expires_in > 0:
            expires_at = datetime.now(timezone.utc) + timedelta(seconds=expires_in)

        return TokenSet(
            access_token=str(data.get("access_token", "")),
            refresh_token=str(data["refresh_token"]) if "refresh_token" in data else None,
            token_type=str(data.get("token_type", "Bearer")),
            expires_in=expires_in,
            expires_at=expires_at,
            refresh_expires_in=refresh_expires_in,
            scopes=scopes,
            raw_payload=data,
        )

    @staticmethod
    def _extract_error_message(text: str) -> str | None:
        import json

        try:
            decoded = json.loads(text)
        except ValueError:
            return None
        if not isinstance(decoded, dict):
            return None
        if isinstance(decoded.get("error_description"), str):
            return decoded["error_description"]
        if isinstance(decoded.get("error"), str):
            return decoded["error"]
        return None
