Newer
Older
gnexus-auth-client-py / src / gnexus_gauth / client.py
"""High-level client for gnexus-gauth integrations."""

from datetime import datetime, timedelta, timezone

from gnexus_gauth.config import GAuthConfig
from gnexus_gauth.contracts import (
    ClockInterface,
    PkceStoreInterface,
    RuntimeUserProviderInterface,
    StateStoreInterface,
    TokenEndpointInterface,
    WebhookParserInterface,
    WebhookVerifierInterface,
)
from gnexus_gauth.dto import (
    AuthenticatedUser,
    AuthorizationRequest,
    TokenSet,
    VerifiedWebhook,
    WebhookEvent,
)
from gnexus_gauth.exceptions import PkceException, StateValidationException
from gnexus_gauth.oauth import AuthorizationUrlBuilder, PkceGenerator
from gnexus_gauth.support import SystemClock


class GAuthClient:
    """Main integration surface for consuming applications."""

    def __init__(
        self,
        config: GAuthConfig,
        token_endpoint: TokenEndpointInterface,
        runtime_user_provider: RuntimeUserProviderInterface,
        webhook_verifier: WebhookVerifierInterface,
        webhook_parser: WebhookParserInterface,
        state_store: StateStoreInterface,
        pkce_store: PkceStoreInterface,
        clock: ClockInterface | None = None,
        authorization_url_builder: AuthorizationUrlBuilder | None = None,
    ) -> None:
        self._config = config
        self._token_endpoint = token_endpoint
        self._runtime_user_provider = runtime_user_provider
        self._webhook_verifier = webhook_verifier
        self._webhook_parser = webhook_parser
        self._state_store = state_store
        self._pkce_store = pkce_store
        self._clock = clock or SystemClock()
        self._authorization_url_builder = authorization_url_builder or AuthorizationUrlBuilder(config)

    def build_authorization_request(
        self,
        return_to: str | None = None,
        scopes: list[str] | None = None,
    ) -> AuthorizationRequest:
        state = PkceGenerator.generate_state()
        verifier = PkceGenerator.generate_verifier()
        challenge = PkceGenerator.generate_challenge(verifier)
        expires_at = self._clock.now() + timedelta(seconds=self._config.state_ttl_seconds)

        self._state_store.put(
            state,
            expires_at,
            {"return_to": return_to, "scopes": list(scopes) if scopes else []},
        )
        self._pkce_store.put(state, verifier, expires_at)

        url = self._authorization_url_builder.build(
            state=state,
            pkce_challenge=challenge,
            return_to=return_to,
            scopes=scopes,
        )

        return AuthorizationRequest(
            authorization_url=url,
            state=state,
            pkce_verifier=verifier,
            pkce_challenge=challenge,
            scopes=list(scopes) if scopes else [],
            return_to=return_to,
        )

    def exchange_authorization_code(self, code: str, state: str) -> TokenSet:
        if not self._state_store.has(state):
            raise StateValidationException("Unknown or expired authorization state.")

        verifier = self._pkce_store.get(state)
        if not verifier:
            raise PkceException("Missing PKCE verifier for authorization callback.")

        token_set = self._token_endpoint.exchange_authorization_code(code, verifier)

        self._state_store.forget(state)
        self._pkce_store.forget(state)

        return token_set

    def refresh_token(self, refresh_token: str) -> TokenSet:
        return self._token_endpoint.refresh_token(refresh_token)

    def revoke_token(self, token: str, token_type_hint: str | None = None) -> None:
        self._token_endpoint.revoke_token(token, token_type_hint)

    def fetch_user(self, access_token: str) -> AuthenticatedUser:
        return self._runtime_user_provider.fetch_user(access_token)

    def verify_webhook(self, raw_body: str, headers: dict, secret: str) -> VerifiedWebhook:
        return self._webhook_verifier.verify(raw_body, headers, secret)

    def parse_webhook(self, raw_body: str) -> WebhookEvent:
        return self._webhook_parser.parse(raw_body)

    def verify_and_parse_webhook(self, raw_body: str, headers: dict, secret: str) -> WebhookEvent:
        self.verify_webhook(raw_body, headers, secret)
        return self.parse_webhook(raw_body)