Newer
Older
gnexus-creds / gnexus_creds / oauth.py
"""gnexus-auth OAuth/session routes."""

from datetime import UTC, datetime, timedelta

from fastapi import APIRouter, Depends, Request, Response
from fastapi.responses import RedirectResponse
from gnexus_gauth.config import GAuthConfig
from gnexus_gauth.oauth import AuthorizationUrlBuilder, HttpTokenEndpoint, PkceGenerator
from gnexus_gauth.runtime import HttpRuntimeUserProvider
from gnexus_gauth.webhook import HmacWebhookVerifier, JsonWebhookParser
from sqlalchemy.orm import Session

from gnexus_creds.config import get_settings
from gnexus_creds.db import get_db
from gnexus_creds.errors import AppError
from gnexus_creds.models import OAuthState, User, utcnow
from gnexus_creds.services import (
    check_rate_limit,
    create_session,
    delete_session,
    is_expired,
    upsert_user_from_auth,
)

router = APIRouter(prefix="/auth", tags=["auth"])
webhook_router = APIRouter(prefix="/webhooks", tags=["webhooks"])


def _config() -> GAuthConfig:
    settings = get_settings()
    return GAuthConfig(
        base_url=settings.auth_base_url,
        client_id=settings.auth_client_id,
        client_secret=settings.auth_client_secret,
        redirect_uri=settings.auth_redirect_uri,
    )


def _validate_return_to(value: str) -> str:
    if not value.startswith("/") or value.startswith("//") or "\\" in value:
        return "/"
    return value


def _client_ip(request: Request) -> str:
    return request.client.host if request.client else "anonymous"


@router.get("/login")
async def login(
    request: Request, return_to: str = "/", db: Session = Depends(get_db)
) -> RedirectResponse:
    check_rate_limit(
        db, key=f"oauth_login:{_client_ip(request)}", max_requests=20, window_seconds=60
    )
    return_to = _validate_return_to(return_to)
    config = _config()
    state = PkceGenerator.generate_state()
    verifier = PkceGenerator.generate_verifier()
    challenge = PkceGenerator.generate_challenge(verifier)
    db.add(
        OAuthState(
            state=state,
            pkce_verifier=verifier,
            return_to=return_to,
            scopes=["openid", "email", "profile"],
            expires_at=utcnow() + timedelta(minutes=10),
        )
    )
    db.commit()
    url = AuthorizationUrlBuilder(config).build(
        state=state,
        pkce_challenge=challenge,
        return_to=return_to,
        scopes=["openid", "email", "profile"],
    )
    return RedirectResponse(url)


@router.get("/callback")
async def callback(
    request: Request, code: str, state: str, db: Session = Depends(get_db)
) -> RedirectResponse:
    check_rate_limit(
        db, key=f"oauth_callback:{_client_ip(request)}", max_requests=20, window_seconds=60
    )
    settings = get_settings()
    saved = db.get(OAuthState, state)
    if saved is None or is_expired(saved.expires_at, now=datetime.now(UTC)):
        raise AppError("invalid_oauth_state", "Invalid or expired OAuth state.", status_code=400)
    config = _config()
    token_set = HttpTokenEndpoint(config).exchange_authorization_code(code, saved.pkce_verifier)
    auth_user = HttpRuntimeUserProvider(config).fetch_user(token_set.access_token)
    profile = auth_user.profile or {}
    user = upsert_user_from_auth(
        db,
        auth_subject=auth_user.user_id,
        email=auth_user.email,
        display_name=profile.get("display_name") or profile.get("name") or auth_user.email,
        profile=profile,
        status="disabled" if auth_user.status in {"disabled", "blocked", "deleted"} else "enabled",
        system_role="admin" if auth_user.system_role == "admin" else "user",
        locale=profile.get("locale"),
    )
    if user.status == "disabled":
        raise AppError("user_disabled", "User is disabled.", status_code=403)
    session = create_session(db, user, settings.session_ttl_seconds)
    return_to = _validate_return_to(saved.return_to or "/")
    db.delete(saved)
    db.commit()
    response = RedirectResponse(return_to)
    response.set_cookie(
        settings.session_cookie_name,
        session.id,
        httponly=True,
        samesite="lax",
        max_age=settings.session_ttl_seconds,
        secure=settings.is_production,
    )
    return response


@router.post("/logout")
async def logout(request: Request, db: Session = Depends(get_db)) -> Response:
    settings = get_settings()
    delete_session(db, request.cookies.get(settings.session_cookie_name))
    db.commit()
    response = Response(content='{"status":"ok"}', media_type="application/json")
    response.delete_cookie(
        settings.session_cookie_name,
        httponly=True,
        samesite="lax",
        secure=settings.is_production,
    )
    return response


@webhook_router.post("/gnexus-auth")
async def gnexus_auth_webhook(request: Request, db: Session = Depends(get_db)) -> dict[str, str]:
    settings = get_settings()
    raw = (await request.body()).decode()
    headers = dict(request.headers)
    config = _config()
    HmacWebhookVerifier(config).verify(raw, headers, settings.auth_webhook_secret)
    event = JsonWebhookParser().parse(raw)
    subject = (
        event.target_identifiers.get("sub")
        or event.target_identifiers.get("user_id")
        or event.metadata.get("sub")
    )
    if subject:
        user = db.query(User).filter(User.auth_subject == str(subject)).one_or_none()
        if user:
            profile = event.metadata.get("profile")
            if isinstance(profile, dict):
                user.profile = profile
                user.display_name = (
                    profile.get("display_name") or profile.get("name") or user.display_name
                )
                user.locale = profile.get("locale") or user.locale
            status = event.metadata.get("status")
            if status in {"disabled", "blocked", "deleted"}:
                user.status = "disabled"
            elif status == "enabled":
                user.status = "enabled"
            db.commit()
    return {"status": "ok"}