Newer
Older
navi-1 / navi / auth / deps.py
@Eugene Sukhodolskiy Eugene Sukhodolskiy on 4 May 8 KB Add graceful auth-not-configured guards
"""FastAPI dependencies for auth resolution."""

import asyncio
from datetime import datetime, timezone
from typing import Annotated

import structlog
from fastapi import Depends, HTTPException, Request

from navi.config import settings

from . import User
from .client import get_gauth_client
from .encrypt import get_encryptor

log = structlog.get_logger()


async def _resolve_user(conn) -> User | None:
    """Shared logic to resolve user from a connection object (Request or WebSocket)."""
    if conn is None:
        return None

    # Return cached user if already resolved this request
    if hasattr(conn.state, "user") and conn.state.user is not None:
        return conn.state.user

    # Auth not configured — treat as anonymous
    if not settings.gnexus_auth_client_id or not settings.gnexus_auth_client_secret:
        return None

    cookie_name = settings.navi_auth_cookie_name
    session_id = conn.cookies.get(cookie_name)
    if not session_id:
        return None

    # Look up the auth session in DB
    try:
        from navi.api.deps import get_session_store
    except Exception:
        # Avoid circular import during early bootstrap
        return None

    store = get_session_store()
    row = await _get_auth_session(store, session_id)
    if row is None:
        return None

    encryptor = get_encryptor()
    access_token = encryptor.decrypt(row["access_token_enc"])
    expires_at = row["expires_at"]

    client = get_gauth_client()

    # Refresh if expired
    if datetime.now(timezone.utc) > expires_at:
        try:
            refresh_token = encryptor.decrypt(row["refresh_token_enc"])
            token_set = await asyncio.to_thread(client.refresh_token, refresh_token)
            access_token = token_set.access_token
            # Update DB with new tokens
            await _update_auth_session(
                store,
                session_id,
                encryptor.encrypt(access_token),
                encryptor.encrypt(token_set.refresh_token or refresh_token),
                token_set.expires_at or datetime.now(timezone.utc),
            )
            log.info("auth.token_refreshed", user_id=row["user_id"])
        except Exception:
            log.warning("auth.refresh_failed", session_id=session_id[:8])
            # Refresh failed — treat as unauthenticated
            await _delete_auth_session(store, session_id)
            return None

    # Fetch user from gnexus-auth
    try:
        auth_user = await asyncio.to_thread(client.fetch_user, access_token)
    except Exception:
        log.warning("auth.fetch_user_failed", session_id=session_id[:8])
        return None

    # Determine role from client-level role_ids
    role = "user"
    permissions: list[str] = []
    for access in auth_user.client_access_list:
        if access.client_id == settings.gnexus_auth_client_id:
            if settings.gnexus_auth_admin_role_slug in (access.role_ids or []):
                role = "admin"
            permissions = list(access.permission_ids or [])
            break

    # Upsert into navi_users
    await _upsert_navi_user(store, auth_user.user_id, auth_user.email, auth_user.profile.get("display_name"), role, permissions)

    user = User(
        id=auth_user.user_id,
        email=auth_user.email,
        display_name=auth_user.profile.get("display_name") or auth_user.email,
        role=role,
        permissions=permissions,
    )

    # Update last_used_at
    await _touch_auth_session(store, session_id)
    conn.state.user = user
    return user


async def get_current_user(request: Request) -> User | None:
    """Resolve the current user from the auth session cookie for REST requests."""
    return await _resolve_user(request)


async def get_current_user_ws(websocket) -> User | None:
    """Resolve the current user from the auth session cookie for WebSocket connections."""
    return await _resolve_user(websocket)


# ── Helpers that talk to the session store ──────────────────────────────────

async def _get_auth_session(store, session_id: str) -> dict | None:
    """Fetch a row from user_auth_sessions."""
    pool = await store._get_pool()
    async with pool.acquire() as conn:
        row = await conn.fetchrow(
            "SELECT user_id, access_token_enc, refresh_token_enc, expires_at "
            "FROM user_auth_sessions WHERE id = $1",
            session_id,
        )
    if row is None:
        return None
    return {
        "user_id": row["user_id"],
        "access_token_enc": row["access_token_enc"],
        "refresh_token_enc": row["refresh_token_enc"],
        "expires_at": row["expires_at"],
    }


async def _update_auth_session(
    store, session_id: str, access_token_enc: str, refresh_token_enc: str, expires_at: datetime
) -> None:
    pool = await store._get_pool()
    async with pool.acquire() as conn:
        await conn.execute(
            "UPDATE user_auth_sessions "
            "SET access_token_enc = $1, refresh_token_enc = $2, expires_at = $3, last_used_at = $4 "
            "WHERE id = $5",
            access_token_enc, refresh_token_enc, expires_at, datetime.now(timezone.utc), session_id,
        )


async def _delete_auth_session(store, session_id: str) -> None:
    pool = await store._get_pool()
    async with pool.acquire() as conn:
        await conn.execute("DELETE FROM user_auth_sessions WHERE id = $1", session_id)


async def _touch_auth_session(store, session_id: str) -> None:
    pool = await store._get_pool()
    async with pool.acquire() as conn:
        await conn.execute(
            "UPDATE user_auth_sessions SET last_used_at = $1 WHERE id = $2",
            datetime.now(timezone.utc), session_id,
        )


async def _upsert_navi_user(
    store, user_id: str, email: str, display_name: str | None, role: str, permissions: list[str]
) -> None:
    pool = await store._get_pool()
    import json
    async with pool.acquire() as conn:
        await conn.execute(
            """INSERT INTO navi_users (id, email, display_name, role, permissions, created_at, updated_at)
               VALUES ($1, $2, $3, $4, $5, $6, $6)
               ON CONFLICT (id) DO UPDATE
               SET email = EXCLUDED.email,
                   display_name = EXCLUDED.display_name,
                   role = EXCLUDED.role,
                   permissions = EXCLUDED.permissions,
                   updated_at = EXCLUDED.updated_at""",
            user_id, email, display_name or email, role, json.dumps(permissions), datetime.now(timezone.utc),
        )


# ── FastAPI Depends helpers ─────────────────────────────────────────────────

async def require_user(user: Annotated[User | None, Depends(get_current_user)]) -> User:
    if user is None:
        raise HTTPException(status_code=401, detail="Authentication required")
    return user


async def require_admin(user: Annotated[User | None, Depends(get_current_user)]) -> User:
    if user is None:
        raise HTTPException(status_code=401, detail="Authentication required")
    if user.role != "admin":
        raise HTTPException(status_code=403, detail="Admin access required")
    return user


def require_permission(permission: str):
    async def _check(
        user: Annotated[User | None, Depends(get_current_user)],
    ) -> User:
        if user is None:
            raise HTTPException(status_code=401, detail="Authentication required")
        if user.role != "admin" and permission not in user.permissions:
            raise HTTPException(status_code=403, detail=f"Permission '{permission}' required")
        return user
    return _check


def check_session_access(session, user: User, permission: str | None = None) -> None:
    """Raise 403 if user does not own the session and lacks the required permission.

    Legacy sessions (user_id=None) are accessible only to admins.
    """
    from fastapi import HTTPException

    if session.user_id is None:
        if user.role != "admin":
            raise HTTPException(status_code=403, detail="Access denied")
        return
    if session.user_id == user.id:
        return
    if user.role == "admin":
        return
    if permission and user.has_permission(permission):
        return
    raise HTTPException(status_code=403, detail="Access denied")