"""FastAPI dependencies for auth resolution."""
import asyncio
import hashlib
import time
from datetime import datetime, timezone
from typing import Annotated
import structlog
from fastapi import Depends, HTTPException, Request
from navi.config import settings
from . import ApiToken, User
from .client import get_gauth_client
from .encrypt import get_encryptor
log = structlog.get_logger()
# Per-session-id lock to prevent concurrent token refresh races.
_refresh_locks: dict[str, asyncio.Lock] = {}
# In-memory cache for fetch_user results: session_id -> (User, timestamp).
# Reduces hammering gnexus-auth and protects against eventual-consistency
# flakes right after a token refresh.
_USER_CACHE_TTL: float = 30.0
_user_cache: dict[str, tuple[User, float]] = {}
def _cleanup_refresh_locks() -> None:
"""Remove stale refresh locks to prevent unbounded memory growth."""
now = time.time()
stale = [sid for sid, lock in _refresh_locks.items() if sid not in _user_cache or _user_cache[sid][1] + _USER_CACHE_TTL < now]
for sid in stale:
_refresh_locks.pop(sid, None)
def _get_cached_user(session_id: str) -> User | None:
"""Return a cached User if it is still fresh."""
entry = _user_cache.get(session_id)
if entry is None:
_cleanup_refresh_locks()
return None
user, ts = entry
if time.monotonic() - ts > _USER_CACHE_TTL:
_user_cache.pop(session_id, None)
_cleanup_refresh_locks()
return None
return user
def _set_cached_user(session_id: str, user: User) -> None:
"""Store a resolved User in the short-lived cache."""
_user_cache[session_id] = (user, time.monotonic())
async def _fetch_user_with_retry(client, access_token: str, session_id: str) -> "User | None":
"""Call fetch_user, retry once after 1.5 s on transient failure."""
for attempt in range(2):
try:
return await asyncio.to_thread(client.fetch_user, access_token)
except Exception:
if attempt == 0:
log.debug(
"auth.fetch_user_retry",
session_id=session_id[:8],
attempt=attempt + 1,
)
await asyncio.sleep(1.5)
else:
log.warning("auth.fetch_user_failed", session_id=session_id[:8])
return None
async def _resolve_user(conn) -> User | None:
"""Shared logic to resolve user from a connection object (Request or WebSocket).
Any failure during resolution is silently treated as anonymous to avoid
crashing WebSocket upgrades or other dependency-injected paths.
"""
try:
log.debug("auth.resolve_start", conn_type=type(conn).__name__)
if conn is None:
log.debug("auth.resolve_no_conn")
return None
# Return cached user if already resolved this request
if hasattr(conn.state, "user") and conn.state.user is not None:
log.debug("auth.resolve_cached", user_id=conn.state.user.id)
return conn.state.user
# Auth not configured — treat as anonymous
if not settings.gnauth_client_id or not settings.gnauth_client_secret:
log.debug("auth.resolve_not_configured")
return None
cookie_name = settings.navi_auth_cookie_name
session_id = conn.cookies.get(cookie_name)
if not session_id:
log.debug("auth.resolve_no_cookie", cookie_name=cookie_name)
# Try API token fallback before giving up
api_user = await _resolve_user_from_api_token(conn)
if api_user is not None:
conn.state.user = api_user
return api_user
log.debug("auth.resolve_cookie_found", session_id=session_id[:8])
# Look up the auth session in DB
try:
from navi.api.deps import get_session_store
except Exception:
# Avoid circular import during early bootstrap
log.debug("auth.resolve_store_import_failed")
return None
store = get_session_store()
row = await _get_auth_session(store, session_id)
if row is None:
log.debug("auth.resolve_session_not_found", session_id=session_id[:8])
return None
log.debug("auth.resolve_session_found", user_id=row["user_id"])
encryptor = get_encryptor()
access_token = encryptor.decrypt(row["access_token_enc"])
expires_at = row["expires_at"]
client = get_gauth_client()
# Refresh if expired
if datetime.now(timezone.utc) > expires_at:
lock = _refresh_locks.setdefault(session_id, asyncio.Lock())
async with lock:
# Re-read session inside lock — another request may have refreshed it already
row = await _get_auth_session(store, session_id)
if row is None:
log.debug("auth.resolve_session_gone_during_refresh", session_id=session_id[:8])
return None
expires_at = row["expires_at"]
if datetime.now(timezone.utc) <= expires_at:
access_token = encryptor.decrypt(row["access_token_enc"])
else:
try:
refresh_token = encryptor.decrypt(row["refresh_token_enc"])
token_set = await asyncio.to_thread(client.refresh_token, refresh_token)
access_token = token_set.access_token
await _update_auth_session(
store,
session_id,
encryptor.encrypt(access_token),
encryptor.encrypt(token_set.refresh_token or refresh_token),
token_set.expires_at or datetime.now(timezone.utc),
)
log.info("auth.token_refreshed", user_id=row["user_id"])
except Exception:
log.warning("auth.refresh_failed", session_id=session_id[:8])
# Do NOT delete the session — transient errors (network,
# race-condition with parallel refresh) should not force
# a full re-login.
# Fallback: try API token before giving up.
api_user = await _resolve_user_from_api_token(conn)
if api_user is not None:
conn.state.user = api_user
return api_user
# Check short-lived fetch_user cache before hitting gnexus-auth
cached = _get_cached_user(session_id)
if cached is not None:
log.debug("auth.resolve_user_cache_hit", user_id=cached.id)
conn.state.user = cached
return cached
# Fetch user from gnexus-auth (with one retry on transient failure)
auth_user = await _fetch_user_with_retry(client, access_token, session_id)
if auth_user is None:
return None
log.debug("auth.resolve_user_fetched", user_id=auth_user.user_id)
# Determine role from client-level role_ids
role = "user"
permissions: list[str] = []
for access in auth_user.client_access_list:
if access.client_id == settings.gnauth_client_id:
if settings.gnauth_admin_role_slug in (access.role_ids or []):
role = "admin"
permissions = list(access.permission_ids or [])
break
# Upsert into navi_users
profile = auth_user.profile
await _upsert_navi_user(
store, auth_user.user_id, auth_user.email,
profile.get("display_name"), role, permissions,
username=profile.get("username"),
first_name=profile.get("first_name"),
last_name=profile.get("last_name"),
phone=profile.get("phone"),
birth_date=profile.get("birth_date"),
country=profile.get("country"),
city=profile.get("city"),
locale=profile.get("locale"),
)
profile = auth_user.profile
user = User(
id=auth_user.user_id,
email=auth_user.email,
display_name=profile.get("display_name") or auth_user.email,
username=profile.get("username"),
first_name=profile.get("first_name"),
last_name=profile.get("last_name"),
phone=profile.get("phone"),
birth_date=profile.get("birth_date"),
country=profile.get("country"),
city=profile.get("city"),
locale=profile.get("locale"),
avatar_url=auth_user.avatar_url,
role=role,
permissions=permissions,
)
# Update last_used_at and cache
await _touch_auth_session(store, session_id)
_set_cached_user(session_id, user)
conn.state.user = user
log.debug("auth.resolve_success", user_id=user.id)
return user
except Exception:
# Any unexpected failure during auth resolution should not crash the
# request — treat as anonymous so WebSocket upgrades and REST calls
# degrade gracefully.
log.warning("auth.resolve_failed", exc_info=True)
return None
# ── API token resolution helpers ────────────────────────────────────────────
async def _resolve_user_from_api_token(conn) -> User | None:
"""Resolve user from X-Api-Token header or ?api_token query parameter."""
# Try header first, then query param (for WebSocket or REST)
token = None
try:
token = conn.headers.get("X-Api-Token")
except Exception:
pass
if not token:
try:
token = conn.query_params.get("api_token")
except Exception:
pass
if not token:
return None
try:
from navi.api.deps import get_session_store
except Exception:
return None
store = get_session_store()
pool = await store._get_pool()
token_hash = hashlib.sha256(token.encode()).hexdigest()
async with pool.acquire() as db_conn:
user_row = await db_conn.fetchrow(
"SELECT u.id, u.email, u.display_name, u.username, u.first_name, u.last_name, "
"u.phone, u.birth_date, u.country, u.city, u.locale, u.role, u.permissions, "
"t.id AS token_id, t.revoked_at "
"FROM api_tokens t JOIN navi_users u ON u.id = t.user_id "
"WHERE t.token_hash = $1",
token_hash,
)
if user_row is None:
log.debug("auth.api_token_not_found")
return None
if user_row["revoked_at"] is not None:
log.debug("auth.api_token_revoked", token_id=user_row["token_id"])
return None
import json
permissions = json.loads(user_row["permissions"] or "[]")
user = User(
id=user_row["id"],
email=user_row["email"],
display_name=user_row["display_name"] or user_row["email"],
username=user_row["username"],
first_name=user_row["first_name"],
last_name=user_row["last_name"],
phone=user_row["phone"],
birth_date=user_row["birth_date"],
country=user_row["country"],
city=user_row["city"],
locale=user_row["locale"],
role=user_row["role"],
permissions=permissions,
)
# Update last_used_at asynchronously (best-effort)
try:
async with pool.acquire() as db_conn:
await db_conn.execute(
"UPDATE api_tokens SET last_used_at = $1 WHERE id = $2",
datetime.now(timezone.utc), user_row["token_id"],
)
except Exception:
pass
log.debug("auth.api_token_resolved", user_id=user.id)
return user
async def get_current_user(request: Request) -> User | None:
"""Resolve the current user from the auth session cookie for REST requests."""
return await _resolve_user(request)
async def get_current_user_ws(websocket) -> User | None:
"""Resolve the current user from the auth session cookie for WebSocket connections."""
return await _resolve_user(websocket)
# ── Helpers that talk to the session store ──────────────────────────────────
async def _get_auth_session(store, session_id: str) -> dict | None:
"""Fetch a row from user_auth_sessions."""
pool = await store._get_pool()
async with pool.acquire() as conn:
row = await conn.fetchrow(
"SELECT user_id, access_token_enc, refresh_token_enc, expires_at "
"FROM user_auth_sessions WHERE id = $1",
session_id,
)
if row is None:
return None
return {
"user_id": row["user_id"],
"access_token_enc": row["access_token_enc"],
"refresh_token_enc": row["refresh_token_enc"],
"expires_at": row["expires_at"],
}
async def _update_auth_session(
store, session_id: str, access_token_enc: str, refresh_token_enc: str, expires_at: datetime
) -> None:
pool = await store._get_pool()
async with pool.acquire() as conn:
await conn.execute(
"UPDATE user_auth_sessions "
"SET access_token_enc = $1, refresh_token_enc = $2, expires_at = $3, last_used_at = $4 "
"WHERE id = $5",
access_token_enc, refresh_token_enc, expires_at, datetime.now(timezone.utc), session_id,
)
async def _delete_auth_session(store, session_id: str) -> None:
pool = await store._get_pool()
async with pool.acquire() as conn:
await conn.execute("DELETE FROM user_auth_sessions WHERE id = $1", session_id)
async def _touch_auth_session(store, session_id: str) -> None:
pool = await store._get_pool()
async with pool.acquire() as conn:
await conn.execute(
"UPDATE user_auth_sessions SET last_used_at = $1 WHERE id = $2",
datetime.now(timezone.utc), session_id,
)
async def _upsert_navi_user(
store, user_id: str, email: str, display_name: str | None, role: str, permissions: list[str],
username: str | None = None, first_name: str | None = None, last_name: str | None = None,
phone: str | None = None, birth_date: str | None = None, country: str | None = None,
city: str | None = None, locale: str | None = None,
) -> None:
pool = await store._get_pool()
import json
async with pool.acquire() as conn:
await conn.execute(
"""INSERT INTO navi_users (
id, email, display_name, username, first_name, last_name,
phone, birth_date, country, city, locale,
role, permissions, created_at, updated_at
)
VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14, $14)
ON CONFLICT (id) DO UPDATE
SET email = EXCLUDED.email,
display_name = EXCLUDED.display_name,
username = EXCLUDED.username,
first_name = EXCLUDED.first_name,
last_name = EXCLUDED.last_name,
phone = EXCLUDED.phone,
birth_date = EXCLUDED.birth_date,
country = EXCLUDED.country,
city = EXCLUDED.city,
locale = EXCLUDED.locale,
role = EXCLUDED.role,
permissions = EXCLUDED.permissions,
updated_at = EXCLUDED.updated_at""",
user_id, email, display_name or email, username, first_name, last_name,
phone, birth_date, country, city, locale,
role, json.dumps(permissions), datetime.now(timezone.utc),
)
# ── FastAPI Depends helpers ─────────────────────────────────────────────────
async def require_user(user: Annotated[User | None, Depends(get_current_user)]) -> User:
if user is None:
raise HTTPException(status_code=401, detail="Authentication required")
return user
async def require_admin(user: Annotated[User | None, Depends(get_current_user)]) -> User:
if user is None:
raise HTTPException(status_code=401, detail="Authentication required")
if user.role != "admin":
raise HTTPException(status_code=403, detail="Admin access required")
return user
def require_permission(permission: str):
async def _check(
user: Annotated[User | None, Depends(get_current_user)],
) -> User:
if user is None:
raise HTTPException(status_code=401, detail="Authentication required")
if user.role != "admin" and permission not in user.permissions:
raise HTTPException(status_code=403, detail=f"Permission '{permission}' required")
return user
return _check
def check_session_access(session, user: User, permission: str | None = None) -> None:
"""Raise 403 if user does not own the session and lacks the required permission.
Legacy sessions (user_id=None) are accessible only to admins.
"""
from fastapi import HTTPException
if session.user_id is None:
if user.role != "admin":
raise HTTPException(status_code=403, detail="Access denied")
return
if session.user_id == user.id:
return
if user.role == "admin":
return
if permission and user.has_permission(permission):
return
raise HTTPException(status_code=403, detail="Access denied")