"""FastAPI dependencies for auth resolution."""
import asyncio
import hashlib
import time
from datetime import datetime, timezone
from typing import Annotated
import structlog
from fastapi import Depends, HTTPException, Request
from gnexus_gauth.exceptions import TokenRefreshException
from navi.config import settings
from . import ApiToken, User
from .client import get_gauth_client
from .encrypt import get_encryptor
log = structlog.get_logger()
# Per-session-id lock to prevent concurrent token refresh races.
_refresh_locks: dict[str, asyncio.Lock] = {}
# In-memory cache for fetch_user results: session_id -> (User, timestamp).
# Reduces hammering gnexus-auth and protects against eventual-consistency
# flakes right after a token refresh.
_USER_CACHE_TTL: float = 30.0
_user_cache: dict[str, tuple[User, float]] = {}
def _cleanup_refresh_locks() -> None:
"""Remove stale refresh locks to prevent unbounded memory growth."""
now = time.time()
stale = [sid for sid, lock in _refresh_locks.items() if sid not in _user_cache or _user_cache[sid][1] + _USER_CACHE_TTL < now]
for sid in stale:
_refresh_locks.pop(sid, None)
def _get_cached_user(session_id: str) -> User | None:
"""Return a cached User if it is still fresh."""
entry = _user_cache.get(session_id)
if entry is None:
_cleanup_refresh_locks()
return None
user, ts = entry
if time.monotonic() - ts > _USER_CACHE_TTL:
_user_cache.pop(session_id, None)
_cleanup_refresh_locks()
return None
return user
def _set_cached_user(session_id: str, user: User) -> None:
"""Store a resolved User in the short-lived cache."""
_user_cache[session_id] = (user, time.monotonic())
async def _fetch_user_with_retry(client, access_token: str, session_id: str) -> "User | None":
"""Call fetch_user, retry once after 1.5 s on transient failure."""
for attempt in range(2):
try:
return await asyncio.to_thread(client.fetch_user, access_token)
except Exception:
if attempt == 0:
log.debug(
"auth.fetch_user_retry",
session_id=session_id[:8],
attempt=attempt + 1,
)
await asyncio.sleep(1.5)
else:
log.warning("auth.fetch_user_failed", session_id=session_id[:8])
return None
async def _resolve_user(conn) -> User | None:
"""Shared logic to resolve user from a connection object (Request or WebSocket).
Any failure during resolution is silently treated as anonymous to avoid
crashing WebSocket upgrades or other dependency-injected paths.
"""
try:
log.debug("auth.resolve_start", conn_type=type(conn).__name__)
if conn is None:
log.debug("auth.resolve_no_conn")
return None
# Return cached user if already resolved this request
if hasattr(conn.state, "user") and conn.state.user is not None:
log.debug("auth.resolve_cached", user_id=conn.state.user.id)
return conn.state.user
# Auth not configured — treat as anonymous
if not settings.gnauth_client_id or not settings.gnauth_client_secret:
log.debug("auth.resolve_not_configured")
return None
cookie_name = settings.navi_auth_cookie_name
session_id = conn.cookies.get(cookie_name)
if not session_id:
log.debug("auth.resolve_no_cookie", cookie_name=cookie_name)
# Try API token fallback before giving up
api_user = await _resolve_user_from_api_token(conn)
if api_user is not None:
conn.state.user = api_user
return api_user
log.debug("auth.resolve_cookie_found", session_id=session_id[:8])
# Look up the auth session in DB
try:
from navi.api.deps import get_session_store
except Exception:
# Avoid circular import during early bootstrap
log.debug("auth.resolve_store_import_failed")
return None
store = get_session_store()
row = await _get_auth_session(store, session_id)
if row is None:
log.debug("auth.resolve_session_not_found", session_id=session_id[:8])
return None
log.debug("auth.resolve_session_found", user_id=row["user_id"])
encryptor = get_encryptor()
access_token = encryptor.decrypt(row["access_token_enc"])
expires_at = row["expires_at"]
client = get_gauth_client()
# Refresh if expired
if datetime.now(timezone.utc) > expires_at:
lock = _refresh_locks.setdefault(session_id, asyncio.Lock())
async with lock:
# Re-read session inside lock — another request may have refreshed it already
row = await _get_auth_session(store, session_id)
if row is None:
log.debug("auth.resolve_session_gone_during_refresh", session_id=session_id[:8])
return None
expires_at = row["expires_at"]
if datetime.now(timezone.utc) <= expires_at:
access_token = encryptor.decrypt(row["access_token_enc"])
else:
try:
refresh_token = encryptor.decrypt(row["refresh_token_enc"])
token_set = await asyncio.to_thread(client.refresh_token, refresh_token)
access_token = token_set.access_token
await _update_auth_session(
store,
session_id,
encryptor.encrypt(access_token),
encryptor.encrypt(token_set.refresh_token or refresh_token),
token_set.expires_at or datetime.now(timezone.utc),
)
log.info("auth.token_refreshed", user_id=row["user_id"])
except TokenRefreshException as exc:
# Refresh token is definitively invalid (expired, revoked,
# rotated by another device). Force re-login.
log.warning(
"auth.refresh_token_invalid",
session_id=session_id[:8],
user_id=row["user_id"],
reason=str(exc),
)
# Try API token before giving up.
api_user = await _resolve_user_from_api_token(conn)
if api_user is not None:
conn.state.user = api_user
return api_user
except Exception as exc:
# Transient errors (network, timeout, 5xx from auth server).
# Do NOT delete the session — force-logout is too harsh.
# Best-effort: serve from cache if available, else try API token.
log.warning(
"auth.refresh_transient_fail",
session_id=session_id[:8],
user_id=row["user_id"],
exc_type=type(exc).__name__,
error=str(exc),
)
cached = _get_cached_user(session_id)
if cached is not None:
log.debug("auth.fallback_to_cached_user", user_id=cached.id)
conn.state.user = cached
return cached
api_user = await _resolve_user_from_api_token(conn)
if api_user is not None:
conn.state.user = api_user
return api_user
# Check short-lived fetch_user cache before hitting gnexus-auth
cached = _get_cached_user(session_id)
if cached is not None:
log.debug("auth.resolve_user_cache_hit", user_id=cached.id)
conn.state.user = cached
return cached
# Fetch user from gnexus-auth (with one retry on transient failure)
auth_user = await _fetch_user_with_retry(client, access_token, session_id)
if auth_user is None:
return None
log.debug("auth.resolve_user_fetched", user_id=auth_user.user_id)
# Determine role from client-level role_ids
role = "user"
permissions: list[str] = []
for access in auth_user.client_access_list:
if access.client_id == settings.gnauth_client_id:
if settings.gnauth_admin_role_slug in (access.role_ids or []):
role = "admin"
permissions = list(access.permission_ids or [])
break
# Upsert into navi_users
profile = auth_user.profile
await _upsert_navi_user(
store, auth_user.user_id, auth_user.email,
profile.get("display_name"), role, permissions,
username=profile.get("username"),
first_name=profile.get("first_name"),
last_name=profile.get("last_name"),
phone=profile.get("phone"),
birth_date=profile.get("birth_date"),
country=profile.get("country"),
city=profile.get("city"),
locale=profile.get("locale"),
)
profile = auth_user.profile
user = User(
id=auth_user.user_id,
email=auth_user.email,
display_name=profile.get("display_name") or auth_user.email,
username=profile.get("username"),
first_name=profile.get("first_name"),
last_name=profile.get("last_name"),
phone=profile.get("phone"),
birth_date=profile.get("birth_date"),
country=profile.get("country"),
city=profile.get("city"),
locale=profile.get("locale"),
avatar_url=auth_user.avatar_url,
role=role,
permissions=permissions,
)
# Update last_used_at and cache
await _touch_auth_session(store, session_id)
_set_cached_user(session_id, user)
conn.state.user = user
log.debug("auth.resolve_success", user_id=user.id)
return user
except Exception:
# Any unexpected failure during auth resolution should not crash the
# request — treat as anonymous so WebSocket upgrades and REST calls
# degrade gracefully.
log.warning("auth.resolve_failed", exc_info=True)
return None
# ── API token resolution helpers ────────────────────────────────────────────
async def _resolve_user_from_api_token(conn) -> User | None:
"""Resolve user from X-Api-Token header or ?api_token query parameter."""
# Try header first, then query param (for WebSocket or REST)
token = None
try:
token = conn.headers.get("X-Api-Token")
except Exception:
pass
if not token:
try:
token = conn.query_params.get("api_token")
except Exception:
pass
if not token:
return None
try:
from navi.api.deps import get_session_store
except Exception:
return None
store = get_session_store()
pool = await store._get_pool()
token_hash = hashlib.sha256(token.encode()).hexdigest()
async with pool.acquire() as db_conn:
user_row = await db_conn.fetchrow(
"SELECT u.id, u.email, u.display_name, u.username, u.first_name, u.last_name, "
"u.phone, u.birth_date, u.country, u.city, u.locale, u.role, u.permissions, "
"t.id AS token_id, t.revoked_at "
"FROM api_tokens t JOIN navi_users u ON u.id = t.user_id "
"WHERE t.token_hash = $1",
token_hash,
)
if user_row is None:
log.debug("auth.api_token_not_found")
return None
if user_row["revoked_at"] is not None:
log.debug("auth.api_token_revoked", token_id=user_row["token_id"])
return None
import json
permissions = json.loads(user_row["permissions"] or "[]")
user = User(
id=user_row["id"],
email=user_row["email"],
display_name=user_row["display_name"] or user_row["email"],
username=user_row["username"],
first_name=user_row["first_name"],
last_name=user_row["last_name"],
phone=user_row["phone"],
birth_date=user_row["birth_date"],
country=user_row["country"],
city=user_row["city"],
locale=user_row["locale"],
role=user_row["role"],
permissions=permissions,
)
# Update last_used_at asynchronously (best-effort)
try:
async with pool.acquire() as db_conn:
await db_conn.execute(
"UPDATE api_tokens SET last_used_at = $1 WHERE id = $2",
datetime.now(timezone.utc), user_row["token_id"],
)
except Exception:
pass
log.debug("auth.api_token_resolved", user_id=user.id)
return user
async def get_current_user(request: Request) -> User | None:
"""Resolve the current user from the auth session cookie for REST requests."""
return await _resolve_user(request)
async def get_current_user_ws(websocket) -> User | None:
"""Resolve the current user from the auth session cookie for WebSocket connections."""
return await _resolve_user(websocket)
# ── Helpers that talk to the session store ──────────────────────────────────
async def _get_auth_session(store, session_id: str) -> dict | None:
"""Fetch a row from user_auth_sessions."""
pool = await store._get_pool()
async with pool.acquire() as conn:
row = await conn.fetchrow(
"SELECT user_id, access_token_enc, refresh_token_enc, expires_at "
"FROM user_auth_sessions WHERE id = $1",
session_id,
)
if row is None:
return None
return {
"user_id": row["user_id"],
"access_token_enc": row["access_token_enc"],
"refresh_token_enc": row["refresh_token_enc"],
"expires_at": row["expires_at"],
}
async def _update_auth_session(
store, session_id: str, access_token_enc: str, refresh_token_enc: str, expires_at: datetime
) -> None:
pool = await store._get_pool()
async with pool.acquire() as conn:
await conn.execute(
"UPDATE user_auth_sessions "
"SET access_token_enc = $1, refresh_token_enc = $2, expires_at = $3, last_used_at = $4 "
"WHERE id = $5",
access_token_enc, refresh_token_enc, expires_at, datetime.now(timezone.utc), session_id,
)
async def _delete_auth_session(store, session_id: str) -> None:
pool = await store._get_pool()
async with pool.acquire() as conn:
await conn.execute("DELETE FROM user_auth_sessions WHERE id = $1", session_id)
async def _touch_auth_session(store, session_id: str) -> None:
pool = await store._get_pool()
async with pool.acquire() as conn:
await conn.execute(
"UPDATE user_auth_sessions SET last_used_at = $1 WHERE id = $2",
datetime.now(timezone.utc), session_id,
)
async def _upsert_navi_user(
store, user_id: str, email: str, display_name: str | None, role: str, permissions: list[str],
username: str | None = None, first_name: str | None = None, last_name: str | None = None,
phone: str | None = None, birth_date: str | None = None, country: str | None = None,
city: str | None = None, locale: str | None = None,
) -> None:
pool = await store._get_pool()
import json
async with pool.acquire() as conn:
await conn.execute(
"""INSERT INTO navi_users (
id, email, display_name, username, first_name, last_name,
phone, birth_date, country, city, locale,
role, permissions, created_at, updated_at
)
VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14, $14)
ON CONFLICT (id) DO UPDATE
SET email = EXCLUDED.email,
display_name = EXCLUDED.display_name,
username = EXCLUDED.username,
first_name = EXCLUDED.first_name,
last_name = EXCLUDED.last_name,
phone = EXCLUDED.phone,
birth_date = EXCLUDED.birth_date,
country = EXCLUDED.country,
city = EXCLUDED.city,
locale = EXCLUDED.locale,
role = EXCLUDED.role,
permissions = EXCLUDED.permissions,
updated_at = EXCLUDED.updated_at""",
user_id, email, display_name or email, username, first_name, last_name,
phone, birth_date, country, city, locale,
role, json.dumps(permissions), datetime.now(timezone.utc),
)
# ── FastAPI Depends helpers ─────────────────────────────────────────────────
async def require_user(user: Annotated[User | None, Depends(get_current_user)]) -> User:
if user is None:
raise HTTPException(status_code=401, detail="Authentication required")
return user
async def require_admin(user: Annotated[User | None, Depends(get_current_user)]) -> User:
if user is None:
raise HTTPException(status_code=401, detail="Authentication required")
if user.role != "admin":
raise HTTPException(status_code=403, detail="Admin access required")
return user
def require_permission(permission: str):
async def _check(
user: Annotated[User | None, Depends(get_current_user)],
) -> User:
if user is None:
raise HTTPException(status_code=401, detail="Authentication required")
if user.role != "admin" and permission not in user.permissions:
raise HTTPException(status_code=403, detail=f"Permission '{permission}' required")
return user
return _check
def check_session_access(session, user: User, permission: str | None = None) -> None:
"""Raise 403 if user does not own the session and lacks the required permission.
Legacy sessions (user_id=None) are accessible only to admins.
"""
from fastapi import HTTPException
if session.user_id is None:
if user.role != "admin":
raise HTTPException(status_code=403, detail="Access denied")
return
if session.user_id == user.id:
return
if user.role == "admin":
return
if permission and user.has_permission(permission):
return
raise HTTPException(status_code=403, detail="Access denied")