diff --git a/navi/auth/deps.py b/navi/auth/deps.py index 52e82f0..3ef9435 100644 --- a/navi/auth/deps.py +++ b/navi/auth/deps.py @@ -2,6 +2,7 @@ import asyncio import hashlib +import time from datetime import datetime, timezone from typing import Annotated @@ -19,6 +20,47 @@ # Per-session-id lock to prevent concurrent token refresh races. _refresh_locks: dict[str, asyncio.Lock] = {} +# In-memory cache for fetch_user results: session_id -> (User, timestamp). +# Reduces hammering gnexus-auth and protects against eventual-consistency +# flakes right after a token refresh. +_USER_CACHE_TTL: float = 30.0 +_user_cache: dict[str, tuple[User, float]] = {} + + +def _get_cached_user(session_id: str) -> User | None: + """Return a cached User if it is still fresh.""" + entry = _user_cache.get(session_id) + if entry is None: + return None + user, ts = entry + if time.monotonic() - ts > _USER_CACHE_TTL: + _user_cache.pop(session_id, None) + return None + return user + + +def _set_cached_user(session_id: str, user: User) -> None: + """Store a resolved User in the short-lived cache.""" + _user_cache[session_id] = (user, time.monotonic()) + + +async def _fetch_user_with_retry(client, access_token: str, session_id: str) -> "User | None": + """Call fetch_user, retry once after 1.5 s on transient failure.""" + for attempt in range(2): + try: + return await asyncio.to_thread(client.fetch_user, access_token) + except Exception: + if attempt == 0: + log.debug( + "auth.fetch_user_retry", + session_id=session_id[:8], + attempt=attempt + 1, + ) + await asyncio.sleep(1.5) + else: + log.warning("auth.fetch_user_failed", session_id=session_id[:8]) + return None + async def _resolve_user(conn) -> User | None: """Shared logic to resolve user from a connection object (Request or WebSocket). @@ -105,14 +147,23 @@ log.warning("auth.refresh_failed", session_id=session_id[:8]) # Do NOT delete the session — transient errors (network, # race-condition with parallel refresh) should not force - # a full re-login. Next request will retry. - return None + # a full re-login. + # Fallback: try API token before giving up. + api_user = await _resolve_user_from_api_token(conn) + if api_user is not None: + conn.state.user = api_user + return api_user - # Fetch user from gnexus-auth - try: - auth_user = await asyncio.to_thread(client.fetch_user, access_token) - except Exception: - log.warning("auth.fetch_user_failed", session_id=session_id[:8]) + # Check short-lived fetch_user cache before hitting gnexus-auth + cached = _get_cached_user(session_id) + if cached is not None: + log.debug("auth.resolve_user_cache_hit", user_id=cached.id) + conn.state.user = cached + return cached + + # Fetch user from gnexus-auth (with one retry on transient failure) + auth_user = await _fetch_user_with_retry(client, access_token, session_id) + if auth_user is None: return None log.debug("auth.resolve_user_fetched", user_id=auth_user.user_id) @@ -160,8 +211,9 @@ permissions=permissions, ) - # Update last_used_at + # Update last_used_at and cache await _touch_auth_session(store, session_id) + _set_cached_user(session_id, user) conn.state.user = user log.debug("auth.resolve_success", user_id=user.id) return user diff --git a/tests/unit/auth/test_deps.py b/tests/unit/auth/test_deps.py index d258a8d..d542275 100644 --- a/tests/unit/auth/test_deps.py +++ b/tests/unit/auth/test_deps.py @@ -7,6 +7,8 @@ from navi.auth import User from navi.auth.deps import ( + _refresh_locks, + _user_cache, check_session_access, get_current_user, require_admin, @@ -82,6 +84,9 @@ ) mock_get_encryptor.return_value = encryptor + _user_cache.clear() + _refresh_locks.clear() + yield { "settings": mock_settings, "get_client": mock_get_client, @@ -90,11 +95,12 @@ } -def _make_conn(user_id="u1", expired=False): +def _make_conn(user_id="u1", expired=False, extra_rows=0): """Build a FakeConnection with one auth-session row. For expired sessions we enqueue the same row twice because _resolve_user re-reads the session inside the refresh lock. + `extra_rows` enqueues additional copies for multi-call tests (cache, retry). """ conn = FakeConnection() enc = TokenEncryptor("hocAswNbSlUFITrAPnpv-3ky9EpiZBqZs0km73FR5nE=") @@ -108,6 +114,8 @@ conn.enqueue(row) if expired: conn.enqueue(row) + for _ in range(extra_rows): + conn.enqueue(row) return conn @@ -231,6 +239,71 @@ @pytest.mark.asyncio +async def test_fetch_user_retry_succeeds(_auth_env): + conn = _make_conn(user_id="u1", extra_rows=2) + _auth_env["get_store"].return_value = FakeSessionStore(conn) + + client = MagicMock() + client.fetch_user.side_effect = [ + Exception("transient fail"), + FakeAuthUser(user_id="u1", email="u@test.com", display_name="User"), + ] + _auth_env["get_client"].return_value = client + + req = FakeRequest(cookies={"navi_session": "sess1"}) + user = await get_current_user(req) + assert user is not None + assert user.id == "u1" + assert client.fetch_user.call_count == 2 + + +@pytest.mark.asyncio +async def test_user_cache_avoids_second_fetch_user(_auth_env): + conn = _make_conn(user_id="u1", extra_rows=3) + _auth_env["get_store"].return_value = FakeSessionStore(conn) + + client = MagicMock() + client.fetch_user.return_value = FakeAuthUser( + user_id="u1", email="u@test.com", display_name="User" + ) + _auth_env["get_client"].return_value = client + + req1 = FakeRequest(cookies={"navi_session": "sess1"}) + user1 = await get_current_user(req1) + assert user1 is not None + assert client.fetch_user.call_count == 1 + + req2 = FakeRequest(cookies={"navi_session": "sess1"}) + user2 = await get_current_user(req2) + assert user2 is not None + assert user2.id == "u1" + assert client.fetch_user.call_count == 1 # cache hit + + +@pytest.mark.asyncio +async def test_user_cache_expires(_auth_env): + with patch("navi.auth.deps._USER_CACHE_TTL", -1): + conn = _make_conn(user_id="u1", extra_rows=3) + _auth_env["get_store"].return_value = FakeSessionStore(conn) + + client = MagicMock() + client.fetch_user.return_value = FakeAuthUser( + user_id="u1", email="u@test.com", display_name="User" + ) + _auth_env["get_client"].return_value = client + + req1 = FakeRequest(cookies={"navi_session": "sess1"}) + user1 = await get_current_user(req1) + assert user1 is not None + assert client.fetch_user.call_count == 1 + + req2 = FakeRequest(cookies={"navi_session": "sess1"}) + user2 = await get_current_user(req2) + assert user2 is not None + assert client.fetch_user.call_count == 2 # cache expired + + +@pytest.mark.asyncio async def test_admin_role_detected(_auth_env): conn = _make_conn(user_id="admin1") _auth_env["get_store"].return_value = FakeSessionStore(conn) diff --git a/tests/unit/core/test_registry.py b/tests/unit/core/test_registry.py index bf92eec..4801f25 100644 --- a/tests/unit/core/test_registry.py +++ b/tests/unit/core/test_registry.py @@ -100,6 +100,7 @@ llm_stream_first_chunk_timeout=180, ) monkeypatch.setattr(registry_mod, "settings", test_settings) + monkeypatch.setattr("navi.config.settings", test_settings) monkeypatch.setattr(fallback_mod, "OllamaBackend", FakeOllamaBackend) discovered = registry_mod._discover_backends()