"""Unit tests for navi.auth.deps — user resolution, role checks, permission guards."""

from datetime import datetime, timedelta, timezone
from unittest.mock import AsyncMock, MagicMock, patch

import pytest

from navi.auth import User
from navi.auth.deps import (
    _ANONYMOUS_USER,
    _refresh_locks,
    _user_cache,
    check_session_access,
    get_current_user,
    require_admin,
    require_permission,
    require_user,
)
from navi.auth.encrypt import TokenEncryptor
from tests.conftest_factory import FakeConnection, FakePool


class FakeRequest:
    """Minimal stand-in for FastAPI Request."""

    def __init__(self, cookies=None):
        self.cookies = cookies or {}
        self.state = MagicMock()
        self.state.user = None


class FakeSessionStore:
    """Stand-in that returns a FakePool via _get_pool()."""

    def __init__(self, conn=None):
        self._pool = FakePool(conn)

    async def _get_pool(self):
        return self._pool


class FakeAuthUser:
    """Stand-in for gnexus-auth user object."""

    def __init__(self, user_id, email, role_ids=None, permission_ids=None, **profile):
        self.user_id = user_id
        self.email = email
        self.avatar_url = None
        self.profile = profile
        access = MagicMock()
        access.client_id = "test-client-id"
        access.role_ids = role_ids or []
        access.permission_ids = permission_ids or []
        self.client_access_list = [access]


class FakeTokenSet:
    def __init__(self, access_token, refresh_token=None, expires_at=None):
        self.access_token = access_token
        self.refresh_token = refresh_token
        self.expires_at = expires_at


# ── Fixtures ────────────────────────────────────────────────────────────────


@pytest.fixture(autouse=True)
def _auth_env():
    """Patch settings and singletons used by auth deps."""
    with (
        patch("navi.auth.deps.settings") as mock_settings,
        patch("navi.auth.deps.get_gauth_client") as mock_get_client,
        patch("navi.auth.deps.get_encryptor") as mock_get_encryptor,
        patch("navi.api.deps.get_session_store") as mock_get_store,
        patch("navi.auth.deps.asyncio.to_thread", side_effect=lambda f, *a, **k: f(*a, **k)),
    ):
        mock_settings.navi_auth_enabled = True
        mock_settings.gnauth_client_id = "test-client-id"
        mock_settings.gnauth_client_secret = "test-secret"
        mock_settings.navi_auth_cookie_name = "navi_session"
        mock_settings.gnauth_admin_role_slug = "admin"
        mock_settings.gnauth_base_url = "https://auth.test"
        mock_settings.gnauth_redirect_uri = "https://navi.test/callback"

        encryptor = TokenEncryptor(
            "hocAswNbSlUFITrAPnpv-3ky9EpiZBqZs0km73FR5nE="
        )
        mock_get_encryptor.return_value = encryptor

        _user_cache.clear()
        _refresh_locks.clear()

        yield {
            "settings": mock_settings,
            "get_client": mock_get_client,
            "get_encryptor": mock_get_encryptor,
            "get_store": mock_get_store,
        }


def _make_conn(user_id="u1", expired=False, extra_rows=0):
    """Build a FakeConnection with one auth-session row.

    For expired sessions we enqueue the same row twice because _resolve_user
    re-reads the session inside the refresh lock.
    `extra_rows` enqueues additional copies for multi-call tests (cache, retry).
    """
    conn = FakeConnection()
    enc = TokenEncryptor("hocAswNbSlUFITrAPnpv-3ky9EpiZBqZs0km73FR5nE=")
    expires = datetime.now(timezone.utc) - timedelta(hours=1) if expired else datetime.now(timezone.utc) + timedelta(hours=1)
    row = {
        "user_id": user_id,
        "access_token_enc": enc.encrypt("access-token"),
        "refresh_token_enc": enc.encrypt("refresh-token"),
        "expires_at": expires,
    }
    conn.enqueue(row)
    if expired:
        conn.enqueue(row)
    for _ in range(extra_rows):
        conn.enqueue(row)
    return conn


# ── get_current_user tests ──────────────────────────────────────────────────


@pytest.mark.asyncio
async def test_unconfigured_returns_none(_auth_env):
    _auth_env["settings"].gnauth_client_id = None
    req = FakeRequest()
    user = await get_current_user(req)
    assert user is None


@pytest.mark.asyncio
async def test_no_cookie_returns_none(_auth_env):
    req = FakeRequest()
    user = await get_current_user(req)
    assert user is None


@pytest.mark.asyncio
async def test_invalid_session_returns_none(_auth_env):
    conn = FakeConnection()
    conn.enqueue(None)  # fetchrow returns None
    _auth_env["get_store"].return_value = FakeSessionStore(conn)

    req = FakeRequest(cookies={"navi_session": "bad-sess"})
    user = await get_current_user(req)
    assert user is None


@pytest.mark.asyncio
async def test_expired_token_refreshes(_auth_env):
    conn = _make_conn(user_id="u1", expired=True)
    _auth_env["get_store"].return_value = FakeSessionStore(conn)

    client = MagicMock()
    client.refresh_token.return_value = FakeTokenSet(
        access_token="new-access",
        refresh_token="new-refresh",
        expires_at=datetime.now(timezone.utc) + timedelta(hours=1),
    )
    client.fetch_user.return_value = FakeAuthUser(
        user_id="u1", email="u@test.com", display_name="User"
    )
    _auth_env["get_client"].return_value = client

    req = FakeRequest(cookies={"navi_session": "sess1"})
    user = await get_current_user(req)
    assert user is not None
    assert user.id == "u1"
    # refresh_token called
    assert client.refresh_token.called


@pytest.mark.asyncio
async def test_refresh_failure_returns_none_without_delete(_auth_env):
    conn = _make_conn(user_id="u1", expired=True)
    _auth_env["get_store"].return_value = FakeSessionStore(conn)

    client = MagicMock()
    client.refresh_token.side_effect = Exception("refresh failed")
    _auth_env["get_client"].return_value = client

    req = FakeRequest(cookies={"navi_session": "sess1"})
    user = await get_current_user(req)
    assert user is None
    # Session must NOT be deleted on transient refresh failure
    delete_calls = [c for c in conn.calls if c[0] == "execute" and "DELETE" in c[1]]
    assert len(delete_calls) == 0


@pytest.mark.asyncio
async def test_refresh_skipped_when_already_refreshed_by_parallel_request(_auth_env):
    """If another request refreshed the token while we waited for the lock,
    we should re-read the session and skip the refresh call."""
    enc = TokenEncryptor("hocAswNbSlUFITrAPnpv-3ky9EpiZBqZs0km73FR5nE=")
    conn = FakeConnection()
    conn.enqueue({
        "user_id": "u1",
        "access_token_enc": enc.encrypt("old-access"),
        "refresh_token_enc": enc.encrypt("refresh-token"),
        "expires_at": datetime.now(timezone.utc) - timedelta(hours=1),
    })
    conn.enqueue({
        "user_id": "u1",
        "access_token_enc": enc.encrypt("new-access"),
        "refresh_token_enc": enc.encrypt("refresh-token"),
        "expires_at": datetime.now(timezone.utc) + timedelta(hours=1),
    })
    _auth_env["get_store"].return_value = FakeSessionStore(conn)

    client = MagicMock()
    client.fetch_user.return_value = FakeAuthUser(
        user_id="u1", email="u@test.com", display_name="User"
    )
    _auth_env["get_client"].return_value = client

    req = FakeRequest(cookies={"navi_session": "sess1"})
    user = await get_current_user(req)
    assert user is not None
    assert user.id == "u1"
    # refresh_token should NOT have been called because the second
    # fetchrow showed an un-expired session
    assert not client.refresh_token.called


@pytest.mark.asyncio
async def test_fetch_user_failure_returns_none(_auth_env):
    conn = _make_conn(user_id="u1")
    _auth_env["get_store"].return_value = FakeSessionStore(conn)

    client = MagicMock()
    client.fetch_user.side_effect = Exception("fetch failed")
    _auth_env["get_client"].return_value = client

    req = FakeRequest(cookies={"navi_session": "sess1"})
    user = await get_current_user(req)
    assert user is None


@pytest.mark.asyncio
async def test_fetch_user_retry_succeeds(_auth_env):
    conn = _make_conn(user_id="u1", extra_rows=2)
    _auth_env["get_store"].return_value = FakeSessionStore(conn)

    client = MagicMock()
    client.fetch_user.side_effect = [
        Exception("transient fail"),
        FakeAuthUser(user_id="u1", email="u@test.com", display_name="User"),
    ]
    _auth_env["get_client"].return_value = client

    req = FakeRequest(cookies={"navi_session": "sess1"})
    user = await get_current_user(req)
    assert user is not None
    assert user.id == "u1"
    assert client.fetch_user.call_count == 2


@pytest.mark.asyncio
async def test_user_cache_avoids_second_fetch_user(_auth_env):
    conn = _make_conn(user_id="u1", extra_rows=3)
    _auth_env["get_store"].return_value = FakeSessionStore(conn)

    client = MagicMock()
    client.fetch_user.return_value = FakeAuthUser(
        user_id="u1", email="u@test.com", display_name="User"
    )
    _auth_env["get_client"].return_value = client

    req1 = FakeRequest(cookies={"navi_session": "sess1"})
    user1 = await get_current_user(req1)
    assert user1 is not None
    assert client.fetch_user.call_count == 1

    req2 = FakeRequest(cookies={"navi_session": "sess1"})
    user2 = await get_current_user(req2)
    assert user2 is not None
    assert user2.id == "u1"
    assert client.fetch_user.call_count == 1  # cache hit


@pytest.mark.asyncio
async def test_user_cache_expires(_auth_env):
    with patch("navi.auth.deps._USER_CACHE_TTL", -1):
        conn = _make_conn(user_id="u1", extra_rows=3)
        _auth_env["get_store"].return_value = FakeSessionStore(conn)

        client = MagicMock()
        client.fetch_user.return_value = FakeAuthUser(
            user_id="u1", email="u@test.com", display_name="User"
        )
        _auth_env["get_client"].return_value = client

        req1 = FakeRequest(cookies={"navi_session": "sess1"})
        user1 = await get_current_user(req1)
        assert user1 is not None
        assert client.fetch_user.call_count == 1

        req2 = FakeRequest(cookies={"navi_session": "sess1"})
        user2 = await get_current_user(req2)
        assert user2 is not None
        assert client.fetch_user.call_count == 2  # cache expired


@pytest.mark.asyncio
async def test_admin_role_detected(_auth_env):
    conn = _make_conn(user_id="admin1")
    _auth_env["get_store"].return_value = FakeSessionStore(conn)

    client = MagicMock()
    client.fetch_user.return_value = FakeAuthUser(
        user_id="admin1",
        email="admin@test.com",
        role_ids=["admin"],
    )
    _auth_env["get_client"].return_value = client

    req = FakeRequest(cookies={"navi_session": "sess1"})
    user = await get_current_user(req)
    assert user.role == "admin"


@pytest.mark.asyncio
async def test_user_role_default(_auth_env):
    conn = _make_conn(user_id="u1")
    _auth_env["get_store"].return_value = FakeSessionStore(conn)

    client = MagicMock()
    client.fetch_user.return_value = FakeAuthUser(
        user_id="u1", email="u@test.com", role_ids=[]
    )
    _auth_env["get_client"].return_value = client

    req = FakeRequest(cookies={"navi_session": "sess1"})
    user = await get_current_user(req)
    assert user.role == "user"


# ── require_user / require_admin tests ──────────────────────────────────────


@pytest.mark.asyncio
async def test_require_user_raises_401():
    with pytest.raises(Exception) as exc_info:
        await require_user(None)
    assert exc_info.value.status_code == 401


@pytest.mark.asyncio
async def test_require_admin_raises_403():
    user = User(id="u1", email="u@test.com", role="user")
    with pytest.raises(Exception) as exc_info:
        await require_admin(user)
    assert exc_info.value.status_code == 403


# ── auth-disabled tests ───────────────────────────────────────────────────────


@pytest.mark.asyncio
async def test_auth_disabled_returns_anonymous(_auth_env):
    _auth_env["settings"].navi_auth_enabled = False
    req = FakeRequest()
    user = await get_current_user(req)
    assert user is _ANONYMOUS_USER
    assert user.id == "anonymous"
    assert user.role == "admin"


@pytest.mark.asyncio
async def test_require_user_returns_anonymous_when_disabled(_auth_env):
    _auth_env["settings"].navi_auth_enabled = False
    user = await require_user(None)
    assert user is _ANONYMOUS_USER


@pytest.mark.asyncio
async def test_require_admin_returns_anonymous_when_disabled(_auth_env):
    _auth_env["settings"].navi_auth_enabled = False
    user = await require_admin(None)
    assert user is _ANONYMOUS_USER


@pytest.mark.asyncio
async def test_require_permission_returns_anonymous_when_disabled(_auth_env):
    _auth_env["settings"].navi_auth_enabled = False
    guard = require_permission("navi.sessions.read_all")
    user = await guard(None)
    assert user is _ANONYMOUS_USER


@pytest.mark.asyncio
async def test_check_session_access_allows_everything_when_disabled(_auth_env):
    _auth_env["settings"].navi_auth_enabled = False
    user = User(id="u1", email="u@test.com", role="user")
    session = FakeChatSession(user_id="u2")
    # should not raise even for non-owner non-admin
    check_session_access(session, user)


# ── check_session_access tests ──────────────────────────────────────────────


class FakeChatSession:
    def __init__(self, user_id):
        self.user_id = user_id


@pytest.mark.asyncio
async def test_check_session_access_owner():
    user = User(id="u1", email="u@test.com")
    session = FakeChatSession(user_id="u1")
    # should not raise
    check_session_access(session, user)


@pytest.mark.asyncio
async def test_check_session_access_admin_bypass():
    user = User(id="admin1", email="admin@test.com", role="admin")
    session = FakeChatSession(user_id="u2")
    check_session_access(session, user)


@pytest.mark.asyncio
async def test_check_session_access_legacy_anonymous():
    user = User(id="admin1", email="admin@test.com", role="admin")
    session = FakeChatSession(user_id=None)
    check_session_access(session, user)

    user2 = User(id="u1", email="u@test.com", role="user")
    with pytest.raises(Exception) as exc_info:
        check_session_access(session, user2)
    assert exc_info.value.status_code == 403


@pytest.mark.asyncio
async def test_check_session_access_denied():
    user = User(id="u1", email="u@test.com", role="user")
    session = FakeChatSession(user_id="u2")
    with pytest.raises(Exception) as exc_info:
        check_session_access(session, user)
    assert exc_info.value.status_code == 403
