"""Webhook components for gnexus-gauth."""
import hashlib
import hmac
from datetime import datetime, timezone
from gnexus_gauth.config import GAuthConfig
from gnexus_gauth.contracts import (
ClockInterface,
WebhookParserInterface,
WebhookVerifierInterface,
)
from gnexus_gauth.dto import VerifiedWebhook, WebhookEvent
from gnexus_gauth.exceptions import (
WebhookPayloadException,
WebhookVerificationException,
)
from gnexus_gauth.support import SystemClock
class HmacWebhookVerifier(WebhookVerifierInterface):
"""HMAC-SHA256 webhook signature verifier."""
def __init__(
self,
config: GAuthConfig,
clock: ClockInterface | None = None,
) -> None:
self._config = config
self._clock = clock or SystemClock()
def verify(self, raw_body: str, headers: dict, secret: str) -> VerifiedWebhook:
normalized = self._normalize_headers(headers)
for required in ("x-gnexus-event-id", "x-gnexus-event-type", "x-gnexus-event-timestamp", "x-gnexus-signature"):
if not normalized.get(required):
raise WebhookVerificationException(f"Missing webhook header: {required}.")
signature = self._parse_signature_header(normalized["x-gnexus-signature"])
timestamp = signature["timestamp"]
expected = hmac.new(
secret.encode(),
f"{timestamp}.{raw_body}".encode(),
hashlib.sha256,
).hexdigest()
if not hmac.compare_digest(expected, signature["hash"]):
raise WebhookVerificationException("Invalid webhook signature.")
tolerance = self._config.webhook_tolerance_seconds
if tolerance > 0:
now_ts = int(self._clock.now().timestamp())
if abs(now_ts - timestamp) > tolerance:
raise WebhookVerificationException("Webhook timestamp is outside the allowed tolerance window.")
return VerifiedWebhook(
raw_body=raw_body,
normalized_headers=normalized,
signature_id=normalized["x-gnexus-event-id"],
verified_at=self._clock.now(),
)
@staticmethod
def _normalize_headers(headers: dict) -> dict[str, str]:
normalized: dict[str, str] = {}
for name, value in headers.items():
normalized_name = str(name).lower().replace("_", "-")
if isinstance(value, list):
value = value[0] if value else ""
normalized[normalized_name] = str(value).strip()
return normalized
@staticmethod
def _parse_signature_header(header: str) -> dict:
parts: dict[str, str] = {}
for chunk in header.split(","):
chunk = chunk.strip()
if "=" not in chunk:
continue
key, sep, value = chunk.partition("=")
if sep:
parts[key] = value
if "t" not in parts or "v1" not in parts:
raise WebhookVerificationException("Malformed webhook signature header.")
if not parts["t"].isdigit():
raise WebhookVerificationException("Webhook timestamp must be numeric.")
return {
"timestamp": int(parts["t"]),
"hash": parts["v1"].lower(),
}
class JsonWebhookParser(WebhookParserInterface):
"""JSON webhook payload parser."""
def parse(self, raw_body: str) -> WebhookEvent:
import json
try:
payload = json.loads(raw_body)
except (ValueError, TypeError) as exc:
raise WebhookPayloadException("Webhook payload is not valid JSON.") from exc
if not isinstance(payload, dict):
raise WebhookPayloadException("Webhook payload is not valid JSON.")
event_type = payload.get("type")
if not isinstance(event_type, str) or not event_type:
raise WebhookPayloadException("Webhook payload is missing event type.")
occurred_at: datetime | None = None
raw_occurred = payload.get("occurred_at")
if isinstance(raw_occurred, str) and raw_occurred:
try:
occurred_at = datetime.fromisoformat(raw_occurred)
if occurred_at.tzinfo is None:
occurred_at = occurred_at.replace(tzinfo=timezone.utc)
except ValueError as exc:
raise WebhookPayloadException("Webhook payload contains invalid occurred_at.") from exc
return WebhookEvent(
event_id=str(payload["id"]) if "id" in payload else None,
event_type=event_type,
occurred_at=occurred_at,
target_identifiers=payload["target"] if isinstance(payload.get("target"), dict) else {},
actor_identifiers=payload["actor"] if isinstance(payload.get("actor"), dict) else {},
metadata=payload["data"] if isinstance(payload.get("data"), dict) else {},
raw_payload=payload,
)