"""FastAPI dependencies for auth resolution."""
import asyncio
from datetime import datetime, timezone
from typing import Annotated
import structlog
from fastapi import Depends, HTTPException, Request
from navi.config import settings
from . import User
from .client import get_gauth_client
from .encrypt import get_encryptor
log = structlog.get_logger()
async def _resolve_user(conn) -> User | None:
"""Shared logic to resolve user from a connection object (Request or WebSocket)."""
if conn is None:
return None
# Return cached user if already resolved this request
if hasattr(conn.state, "user") and conn.state.user is not None:
return conn.state.user
# Auth not configured — treat as anonymous
if not settings.gnauth_client_id or not settings.gnauth_client_secret:
return None
cookie_name = settings.navi_auth_cookie_name
session_id = conn.cookies.get(cookie_name)
if not session_id:
return None
# Look up the auth session in DB
try:
from navi.api.deps import get_session_store
except Exception:
# Avoid circular import during early bootstrap
return None
store = get_session_store()
row = await _get_auth_session(store, session_id)
if row is None:
return None
encryptor = get_encryptor()
access_token = encryptor.decrypt(row["access_token_enc"])
expires_at = row["expires_at"]
client = get_gauth_client()
# Refresh if expired
if datetime.now(timezone.utc) > expires_at:
try:
refresh_token = encryptor.decrypt(row["refresh_token_enc"])
token_set = await asyncio.to_thread(client.refresh_token, refresh_token)
access_token = token_set.access_token
# Update DB with new tokens
await _update_auth_session(
store,
session_id,
encryptor.encrypt(access_token),
encryptor.encrypt(token_set.refresh_token or refresh_token),
token_set.expires_at or datetime.now(timezone.utc),
)
log.info("auth.token_refreshed", user_id=row["user_id"])
except Exception:
log.warning("auth.refresh_failed", session_id=session_id[:8])
# Refresh failed — treat as unauthenticated
await _delete_auth_session(store, session_id)
return None
# Fetch user from gnexus-auth
try:
auth_user = await asyncio.to_thread(client.fetch_user, access_token)
except Exception:
log.warning("auth.fetch_user_failed", session_id=session_id[:8])
return None
# Determine role from client-level role_ids
role = "user"
permissions: list[str] = []
for access in auth_user.client_access_list:
if access.client_id == settings.gnauth_client_id:
if settings.gnauth_admin_role_slug in (access.role_ids or []):
role = "admin"
permissions = list(access.permission_ids or [])
break
# Upsert into navi_users
await _upsert_navi_user(store, auth_user.user_id, auth_user.email, auth_user.profile.get("display_name"), role, permissions)
user = User(
id=auth_user.user_id,
email=auth_user.email,
display_name=auth_user.profile.get("display_name") or auth_user.email,
avatar_url=auth_user.avatar_url,
role=role,
permissions=permissions,
)
# Update last_used_at
await _touch_auth_session(store, session_id)
conn.state.user = user
return user
async def get_current_user(request: Request) -> User | None:
"""Resolve the current user from the auth session cookie for REST requests."""
return await _resolve_user(request)
async def get_current_user_ws(websocket) -> User | None:
"""Resolve the current user from the auth session cookie for WebSocket connections."""
return await _resolve_user(websocket)
# ── Helpers that talk to the session store ──────────────────────────────────
async def _get_auth_session(store, session_id: str) -> dict | None:
"""Fetch a row from user_auth_sessions."""
pool = await store._get_pool()
async with pool.acquire() as conn:
row = await conn.fetchrow(
"SELECT user_id, access_token_enc, refresh_token_enc, expires_at "
"FROM user_auth_sessions WHERE id = $1",
session_id,
)
if row is None:
return None
return {
"user_id": row["user_id"],
"access_token_enc": row["access_token_enc"],
"refresh_token_enc": row["refresh_token_enc"],
"expires_at": row["expires_at"],
}
async def _update_auth_session(
store, session_id: str, access_token_enc: str, refresh_token_enc: str, expires_at: datetime
) -> None:
pool = await store._get_pool()
async with pool.acquire() as conn:
await conn.execute(
"UPDATE user_auth_sessions "
"SET access_token_enc = $1, refresh_token_enc = $2, expires_at = $3, last_used_at = $4 "
"WHERE id = $5",
access_token_enc, refresh_token_enc, expires_at, datetime.now(timezone.utc), session_id,
)
async def _delete_auth_session(store, session_id: str) -> None:
pool = await store._get_pool()
async with pool.acquire() as conn:
await conn.execute("DELETE FROM user_auth_sessions WHERE id = $1", session_id)
async def _touch_auth_session(store, session_id: str) -> None:
pool = await store._get_pool()
async with pool.acquire() as conn:
await conn.execute(
"UPDATE user_auth_sessions SET last_used_at = $1 WHERE id = $2",
datetime.now(timezone.utc), session_id,
)
async def _upsert_navi_user(
store, user_id: str, email: str, display_name: str | None, role: str, permissions: list[str]
) -> None:
pool = await store._get_pool()
import json
async with pool.acquire() as conn:
await conn.execute(
"""INSERT INTO navi_users (id, email, display_name, role, permissions, created_at, updated_at)
VALUES ($1, $2, $3, $4, $5, $6, $6)
ON CONFLICT (id) DO UPDATE
SET email = EXCLUDED.email,
display_name = EXCLUDED.display_name,
role = EXCLUDED.role,
permissions = EXCLUDED.permissions,
updated_at = EXCLUDED.updated_at""",
user_id, email, display_name or email, role, json.dumps(permissions), datetime.now(timezone.utc),
)
# ── FastAPI Depends helpers ─────────────────────────────────────────────────
async def require_user(user: Annotated[User | None, Depends(get_current_user)]) -> User:
if user is None:
raise HTTPException(status_code=401, detail="Authentication required")
return user
async def require_admin(user: Annotated[User | None, Depends(get_current_user)]) -> User:
if user is None:
raise HTTPException(status_code=401, detail="Authentication required")
if user.role != "admin":
raise HTTPException(status_code=403, detail="Admin access required")
return user
def require_permission(permission: str):
async def _check(
user: Annotated[User | None, Depends(get_current_user)],
) -> User:
if user is None:
raise HTTPException(status_code=401, detail="Authentication required")
if user.role != "admin" and permission not in user.permissions:
raise HTTPException(status_code=403, detail=f"Permission '{permission}' required")
return user
return _check
def check_session_access(session, user: User, permission: str | None = None) -> None:
"""Raise 403 if user does not own the session and lacks the required permission.
Legacy sessions (user_id=None) are accessible only to admins.
"""
from fastapi import HTTPException
if session.user_id is None:
if user.role != "admin":
raise HTTPException(status_code=403, detail="Access denied")
return
if session.user_id == user.id:
return
if user.role == "admin":
return
if permission and user.has_permission(permission):
return
raise HTTPException(status_code=403, detail="Access denied")