"""OAuth components for gnexus-gauth."""
import base64
import hashlib
import secrets
from urllib.parse import quote, urlencode
import httpx
from gnexus_gauth.config import GAuthConfig
from gnexus_gauth.contracts import TokenEndpointInterface
from gnexus_gauth.dto import TokenSet
from gnexus_gauth.exceptions import (
TokenExchangeException,
TokenRefreshException,
TokenRevokeException,
TransportException,
)
class PkceGenerator:
"""PKCE parameter generator."""
@staticmethod
def generate_verifier(length: int = 64) -> str:
"""Generate a PKCE code verifier."""
token = secrets.token_urlsafe(length)
return base64.urlsafe_b64encode(token.encode()).decode().rstrip("=")
@staticmethod
def generate_challenge(verifier: str) -> str:
"""Generate S256 challenge from verifier."""
digest = hashlib.sha256(verifier.encode()).digest()
return base64.urlsafe_b64encode(digest).decode().rstrip("=")
@staticmethod
def generate_state(length: int = 32) -> str:
"""Generate a random state parameter."""
return secrets.token_urlsafe(length)
class AuthorizationUrlBuilder:
"""Builds authorization URLs with PKCE."""
def __init__(self, config: GAuthConfig) -> None:
self._config = config
def build(
self,
state: str,
pkce_challenge: str,
return_to: str | None = None,
scopes: list[str] | None = None,
) -> str:
query = {
"response_type": "code",
"client_id": self._config.client_id,
"redirect_uri": self._config.redirect_uri,
"state": state,
"code_challenge": pkce_challenge,
"code_challenge_method": "S256",
}
if scopes:
query["scope"] = " ".join(scopes)
if return_to:
query["return_to"] = return_to
return f"{self._config.authorize_url}?{urlencode(query, safe='', quote_via=quote)}"
class HttpTokenEndpoint(TokenEndpointInterface):
"""HTTP token endpoint client."""
def __init__(
self,
config: GAuthConfig,
http_client: httpx.Client | None = None,
) -> None:
self._config = config
self._http = http_client or httpx.Client()
self._own_client = http_client is None
def __del__(self) -> None:
if getattr(self, "_own_client", False) and hasattr(self, "_http"):
self._http.close()
def exchange_authorization_code(self, code: str, pkce_verifier: str) -> TokenSet:
payload = {
"grant_type": "authorization_code",
"client_id": self._config.client_id,
"client_secret": self._config.client_secret,
"redirect_uri": self._config.redirect_uri,
"code": code,
"code_verifier": pkce_verifier,
}
data = self._send_form_request(self._config.token_url, payload, TokenExchangeException)
return self._map_token_set(data)
def refresh_token(self, refresh_token: str) -> TokenSet:
payload = {
"grant_type": "refresh_token",
"client_id": self._config.client_id,
"client_secret": self._config.client_secret,
"refresh_token": refresh_token,
}
data = self._send_form_request(self._config.refresh_url, payload, TokenRefreshException)
return self._map_token_set(data)
def revoke_token(self, token: str, token_type_hint: str | None = None) -> None:
payload = {
"client_id": self._config.client_id,
"client_secret": self._config.client_secret,
"token": token,
}
if token_type_hint:
payload["token_type_hint"] = token_type_hint
self._send_form_request(self._config.revoke_url, payload, TokenRevokeException, expect_json=False)
def _send_form_request(
self,
url: str,
payload: dict,
exception_class: type,
*,
expect_json: bool = True,
) -> dict:
headers = {
"Content-Type": "application/x-www-form-urlencoded",
"Accept": "application/json",
}
if self._config.user_agent:
headers["User-Agent"] = self._config.user_agent
try:
response = self._http.post(url, data=payload, headers=headers)
except httpx.TransportError as exc:
raise TransportException("Request to gnexus-auth failed.") from exc
if response.status_code >= 400:
message = self._extract_error_message(response.text) or "gnexus-auth returned an error response."
raise exception_class(message)
if not expect_json:
return {}
try:
data = response.json()
except ValueError:
raise exception_class("gnexus-auth returned malformed JSON.")
if not isinstance(data, dict):
raise exception_class("gnexus-auth returned malformed JSON.")
return data
@staticmethod
def _map_token_set(data: dict) -> TokenSet:
expires_in = int(data.get("expires_in", 0))
refresh_expires_in = data.get("refresh_expires_in")
if refresh_expires_in is not None:
refresh_expires_in = int(refresh_expires_in)
scope = data.get("scope")
scopes = []
if isinstance(scope, str) and scope:
scopes = [s for s in scope.split() if s]
from datetime import datetime, timedelta, timezone
expires_at = None
if expires_in > 0:
expires_at = datetime.now(timezone.utc) + timedelta(seconds=expires_in)
return TokenSet(
access_token=str(data.get("access_token", "")),
refresh_token=str(data["refresh_token"]) if "refresh_token" in data else None,
token_type=str(data.get("token_type", "Bearer")),
expires_in=expires_in,
expires_at=expires_at,
refresh_expires_in=refresh_expires_in,
scopes=scopes,
raw_payload=data,
)
@staticmethod
def _extract_error_message(text: str) -> str | None:
import json
try:
decoded = json.loads(text)
except ValueError:
return None
if not isinstance(decoded, dict):
return None
if isinstance(decoded.get("error_description"), str):
return decoded["error_description"]
if isinstance(decoded.get("error"), str):
return decoded["error"]
return None