"""FastAPI dependencies for auth resolution."""

import asyncio
import hashlib
import time
from datetime import datetime, timezone
from typing import Annotated

import structlog
from fastapi import Depends, HTTPException, Request
from gnexus_gauth.exceptions import TokenRefreshException

from navi.config import settings

from . import ApiToken, User
from .client import get_gauth_client
from .encrypt import get_encryptor

log = structlog.get_logger()

# Per-session-id lock to prevent concurrent token refresh races.
_refresh_locks: dict[str, asyncio.Lock] = {}

# In-memory cache for fetch_user results: session_id -> (User, timestamp).
# Reduces hammering gnexus-auth and protects against eventual-consistency
# flakes right after a token refresh.
_USER_CACHE_TTL: float = 30.0
_user_cache: dict[str, tuple[User, float]] = {}


def _cleanup_refresh_locks() -> None:
    """Remove stale refresh locks to prevent unbounded memory growth."""
    now = time.time()
    stale = [sid for sid, lock in _refresh_locks.items() if sid not in _user_cache or _user_cache[sid][1] + _USER_CACHE_TTL < now]
    for sid in stale:
        _refresh_locks.pop(sid, None)


def _get_cached_user(session_id: str) -> User | None:
    """Return a cached User if it is still fresh."""
    entry = _user_cache.get(session_id)
    if entry is None:
        _cleanup_refresh_locks()
        return None
    user, ts = entry
    if time.monotonic() - ts > _USER_CACHE_TTL:
        _user_cache.pop(session_id, None)
        _cleanup_refresh_locks()
        return None
    return user


def _set_cached_user(session_id: str, user: User) -> None:
    """Store a resolved User in the short-lived cache."""
    _user_cache[session_id] = (user, time.monotonic())


async def _fetch_user_with_retry(client, access_token: str, session_id: str) -> "User | None":
    """Call fetch_user, retry once after 1.5 s on transient failure."""
    for attempt in range(2):
        try:
            return await asyncio.to_thread(client.fetch_user, access_token)
        except Exception:
            if attempt == 0:
                log.debug(
                    "auth.fetch_user_retry",
                    session_id=session_id[:8],
                    attempt=attempt + 1,
                )
                await asyncio.sleep(1.5)
            else:
                log.warning("auth.fetch_user_failed", session_id=session_id[:8])
    return None


async def _resolve_user(conn) -> User | None:
    """Shared logic to resolve user from a connection object (Request or WebSocket).

    Any failure during resolution is silently treated as anonymous to avoid
    crashing WebSocket upgrades or other dependency-injected paths.
    """
    try:
        log.debug("auth.resolve_start", conn_type=type(conn).__name__)
        if conn is None:
            log.debug("auth.resolve_no_conn")
            return None

        # Return cached user if already resolved this request
        if hasattr(conn.state, "user") and conn.state.user is not None:
            log.debug("auth.resolve_cached", user_id=conn.state.user.id)
            return conn.state.user

        # Auth not configured — treat as anonymous
        if not settings.gnauth_client_id or not settings.gnauth_client_secret:
            log.debug("auth.resolve_not_configured")
            return None

        cookie_name = settings.navi_auth_cookie_name
        session_id = conn.cookies.get(cookie_name)
        if not session_id:
            log.debug("auth.resolve_no_cookie", cookie_name=cookie_name)
            # Try API token fallback before giving up
            api_user = await _resolve_user_from_api_token(conn)
            if api_user is not None:
                conn.state.user = api_user
            return api_user

        log.debug("auth.resolve_cookie_found", session_id=session_id[:8])

        # Look up the auth session in DB
        try:
            from navi.api.deps import get_session_store
        except Exception:
            # Avoid circular import during early bootstrap
            log.debug("auth.resolve_store_import_failed")
            return None

        store = get_session_store()
        row = await _get_auth_session(store, session_id)
        if row is None:
            log.debug("auth.resolve_session_not_found", session_id=session_id[:8])
            return None

        log.debug("auth.resolve_session_found", user_id=row["user_id"])

        encryptor = get_encryptor()
        access_token = encryptor.decrypt(row["access_token_enc"])
        expires_at = row["expires_at"]

        client = get_gauth_client()

        # Refresh if expired
        if datetime.now(timezone.utc) > expires_at:
            lock = _refresh_locks.setdefault(session_id, asyncio.Lock())
            async with lock:
                # Re-read session inside lock — another request may have refreshed it already
                row = await _get_auth_session(store, session_id)
                if row is None:
                    log.debug("auth.resolve_session_gone_during_refresh", session_id=session_id[:8])
                    return None
                expires_at = row["expires_at"]
                if datetime.now(timezone.utc) <= expires_at:
                    access_token = encryptor.decrypt(row["access_token_enc"])
                else:
                    try:
                        refresh_token = encryptor.decrypt(row["refresh_token_enc"])
                        token_set = await asyncio.to_thread(client.refresh_token, refresh_token)
                        access_token = token_set.access_token
                        await _update_auth_session(
                            store,
                            session_id,
                            encryptor.encrypt(access_token),
                            encryptor.encrypt(token_set.refresh_token or refresh_token),
                            token_set.expires_at or datetime.now(timezone.utc),
                        )
                        log.info("auth.token_refreshed", user_id=row["user_id"])
                    except TokenRefreshException as exc:
                        # Refresh token is definitively invalid (expired, revoked,
                        # rotated by another device).  Force re-login.
                        log.warning(
                            "auth.refresh_token_invalid",
                            session_id=session_id[:8],
                            user_id=row["user_id"],
                            reason=str(exc),
                        )
                        # Try API token before giving up.
                        api_user = await _resolve_user_from_api_token(conn)
                        if api_user is not None:
                            conn.state.user = api_user
                        return api_user
                    except Exception as exc:
                        # Transient errors (network, timeout, 5xx from auth server).
                        # Do NOT delete the session — force-logout is too harsh.
                        # Best-effort: serve from cache if available, else try API token.
                        log.warning(
                            "auth.refresh_transient_fail",
                            session_id=session_id[:8],
                            user_id=row["user_id"],
                            exc_type=type(exc).__name__,
                            error=str(exc),
                        )
                        cached = _get_cached_user(session_id)
                        if cached is not None:
                            log.debug("auth.fallback_to_cached_user", user_id=cached.id)
                            conn.state.user = cached
                            return cached
                        api_user = await _resolve_user_from_api_token(conn)
                        if api_user is not None:
                            conn.state.user = api_user
                        return api_user

        # Check short-lived fetch_user cache before hitting gnexus-auth
        cached = _get_cached_user(session_id)
        if cached is not None:
            log.debug("auth.resolve_user_cache_hit", user_id=cached.id)
            conn.state.user = cached
            return cached

        # Fetch user from gnexus-auth (with one retry on transient failure)
        auth_user = await _fetch_user_with_retry(client, access_token, session_id)
        if auth_user is None:
            return None

        log.debug("auth.resolve_user_fetched", user_id=auth_user.user_id)

        # Determine role from client-level role_ids
        role = "user"
        permissions: list[str] = []
        for access in auth_user.client_access_list:
            if access.client_id == settings.gnauth_client_id:
                if settings.gnauth_admin_role_slug in (access.role_ids or []):
                    role = "admin"
                permissions = list(access.permission_ids or [])
                break

        # Upsert into navi_users
        profile = auth_user.profile
        await _upsert_navi_user(
            store, auth_user.user_id, auth_user.email,
            profile.get("display_name"), role, permissions,
            username=profile.get("username"),
            first_name=profile.get("first_name"),
            last_name=profile.get("last_name"),
            phone=profile.get("phone"),
            birth_date=profile.get("birth_date"),
            country=profile.get("country"),
            city=profile.get("city"),
            locale=profile.get("locale"),
        )

        profile = auth_user.profile
        user = User(
            id=auth_user.user_id,
            email=auth_user.email,
            display_name=profile.get("display_name") or auth_user.email,
            username=profile.get("username"),
            first_name=profile.get("first_name"),
            last_name=profile.get("last_name"),
            phone=profile.get("phone"),
            birth_date=profile.get("birth_date"),
            country=profile.get("country"),
            city=profile.get("city"),
            locale=profile.get("locale"),
            avatar_url=auth_user.avatar_url,
            role=role,
            permissions=permissions,
        )

        # Update last_used_at and cache
        await _touch_auth_session(store, session_id)
        _set_cached_user(session_id, user)
        conn.state.user = user
        log.debug("auth.resolve_success", user_id=user.id)
        return user
    except Exception:
        # Any unexpected failure during auth resolution should not crash the
        # request — treat as anonymous so WebSocket upgrades and REST calls
        # degrade gracefully.
        log.warning("auth.resolve_failed", exc_info=True)
        return None


# ── API token resolution helpers ────────────────────────────────────────────

async def _resolve_user_from_api_token(conn) -> User | None:
    """Resolve user from X-Api-Token header or ?api_token query parameter."""
    # Try header first, then query param (for WebSocket or REST)
    token = None
    try:
        token = conn.headers.get("X-Api-Token")
    except Exception:
        pass
    if not token:
        try:
            token = conn.query_params.get("api_token")
        except Exception:
            pass
    if not token:
        return None

    try:
        from navi.api.deps import get_session_store
    except Exception:
        return None

    store = get_session_store()
    pool = await store._get_pool()

    token_hash = hashlib.sha256(token.encode()).hexdigest()
    async with pool.acquire() as db_conn:
        user_row = await db_conn.fetchrow(
            "SELECT u.id, u.email, u.display_name, u.username, u.first_name, u.last_name, "
            "u.phone, u.birth_date, u.country, u.city, u.locale, u.role, u.permissions, "
            "t.id AS token_id, t.revoked_at "
            "FROM api_tokens t JOIN navi_users u ON u.id = t.user_id "
            "WHERE t.token_hash = $1",
            token_hash,
        )
    if user_row is None:
        log.debug("auth.api_token_not_found")
        return None
    if user_row["revoked_at"] is not None:
        log.debug("auth.api_token_revoked", token_id=user_row["token_id"])
        return None

    import json
    permissions = json.loads(user_row["permissions"] or "[]")
    user = User(
        id=user_row["id"],
        email=user_row["email"],
        display_name=user_row["display_name"] or user_row["email"],
        username=user_row["username"],
        first_name=user_row["first_name"],
        last_name=user_row["last_name"],
        phone=user_row["phone"],
        birth_date=user_row["birth_date"],
        country=user_row["country"],
        city=user_row["city"],
        locale=user_row["locale"],
        role=user_row["role"],
        permissions=permissions,
    )
    # Update last_used_at asynchronously (best-effort)
    try:
        async with pool.acquire() as db_conn:
            await db_conn.execute(
                "UPDATE api_tokens SET last_used_at = $1 WHERE id = $2",
                datetime.now(timezone.utc), user_row["token_id"],
            )
    except Exception:
        pass
    log.debug("auth.api_token_resolved", user_id=user.id)
    return user


async def get_current_user(request: Request) -> User | None:
    """Resolve the current user from the auth session cookie for REST requests."""
    return await _resolve_user(request)


async def get_current_user_ws(websocket) -> User | None:
    """Resolve the current user from the auth session cookie for WebSocket connections."""
    return await _resolve_user(websocket)


# ── Helpers that talk to the session store ──────────────────────────────────

async def _get_auth_session(store, session_id: str) -> dict | None:
    """Fetch a row from user_auth_sessions."""
    pool = await store._get_pool()
    async with pool.acquire() as conn:
        row = await conn.fetchrow(
            "SELECT user_id, access_token_enc, refresh_token_enc, expires_at "
            "FROM user_auth_sessions WHERE id = $1",
            session_id,
        )
    if row is None:
        return None
    return {
        "user_id": row["user_id"],
        "access_token_enc": row["access_token_enc"],
        "refresh_token_enc": row["refresh_token_enc"],
        "expires_at": row["expires_at"],
    }


async def _update_auth_session(
    store, session_id: str, access_token_enc: str, refresh_token_enc: str, expires_at: datetime
) -> None:
    pool = await store._get_pool()
    async with pool.acquire() as conn:
        await conn.execute(
            "UPDATE user_auth_sessions "
            "SET access_token_enc = $1, refresh_token_enc = $2, expires_at = $3, last_used_at = $4 "
            "WHERE id = $5",
            access_token_enc, refresh_token_enc, expires_at, datetime.now(timezone.utc), session_id,
        )


async def _delete_auth_session(store, session_id: str) -> None:
    pool = await store._get_pool()
    async with pool.acquire() as conn:
        await conn.execute("DELETE FROM user_auth_sessions WHERE id = $1", session_id)


async def _touch_auth_session(store, session_id: str) -> None:
    pool = await store._get_pool()
    async with pool.acquire() as conn:
        await conn.execute(
            "UPDATE user_auth_sessions SET last_used_at = $1 WHERE id = $2",
            datetime.now(timezone.utc), session_id,
        )


async def _upsert_navi_user(
    store, user_id: str, email: str, display_name: str | None, role: str, permissions: list[str],
    username: str | None = None, first_name: str | None = None, last_name: str | None = None,
    phone: str | None = None, birth_date: str | None = None, country: str | None = None,
    city: str | None = None, locale: str | None = None,
) -> None:
    pool = await store._get_pool()
    import json
    async with pool.acquire() as conn:
        await conn.execute(
            """INSERT INTO navi_users (
                   id, email, display_name, username, first_name, last_name,
                   phone, birth_date, country, city, locale,
                   role, permissions, created_at, updated_at
               )
               VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14, $14)
               ON CONFLICT (id) DO UPDATE
               SET email = EXCLUDED.email,
                   display_name = EXCLUDED.display_name,
                   username = EXCLUDED.username,
                   first_name = EXCLUDED.first_name,
                   last_name = EXCLUDED.last_name,
                   phone = EXCLUDED.phone,
                   birth_date = EXCLUDED.birth_date,
                   country = EXCLUDED.country,
                   city = EXCLUDED.city,
                   locale = EXCLUDED.locale,
                   role = EXCLUDED.role,
                   permissions = EXCLUDED.permissions,
                   updated_at = EXCLUDED.updated_at""",
            user_id, email, display_name or email, username, first_name, last_name,
            phone, birth_date, country, city, locale,
            role, json.dumps(permissions), datetime.now(timezone.utc),
        )


# ── FastAPI Depends helpers ─────────────────────────────────────────────────

async def require_user(user: Annotated[User | None, Depends(get_current_user)]) -> User:
    if user is None:
        raise HTTPException(status_code=401, detail="Authentication required")
    return user


async def require_admin(user: Annotated[User | None, Depends(get_current_user)]) -> User:
    if user is None:
        raise HTTPException(status_code=401, detail="Authentication required")
    if user.role != "admin":
        raise HTTPException(status_code=403, detail="Admin access required")
    return user


def require_permission(permission: str):
    async def _check(
        user: Annotated[User | None, Depends(get_current_user)],
    ) -> User:
        if user is None:
            raise HTTPException(status_code=401, detail="Authentication required")
        if user.role != "admin" and permission not in user.permissions:
            raise HTTPException(status_code=403, detail=f"Permission '{permission}' required")
        return user
    return _check


def check_session_access(session, user: User, permission: str | None = None) -> None:
    """Raise 403 if user does not own the session and lacks the required permission.

    Legacy sessions (user_id=None) are accessible only to admins.
    """
    from fastapi import HTTPException

    if session.user_id is None:
        if user.role != "admin":
            raise HTTPException(status_code=403, detail="Access denied")
        return
    if session.user_id == user.id:
        return
    if user.role == "admin":
        return
    if permission and user.has_permission(permission):
        return
    raise HTTPException(status_code=403, detail="Access denied")
