"""Unit tests for navi.auth.deps — user resolution, role checks, permission guards."""
from datetime import datetime, timedelta, timezone
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from navi.auth import User
from navi.auth.deps import (
_ANONYMOUS_USER,
_refresh_locks,
_user_cache,
check_session_access,
get_current_user,
require_admin,
require_permission,
require_user,
)
from navi.auth.encrypt import TokenEncryptor
from tests.conftest_factory import FakeConnection, FakePool
class FakeRequest:
"""Minimal stand-in for FastAPI Request."""
def __init__(self, cookies=None):
self.cookies = cookies or {}
self.state = MagicMock()
self.state.user = None
class FakeSessionStore:
"""Stand-in that returns a FakePool via _get_pool()."""
def __init__(self, conn=None):
self._pool = FakePool(conn)
async def _get_pool(self):
return self._pool
class FakeAuthUser:
"""Stand-in for gnexus-auth user object."""
def __init__(self, user_id, email, role_ids=None, permission_ids=None, **profile):
self.user_id = user_id
self.email = email
self.avatar_url = None
self.profile = profile
access = MagicMock()
access.client_id = "test-client-id"
access.role_ids = role_ids or []
access.permission_ids = permission_ids or []
self.client_access_list = [access]
class FakeTokenSet:
def __init__(self, access_token, refresh_token=None, expires_at=None):
self.access_token = access_token
self.refresh_token = refresh_token
self.expires_at = expires_at
# ── Fixtures ────────────────────────────────────────────────────────────────
@pytest.fixture(autouse=True)
def _auth_env():
"""Patch settings and singletons used by auth deps."""
with (
patch("navi.auth.deps.settings") as mock_settings,
patch("navi.auth.deps.get_gauth_client") as mock_get_client,
patch("navi.auth.deps.get_encryptor") as mock_get_encryptor,
patch("navi.api.deps.get_session_store") as mock_get_store,
patch("navi.auth.deps.asyncio.to_thread", side_effect=lambda f, *a, **k: f(*a, **k)),
):
mock_settings.navi_auth_enabled = True
mock_settings.gnauth_client_id = "test-client-id"
mock_settings.gnauth_client_secret = "test-secret"
mock_settings.navi_auth_cookie_name = "navi_session"
mock_settings.gnauth_admin_role_slug = "admin"
mock_settings.gnauth_base_url = "https://auth.test"
mock_settings.gnauth_redirect_uri = "https://navi.test/callback"
encryptor = TokenEncryptor(
"hocAswNbSlUFITrAPnpv-3ky9EpiZBqZs0km73FR5nE="
)
mock_get_encryptor.return_value = encryptor
_user_cache.clear()
_refresh_locks.clear()
yield {
"settings": mock_settings,
"get_client": mock_get_client,
"get_encryptor": mock_get_encryptor,
"get_store": mock_get_store,
}
def _make_conn(user_id="u1", expired=False, extra_rows=0):
"""Build a FakeConnection with one auth-session row.
For expired sessions we enqueue the same row twice because _resolve_user
re-reads the session inside the refresh lock.
`extra_rows` enqueues additional copies for multi-call tests (cache, retry).
"""
conn = FakeConnection()
enc = TokenEncryptor("hocAswNbSlUFITrAPnpv-3ky9EpiZBqZs0km73FR5nE=")
expires = datetime.now(timezone.utc) - timedelta(hours=1) if expired else datetime.now(timezone.utc) + timedelta(hours=1)
row = {
"user_id": user_id,
"access_token_enc": enc.encrypt("access-token"),
"refresh_token_enc": enc.encrypt("refresh-token"),
"expires_at": expires,
}
conn.enqueue(row)
if expired:
conn.enqueue(row)
for _ in range(extra_rows):
conn.enqueue(row)
return conn
# ── get_current_user tests ──────────────────────────────────────────────────
@pytest.mark.asyncio
async def test_unconfigured_returns_none(_auth_env):
_auth_env["settings"].gnauth_client_id = None
req = FakeRequest()
user = await get_current_user(req)
assert user is None
@pytest.mark.asyncio
async def test_no_cookie_returns_none(_auth_env):
req = FakeRequest()
user = await get_current_user(req)
assert user is None
@pytest.mark.asyncio
async def test_invalid_session_returns_none(_auth_env):
conn = FakeConnection()
conn.enqueue(None) # fetchrow returns None
_auth_env["get_store"].return_value = FakeSessionStore(conn)
req = FakeRequest(cookies={"navi_session": "bad-sess"})
user = await get_current_user(req)
assert user is None
@pytest.mark.asyncio
async def test_expired_token_refreshes(_auth_env):
conn = _make_conn(user_id="u1", expired=True)
_auth_env["get_store"].return_value = FakeSessionStore(conn)
client = MagicMock()
client.refresh_token.return_value = FakeTokenSet(
access_token="new-access",
refresh_token="new-refresh",
expires_at=datetime.now(timezone.utc) + timedelta(hours=1),
)
client.fetch_user.return_value = FakeAuthUser(
user_id="u1", email="u@test.com", display_name="User"
)
_auth_env["get_client"].return_value = client
req = FakeRequest(cookies={"navi_session": "sess1"})
user = await get_current_user(req)
assert user is not None
assert user.id == "u1"
# refresh_token called
assert client.refresh_token.called
@pytest.mark.asyncio
async def test_refresh_failure_returns_none_without_delete(_auth_env):
conn = _make_conn(user_id="u1", expired=True)
_auth_env["get_store"].return_value = FakeSessionStore(conn)
client = MagicMock()
client.refresh_token.side_effect = Exception("refresh failed")
_auth_env["get_client"].return_value = client
req = FakeRequest(cookies={"navi_session": "sess1"})
user = await get_current_user(req)
assert user is None
# Session must NOT be deleted on transient refresh failure
delete_calls = [c for c in conn.calls if c[0] == "execute" and "DELETE" in c[1]]
assert len(delete_calls) == 0
@pytest.mark.asyncio
async def test_refresh_skipped_when_already_refreshed_by_parallel_request(_auth_env):
"""If another request refreshed the token while we waited for the lock,
we should re-read the session and skip the refresh call."""
enc = TokenEncryptor("hocAswNbSlUFITrAPnpv-3ky9EpiZBqZs0km73FR5nE=")
conn = FakeConnection()
conn.enqueue({
"user_id": "u1",
"access_token_enc": enc.encrypt("old-access"),
"refresh_token_enc": enc.encrypt("refresh-token"),
"expires_at": datetime.now(timezone.utc) - timedelta(hours=1),
})
conn.enqueue({
"user_id": "u1",
"access_token_enc": enc.encrypt("new-access"),
"refresh_token_enc": enc.encrypt("refresh-token"),
"expires_at": datetime.now(timezone.utc) + timedelta(hours=1),
})
_auth_env["get_store"].return_value = FakeSessionStore(conn)
client = MagicMock()
client.fetch_user.return_value = FakeAuthUser(
user_id="u1", email="u@test.com", display_name="User"
)
_auth_env["get_client"].return_value = client
req = FakeRequest(cookies={"navi_session": "sess1"})
user = await get_current_user(req)
assert user is not None
assert user.id == "u1"
# refresh_token should NOT have been called because the second
# fetchrow showed an un-expired session
assert not client.refresh_token.called
@pytest.mark.asyncio
async def test_fetch_user_failure_returns_none(_auth_env):
conn = _make_conn(user_id="u1")
_auth_env["get_store"].return_value = FakeSessionStore(conn)
client = MagicMock()
client.fetch_user.side_effect = Exception("fetch failed")
_auth_env["get_client"].return_value = client
req = FakeRequest(cookies={"navi_session": "sess1"})
user = await get_current_user(req)
assert user is None
@pytest.mark.asyncio
async def test_fetch_user_retry_succeeds(_auth_env):
conn = _make_conn(user_id="u1", extra_rows=2)
_auth_env["get_store"].return_value = FakeSessionStore(conn)
client = MagicMock()
client.fetch_user.side_effect = [
Exception("transient fail"),
FakeAuthUser(user_id="u1", email="u@test.com", display_name="User"),
]
_auth_env["get_client"].return_value = client
req = FakeRequest(cookies={"navi_session": "sess1"})
user = await get_current_user(req)
assert user is not None
assert user.id == "u1"
assert client.fetch_user.call_count == 2
@pytest.mark.asyncio
async def test_user_cache_avoids_second_fetch_user(_auth_env):
conn = _make_conn(user_id="u1", extra_rows=3)
_auth_env["get_store"].return_value = FakeSessionStore(conn)
client = MagicMock()
client.fetch_user.return_value = FakeAuthUser(
user_id="u1", email="u@test.com", display_name="User"
)
_auth_env["get_client"].return_value = client
req1 = FakeRequest(cookies={"navi_session": "sess1"})
user1 = await get_current_user(req1)
assert user1 is not None
assert client.fetch_user.call_count == 1
req2 = FakeRequest(cookies={"navi_session": "sess1"})
user2 = await get_current_user(req2)
assert user2 is not None
assert user2.id == "u1"
assert client.fetch_user.call_count == 1 # cache hit
@pytest.mark.asyncio
async def test_user_cache_expires(_auth_env):
with patch("navi.auth.deps._USER_CACHE_TTL", -1):
conn = _make_conn(user_id="u1", extra_rows=3)
_auth_env["get_store"].return_value = FakeSessionStore(conn)
client = MagicMock()
client.fetch_user.return_value = FakeAuthUser(
user_id="u1", email="u@test.com", display_name="User"
)
_auth_env["get_client"].return_value = client
req1 = FakeRequest(cookies={"navi_session": "sess1"})
user1 = await get_current_user(req1)
assert user1 is not None
assert client.fetch_user.call_count == 1
req2 = FakeRequest(cookies={"navi_session": "sess1"})
user2 = await get_current_user(req2)
assert user2 is not None
assert client.fetch_user.call_count == 2 # cache expired
@pytest.mark.asyncio
async def test_admin_role_detected(_auth_env):
conn = _make_conn(user_id="admin1")
_auth_env["get_store"].return_value = FakeSessionStore(conn)
client = MagicMock()
client.fetch_user.return_value = FakeAuthUser(
user_id="admin1",
email="admin@test.com",
role_ids=["admin"],
)
_auth_env["get_client"].return_value = client
req = FakeRequest(cookies={"navi_session": "sess1"})
user = await get_current_user(req)
assert user.role == "admin"
@pytest.mark.asyncio
async def test_user_role_default(_auth_env):
conn = _make_conn(user_id="u1")
_auth_env["get_store"].return_value = FakeSessionStore(conn)
client = MagicMock()
client.fetch_user.return_value = FakeAuthUser(
user_id="u1", email="u@test.com", role_ids=[]
)
_auth_env["get_client"].return_value = client
req = FakeRequest(cookies={"navi_session": "sess1"})
user = await get_current_user(req)
assert user.role == "user"
# ── WebSocket auth-disabled test ────────────────────────────────────────────
class FakeWebSocket:
"""Minimal stand-in for FastAPI WebSocket."""
def __init__(self, cookies=None):
self.cookies = cookies or {}
self.state = MagicMock()
self.state.user = None
@pytest.mark.asyncio
async def test_get_current_user_ws_returns_anonymous_when_disabled(_auth_env):
_auth_env["settings"].navi_auth_enabled = False
from navi.auth.deps import get_current_user_ws
ws = FakeWebSocket()
user = await get_current_user_ws(ws)
assert user.id == "anonymous"
assert user.role == "admin"
# ── Shared fake used by check_session_access tests ───────────────────────────
class FakeChatSession:
def __init__(self, user_id):
self.user_id = user_id
# ── require_user / require_admin tests ──────────────────────────────────────
@pytest.mark.asyncio
async def test_require_user_raises_401():
with pytest.raises(Exception) as exc_info:
await require_user(None)
assert exc_info.value.status_code == 401
@pytest.mark.asyncio
async def test_require_admin_raises_403():
user = User(id="u1", email="u@test.com", role="user")
with pytest.raises(Exception) as exc_info:
await require_admin(user)
assert exc_info.value.status_code == 403
# ── auth-disabled tests ───────────────────────────────────────────────────────
@pytest.mark.asyncio
async def test_auth_disabled_returns_anonymous(_auth_env):
_auth_env["settings"].navi_auth_enabled = False
req = FakeRequest()
user = await get_current_user(req)
assert user.id == "anonymous"
assert user.role == "admin"
@pytest.mark.asyncio
async def test_require_user_returns_anonymous_when_disabled(_auth_env):
_auth_env["settings"].navi_auth_enabled = False
user = await require_user(None)
assert user.id == "anonymous"
assert user.role == "admin"
@pytest.mark.asyncio
async def test_require_admin_returns_anonymous_when_disabled(_auth_env):
_auth_env["settings"].navi_auth_enabled = False
user = await require_admin(None)
assert user.id == "anonymous"
assert user.role == "admin"
@pytest.mark.asyncio
async def test_require_permission_returns_anonymous_when_disabled(_auth_env):
_auth_env["settings"].navi_auth_enabled = False
guard = require_permission("navi.sessions.read_all")
user = await guard(None)
assert user.id == "anonymous"
assert user.role == "admin"
@pytest.mark.asyncio
async def test_check_session_access_allows_everything_when_disabled(_auth_env):
_auth_env["settings"].navi_auth_enabled = False
user = User(id="u1", email="u@test.com", role="user")
session = FakeChatSession(user_id="u2")
# should not raise even for non-owner non-admin
check_session_access(session, user)
@pytest.mark.asyncio
async def test_check_session_access_owner():
user = User(id="u1", email="u@test.com")
session = FakeChatSession(user_id="u1")
# should not raise
check_session_access(session, user)
@pytest.mark.asyncio
async def test_check_session_access_admin_bypass():
user = User(id="admin1", email="admin@test.com", role="admin")
session = FakeChatSession(user_id="u2")
check_session_access(session, user)
@pytest.mark.asyncio
async def test_check_session_access_legacy_anonymous():
user = User(id="admin1", email="admin@test.com", role="admin")
session = FakeChatSession(user_id=None)
check_session_access(session, user)
user2 = User(id="u1", email="u@test.com", role="user")
with pytest.raises(Exception) as exc_info:
check_session_access(session, user2)
assert exc_info.value.status_code == 403
@pytest.mark.asyncio
async def test_check_session_access_denied():
user = User(id="u1", email="u@test.com", role="user")
session = FakeChatSession(user_id="u2")
with pytest.raises(Exception) as exc_info:
check_session_access(session, user)
assert exc_info.value.status_code == 403