Newer
Older
navi-1 / navi / auth / deps.py
"""FastAPI dependencies for auth resolution."""

import asyncio
from datetime import datetime, timezone
from typing import Annotated

import structlog
from fastapi import Depends, HTTPException, Request

from navi.config import settings

from . import User
from .client import get_gauth_client
from .encrypt import get_encryptor

log = structlog.get_logger()


async def _resolve_user(conn) -> User | None:
    """Shared logic to resolve user from a connection object (Request or WebSocket)."""
    if conn is None:
        return None

    # Return cached user if already resolved this request
    if hasattr(conn.state, "user") and conn.state.user is not None:
        return conn.state.user

    # Auth not configured — treat as anonymous
    if not settings.gnauth_client_id or not settings.gnauth_client_secret:
        return None

    cookie_name = settings.navi_auth_cookie_name
    session_id = conn.cookies.get(cookie_name)
    if not session_id:
        return None

    # Look up the auth session in DB
    try:
        from navi.api.deps import get_session_store
    except Exception:
        # Avoid circular import during early bootstrap
        return None

    store = get_session_store()
    row = await _get_auth_session(store, session_id)
    if row is None:
        return None

    encryptor = get_encryptor()
    access_token = encryptor.decrypt(row["access_token_enc"])
    expires_at = row["expires_at"]

    client = get_gauth_client()

    # Refresh if expired
    if datetime.now(timezone.utc) > expires_at:
        try:
            refresh_token = encryptor.decrypt(row["refresh_token_enc"])
            token_set = await asyncio.to_thread(client.refresh_token, refresh_token)
            access_token = token_set.access_token
            # Update DB with new tokens
            await _update_auth_session(
                store,
                session_id,
                encryptor.encrypt(access_token),
                encryptor.encrypt(token_set.refresh_token or refresh_token),
                token_set.expires_at or datetime.now(timezone.utc),
            )
            log.info("auth.token_refreshed", user_id=row["user_id"])
        except Exception:
            log.warning("auth.refresh_failed", session_id=session_id[:8])
            # Refresh failed — treat as unauthenticated
            await _delete_auth_session(store, session_id)
            return None

    # Fetch user from gnexus-auth
    try:
        auth_user = await asyncio.to_thread(client.fetch_user, access_token)
    except Exception:
        log.warning("auth.fetch_user_failed", session_id=session_id[:8])
        return None

    # Determine role from client-level role_ids
    role = "user"
    permissions: list[str] = []
    for access in auth_user.client_access_list:
        if access.client_id == settings.gnauth_client_id:
            if settings.gnauth_admin_role_slug in (access.role_ids or []):
                role = "admin"
            permissions = list(access.permission_ids or [])
            break

    # Upsert into navi_users
    await _upsert_navi_user(store, auth_user.user_id, auth_user.email, auth_user.profile.get("display_name"), role, permissions)

    user = User(
        id=auth_user.user_id,
        email=auth_user.email,
        display_name=auth_user.profile.get("display_name") or auth_user.email,
        avatar_url=auth_user.avatar_url,
        role=role,
        permissions=permissions,
    )

    # Update last_used_at
    await _touch_auth_session(store, session_id)
    conn.state.user = user
    return user


async def get_current_user(request: Request) -> User | None:
    """Resolve the current user from the auth session cookie for REST requests."""
    return await _resolve_user(request)


async def get_current_user_ws(websocket) -> User | None:
    """Resolve the current user from the auth session cookie for WebSocket connections."""
    return await _resolve_user(websocket)


# ── Helpers that talk to the session store ──────────────────────────────────

async def _get_auth_session(store, session_id: str) -> dict | None:
    """Fetch a row from user_auth_sessions."""
    pool = await store._get_pool()
    async with pool.acquire() as conn:
        row = await conn.fetchrow(
            "SELECT user_id, access_token_enc, refresh_token_enc, expires_at "
            "FROM user_auth_sessions WHERE id = $1",
            session_id,
        )
    if row is None:
        return None
    return {
        "user_id": row["user_id"],
        "access_token_enc": row["access_token_enc"],
        "refresh_token_enc": row["refresh_token_enc"],
        "expires_at": row["expires_at"],
    }


async def _update_auth_session(
    store, session_id: str, access_token_enc: str, refresh_token_enc: str, expires_at: datetime
) -> None:
    pool = await store._get_pool()
    async with pool.acquire() as conn:
        await conn.execute(
            "UPDATE user_auth_sessions "
            "SET access_token_enc = $1, refresh_token_enc = $2, expires_at = $3, last_used_at = $4 "
            "WHERE id = $5",
            access_token_enc, refresh_token_enc, expires_at, datetime.now(timezone.utc), session_id,
        )


async def _delete_auth_session(store, session_id: str) -> None:
    pool = await store._get_pool()
    async with pool.acquire() as conn:
        await conn.execute("DELETE FROM user_auth_sessions WHERE id = $1", session_id)


async def _touch_auth_session(store, session_id: str) -> None:
    pool = await store._get_pool()
    async with pool.acquire() as conn:
        await conn.execute(
            "UPDATE user_auth_sessions SET last_used_at = $1 WHERE id = $2",
            datetime.now(timezone.utc), session_id,
        )


async def _upsert_navi_user(
    store, user_id: str, email: str, display_name: str | None, role: str, permissions: list[str]
) -> None:
    pool = await store._get_pool()
    import json
    async with pool.acquire() as conn:
        await conn.execute(
            """INSERT INTO navi_users (id, email, display_name, role, permissions, created_at, updated_at)
               VALUES ($1, $2, $3, $4, $5, $6, $6)
               ON CONFLICT (id) DO UPDATE
               SET email = EXCLUDED.email,
                   display_name = EXCLUDED.display_name,
                   role = EXCLUDED.role,
                   permissions = EXCLUDED.permissions,
                   updated_at = EXCLUDED.updated_at""",
            user_id, email, display_name or email, role, json.dumps(permissions), datetime.now(timezone.utc),
        )


# ── FastAPI Depends helpers ─────────────────────────────────────────────────

async def require_user(user: Annotated[User | None, Depends(get_current_user)]) -> User:
    if user is None:
        raise HTTPException(status_code=401, detail="Authentication required")
    return user


async def require_admin(user: Annotated[User | None, Depends(get_current_user)]) -> User:
    if user is None:
        raise HTTPException(status_code=401, detail="Authentication required")
    if user.role != "admin":
        raise HTTPException(status_code=403, detail="Admin access required")
    return user


def require_permission(permission: str):
    async def _check(
        user: Annotated[User | None, Depends(get_current_user)],
    ) -> User:
        if user is None:
            raise HTTPException(status_code=401, detail="Authentication required")
        if user.role != "admin" and permission not in user.permissions:
            raise HTTPException(status_code=403, detail=f"Permission '{permission}' required")
        return user
    return _check


def check_session_access(session, user: User, permission: str | None = None) -> None:
    """Raise 403 if user does not own the session and lacks the required permission.

    Legacy sessions (user_id=None) are accessible only to admins.
    """
    from fastapi import HTTPException

    if session.user_id is None:
        if user.role != "admin":
            raise HTTPException(status_code=403, detail="Access denied")
        return
    if session.user_id == user.id:
        return
    if user.role == "admin":
        return
    if permission and user.has_permission(permission):
        return
    raise HTTPException(status_code=403, detail="Access denied")