Newer
Older
navi-1 / navi / auth / deps.py
@Eugene Sukhodolskiy Eugene Sukhodolskiy 2 days ago 19 KB Add NAVI_AUTH_ENABLED switch for optional auth
"""FastAPI dependencies for auth resolution."""

import asyncio
import hashlib
import time
from datetime import datetime, timezone
from typing import Annotated

import structlog
from fastapi import Depends, HTTPException, Request
from gnexus_gauth.exceptions import TokenRefreshException

from navi.config import settings

from . import ApiToken, User
from .client import get_gauth_client
from .encrypt import get_encryptor

log = structlog.get_logger()

# Anonymous admin identity used only when NAVI_AUTH_ENABLED=false.
_ANONYMOUS_USER = User(
    id="anonymous",
    email="anonymous@navi.local",
    display_name="Anonymous",
    username=None,
    first_name=None,
    last_name=None,
    phone=None,
    birth_date=None,
    country=None,
    city=None,
    locale=None,
    avatar_url=None,
    role="admin",
    permissions=[],
)

# Per-session-id lock to prevent concurrent token refresh races.
_refresh_locks: dict[str, asyncio.Lock] = {}

# In-memory cache for fetch_user results: session_id -> (User, timestamp).
# Reduces hammering gnexus-auth and protects against eventual-consistency
# flakes right after a token refresh.
_USER_CACHE_TTL: float = 30.0
_user_cache: dict[str, tuple[User, float]] = {}


def _cleanup_refresh_locks() -> None:
    """Remove stale refresh locks to prevent unbounded memory growth."""
    now = time.time()
    stale = [sid for sid, lock in _refresh_locks.items() if sid not in _user_cache or _user_cache[sid][1] + _USER_CACHE_TTL < now]
    for sid in stale:
        _refresh_locks.pop(sid, None)


def _get_cached_user(session_id: str) -> User | None:
    """Return a cached User if it is still fresh."""
    entry = _user_cache.get(session_id)
    if entry is None:
        _cleanup_refresh_locks()
        return None
    user, ts = entry
    if time.monotonic() - ts > _USER_CACHE_TTL:
        _user_cache.pop(session_id, None)
        _cleanup_refresh_locks()
        return None
    return user


def _set_cached_user(session_id: str, user: User) -> None:
    """Store a resolved User in the short-lived cache."""
    _user_cache[session_id] = (user, time.monotonic())


async def _fetch_user_with_retry(client, access_token: str, session_id: str) -> "User | None":
    """Call fetch_user, retry once after 1.5 s on transient failure."""
    for attempt in range(2):
        try:
            return await asyncio.to_thread(client.fetch_user, access_token)
        except Exception:
            if attempt == 0:
                log.debug(
                    "auth.fetch_user_retry",
                    session_id=session_id[:8],
                    attempt=attempt + 1,
                )
                await asyncio.sleep(1.5)
            else:
                log.warning("auth.fetch_user_failed", session_id=session_id[:8])
    return None


async def _resolve_user(conn) -> User | None:
    """Shared logic to resolve user from a connection object (Request or WebSocket).

    Any failure during resolution is silently treated as anonymous to avoid
    crashing WebSocket upgrades or other dependency-injected paths.
    """
    try:
        log.debug("auth.resolve_start", conn_type=type(conn).__name__)

        # Auth explicitly disabled: every request is the anonymous admin.
        # This short-circuits before any cookie/API-token/OAuth lookup.
        if not settings.navi_auth_enabled:
            log.debug("auth.resolve_disabled")
            return _ANONYMOUS_USER

        if conn is None:
            log.debug("auth.resolve_no_conn")
            return None

        # Return cached user if already resolved this request
        if hasattr(conn.state, "user") and conn.state.user is not None:
            log.debug("auth.resolve_cached", user_id=conn.state.user.id)
            return conn.state.user

        # Auth not configured — treat as anonymous
        if not settings.gnauth_client_id or not settings.gnauth_client_secret:
            log.debug("auth.resolve_not_configured")
            return None

        cookie_name = settings.navi_auth_cookie_name
        session_id = conn.cookies.get(cookie_name)
        if not session_id:
            log.debug("auth.resolve_no_cookie", cookie_name=cookie_name)
            # Try API token fallback before giving up
            api_user = await _resolve_user_from_api_token(conn)
            if api_user is not None:
                conn.state.user = api_user
            return api_user

        log.debug("auth.resolve_cookie_found", session_id=session_id[:8])

        # Look up the auth session in DB
        try:
            from navi.api.deps import get_session_store
        except Exception:
            # Avoid circular import during early bootstrap
            log.debug("auth.resolve_store_import_failed")
            return None

        store = get_session_store()
        row = await _get_auth_session(store, session_id)
        if row is None:
            log.debug("auth.resolve_session_not_found", session_id=session_id[:8])
            return None

        log.debug("auth.resolve_session_found", user_id=row["user_id"])

        encryptor = get_encryptor()
        access_token = encryptor.decrypt(row["access_token_enc"])
        expires_at = row["expires_at"]

        client = get_gauth_client()

        # Refresh if expired
        if datetime.now(timezone.utc) > expires_at:
            lock = _refresh_locks.setdefault(session_id, asyncio.Lock())
            async with lock:
                # Re-read session inside lock — another request may have refreshed it already
                row = await _get_auth_session(store, session_id)
                if row is None:
                    log.debug("auth.resolve_session_gone_during_refresh", session_id=session_id[:8])
                    return None
                expires_at = row["expires_at"]
                if datetime.now(timezone.utc) <= expires_at:
                    access_token = encryptor.decrypt(row["access_token_enc"])
                else:
                    try:
                        refresh_token = encryptor.decrypt(row["refresh_token_enc"])
                        token_set = await asyncio.to_thread(client.refresh_token, refresh_token)
                        access_token = token_set.access_token
                        await _update_auth_session(
                            store,
                            session_id,
                            encryptor.encrypt(access_token),
                            encryptor.encrypt(token_set.refresh_token or refresh_token),
                            token_set.expires_at or datetime.now(timezone.utc),
                        )
                        log.info("auth.token_refreshed", user_id=row["user_id"])
                    except TokenRefreshException as exc:
                        # Refresh token is definitively invalid (expired, revoked,
                        # rotated by another device).  Force re-login.
                        log.warning(
                            "auth.refresh_token_invalid",
                            session_id=session_id[:8],
                            user_id=row["user_id"],
                            reason=str(exc),
                        )
                        # Try API token before giving up.
                        api_user = await _resolve_user_from_api_token(conn)
                        if api_user is not None:
                            conn.state.user = api_user
                        return api_user
                    except Exception as exc:
                        # Transient errors (network, timeout, 5xx from auth server).
                        # Do NOT delete the session — force-logout is too harsh.
                        # Best-effort: serve from cache if available, else try API token.
                        log.warning(
                            "auth.refresh_transient_fail",
                            session_id=session_id[:8],
                            user_id=row["user_id"],
                            exc_type=type(exc).__name__,
                            error=str(exc),
                        )
                        cached = _get_cached_user(session_id)
                        if cached is not None:
                            log.debug("auth.fallback_to_cached_user", user_id=cached.id)
                            conn.state.user = cached
                            return cached
                        api_user = await _resolve_user_from_api_token(conn)
                        if api_user is not None:
                            conn.state.user = api_user
                        return api_user

        # Check short-lived fetch_user cache before hitting gnexus-auth
        cached = _get_cached_user(session_id)
        if cached is not None:
            log.debug("auth.resolve_user_cache_hit", user_id=cached.id)
            conn.state.user = cached
            return cached

        # Fetch user from gnexus-auth (with one retry on transient failure)
        auth_user = await _fetch_user_with_retry(client, access_token, session_id)
        if auth_user is None:
            return None

        log.debug("auth.resolve_user_fetched", user_id=auth_user.user_id)

        # Determine role from client-level role_ids
        role = "user"
        permissions: list[str] = []
        for access in auth_user.client_access_list:
            if access.client_id == settings.gnauth_client_id:
                if settings.gnauth_admin_role_slug in (access.role_ids or []):
                    role = "admin"
                permissions = list(access.permission_ids or [])
                break

        # Upsert into navi_users
        profile = auth_user.profile
        await _upsert_navi_user(
            store, auth_user.user_id, auth_user.email,
            profile.get("display_name"), role, permissions,
            username=profile.get("username"),
            first_name=profile.get("first_name"),
            last_name=profile.get("last_name"),
            phone=profile.get("phone"),
            birth_date=profile.get("birth_date"),
            country=profile.get("country"),
            city=profile.get("city"),
            locale=profile.get("locale"),
        )

        profile = auth_user.profile
        user = User(
            id=auth_user.user_id,
            email=auth_user.email,
            display_name=profile.get("display_name") or auth_user.email,
            username=profile.get("username"),
            first_name=profile.get("first_name"),
            last_name=profile.get("last_name"),
            phone=profile.get("phone"),
            birth_date=profile.get("birth_date"),
            country=profile.get("country"),
            city=profile.get("city"),
            locale=profile.get("locale"),
            avatar_url=auth_user.avatar_url,
            role=role,
            permissions=permissions,
        )

        # Update last_used_at and cache
        await _touch_auth_session(store, session_id)
        _set_cached_user(session_id, user)
        conn.state.user = user
        log.debug("auth.resolve_success", user_id=user.id)
        return user
    except Exception:
        # Any unexpected failure during auth resolution should not crash the
        # request — treat as anonymous so WebSocket upgrades and REST calls
        # degrade gracefully.
        log.warning("auth.resolve_failed", exc_info=True)
        return None


# ── API token resolution helpers ────────────────────────────────────────────

async def _resolve_user_from_api_token(conn) -> User | None:
    """Resolve user from X-Api-Token header or ?api_token query parameter."""
    # Try header first, then query param (for WebSocket or REST)
    token = None
    try:
        token = conn.headers.get("X-Api-Token")
    except Exception:
        pass
    if not token:
        try:
            token = conn.query_params.get("api_token")
        except Exception:
            pass
    if not token:
        return None

    try:
        from navi.api.deps import get_session_store
    except Exception:
        return None

    store = get_session_store()
    pool = await store._get_pool()

    token_hash = hashlib.sha256(token.encode()).hexdigest()
    async with pool.acquire() as db_conn:
        user_row = await db_conn.fetchrow(
            "SELECT u.id, u.email, u.display_name, u.username, u.first_name, u.last_name, "
            "u.phone, u.birth_date, u.country, u.city, u.locale, u.role, u.permissions, "
            "t.id AS token_id, t.revoked_at "
            "FROM api_tokens t JOIN navi_users u ON u.id = t.user_id "
            "WHERE t.token_hash = $1",
            token_hash,
        )
    if user_row is None:
        log.debug("auth.api_token_not_found")
        return None
    if user_row["revoked_at"] is not None:
        log.debug("auth.api_token_revoked", token_id=user_row["token_id"])
        return None

    import json
    permissions = json.loads(user_row["permissions"] or "[]")
    user = User(
        id=user_row["id"],
        email=user_row["email"],
        display_name=user_row["display_name"] or user_row["email"],
        username=user_row["username"],
        first_name=user_row["first_name"],
        last_name=user_row["last_name"],
        phone=user_row["phone"],
        birth_date=user_row["birth_date"],
        country=user_row["country"],
        city=user_row["city"],
        locale=user_row["locale"],
        role=user_row["role"],
        permissions=permissions,
    )
    # Update last_used_at asynchronously (best-effort)
    try:
        async with pool.acquire() as db_conn:
            await db_conn.execute(
                "UPDATE api_tokens SET last_used_at = $1 WHERE id = $2",
                datetime.now(timezone.utc), user_row["token_id"],
            )
    except Exception:
        pass
    log.debug("auth.api_token_resolved", user_id=user.id)
    return user


async def get_current_user(request: Request) -> User | None:
    """Resolve the current user from the auth session cookie for REST requests."""
    return await _resolve_user(request)


async def get_current_user_ws(websocket) -> User | None:
    """Resolve the current user from the auth session cookie for WebSocket connections."""
    return await _resolve_user(websocket)


# ── Helpers that talk to the session store ──────────────────────────────────

async def _get_auth_session(store, session_id: str) -> dict | None:
    """Fetch a row from user_auth_sessions."""
    pool = await store._get_pool()
    async with pool.acquire() as conn:
        row = await conn.fetchrow(
            "SELECT user_id, access_token_enc, refresh_token_enc, expires_at "
            "FROM user_auth_sessions WHERE id = $1",
            session_id,
        )
    if row is None:
        return None
    return {
        "user_id": row["user_id"],
        "access_token_enc": row["access_token_enc"],
        "refresh_token_enc": row["refresh_token_enc"],
        "expires_at": row["expires_at"],
    }


async def _update_auth_session(
    store, session_id: str, access_token_enc: str, refresh_token_enc: str, expires_at: datetime
) -> None:
    pool = await store._get_pool()
    async with pool.acquire() as conn:
        await conn.execute(
            "UPDATE user_auth_sessions "
            "SET access_token_enc = $1, refresh_token_enc = $2, expires_at = $3, last_used_at = $4 "
            "WHERE id = $5",
            access_token_enc, refresh_token_enc, expires_at, datetime.now(timezone.utc), session_id,
        )


async def _delete_auth_session(store, session_id: str) -> None:
    pool = await store._get_pool()
    async with pool.acquire() as conn:
        await conn.execute("DELETE FROM user_auth_sessions WHERE id = $1", session_id)


async def _touch_auth_session(store, session_id: str) -> None:
    pool = await store._get_pool()
    async with pool.acquire() as conn:
        await conn.execute(
            "UPDATE user_auth_sessions SET last_used_at = $1 WHERE id = $2",
            datetime.now(timezone.utc), session_id,
        )


async def _upsert_navi_user(
    store, user_id: str, email: str, display_name: str | None, role: str, permissions: list[str],
    username: str | None = None, first_name: str | None = None, last_name: str | None = None,
    phone: str | None = None, birth_date: str | None = None, country: str | None = None,
    city: str | None = None, locale: str | None = None,
) -> None:
    pool = await store._get_pool()
    import json
    async with pool.acquire() as conn:
        await conn.execute(
            """INSERT INTO navi_users (
                   id, email, display_name, username, first_name, last_name,
                   phone, birth_date, country, city, locale,
                   role, permissions, created_at, updated_at
               )
               VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14, $14)
               ON CONFLICT (id) DO UPDATE
               SET email = EXCLUDED.email,
                   display_name = EXCLUDED.display_name,
                   username = EXCLUDED.username,
                   first_name = EXCLUDED.first_name,
                   last_name = EXCLUDED.last_name,
                   phone = EXCLUDED.phone,
                   birth_date = EXCLUDED.birth_date,
                   country = EXCLUDED.country,
                   city = EXCLUDED.city,
                   locale = EXCLUDED.locale,
                   role = EXCLUDED.role,
                   permissions = EXCLUDED.permissions,
                   updated_at = EXCLUDED.updated_at""",
            user_id, email, display_name or email, username, first_name, last_name,
            phone, birth_date, country, city, locale,
            role, json.dumps(permissions), datetime.now(timezone.utc),
        )


# ── FastAPI Depends helpers ─────────────────────────────────────────────────

async def require_user(user: Annotated[User | None, Depends(get_current_user)]) -> User:
    if not settings.navi_auth_enabled:
        return _ANONYMOUS_USER
    if user is None:
        raise HTTPException(status_code=401, detail="Authentication required")
    return user


async def require_admin(user: Annotated[User | None, Depends(get_current_user)]) -> User:
    if not settings.navi_auth_enabled:
        return _ANONYMOUS_USER
    if user is None:
        raise HTTPException(status_code=401, detail="Authentication required")
    if user.role != "admin":
        raise HTTPException(status_code=403, detail="Admin access required")
    return user


def require_permission(permission: str):
    async def _check(
        user: Annotated[User | None, Depends(get_current_user)],
    ) -> User:
        if not settings.navi_auth_enabled:
            return _ANONYMOUS_USER
        if user is None:
            raise HTTPException(status_code=401, detail="Authentication required")
        if user.role != "admin" and permission not in user.permissions:
            raise HTTPException(status_code=403, detail=f"Permission '{permission}' required")
        return user
    return _check


def check_session_access(session, user: User, permission: str | None = None) -> None:
    """Raise 403 if user does not own the session and lacks the required permission.

    Legacy sessions (user_id=None) are accessible only to admins.
    When auth is disabled, all sessions are shared by the single anonymous user.
    """
    from fastapi import HTTPException

    if not settings.navi_auth_enabled:
        return

    if session.user_id is None:
        if user.role != "admin":
            raise HTTPException(status_code=403, detail="Access denied")
        return
    if session.user_id == user.id:
        return
    if user.role == "admin":
        return
    if permission and user.has_permission(permission):
        return
    raise HTTPException(status_code=403, detail="Access denied")