diff --git a/navi/auth/deps.py b/navi/auth/deps.py index 11c4a7c..7a51bf2 100644 --- a/navi/auth/deps.py +++ b/navi/auth/deps.py @@ -15,6 +15,9 @@ log = structlog.get_logger() +# Per-session-id lock to prevent concurrent token refresh races. +_refresh_locks: dict[str, asyncio.Lock] = {} + async def _resolve_user(conn) -> User | None: """Shared logic to resolve user from a connection object (Request or WebSocket). @@ -70,24 +73,35 @@ # Refresh if expired if datetime.now(timezone.utc) > expires_at: - try: - refresh_token = encryptor.decrypt(row["refresh_token_enc"]) - token_set = await asyncio.to_thread(client.refresh_token, refresh_token) - access_token = token_set.access_token - # Update DB with new tokens - await _update_auth_session( - store, - session_id, - encryptor.encrypt(access_token), - encryptor.encrypt(token_set.refresh_token or refresh_token), - token_set.expires_at or datetime.now(timezone.utc), - ) - log.info("auth.token_refreshed", user_id=row["user_id"]) - except Exception: - log.warning("auth.refresh_failed", session_id=session_id[:8]) - # Refresh failed — treat as unauthenticated - await _delete_auth_session(store, session_id) - return None + lock = _refresh_locks.setdefault(session_id, asyncio.Lock()) + async with lock: + # Re-read session inside lock — another request may have refreshed it already + row = await _get_auth_session(store, session_id) + if row is None: + log.debug("auth.resolve_session_gone_during_refresh", session_id=session_id[:8]) + return None + expires_at = row["expires_at"] + if datetime.now(timezone.utc) <= expires_at: + access_token = encryptor.decrypt(row["access_token_enc"]) + else: + try: + refresh_token = encryptor.decrypt(row["refresh_token_enc"]) + token_set = await asyncio.to_thread(client.refresh_token, refresh_token) + access_token = token_set.access_token + await _update_auth_session( + store, + session_id, + encryptor.encrypt(access_token), + encryptor.encrypt(token_set.refresh_token or refresh_token), + token_set.expires_at or datetime.now(timezone.utc), + ) + log.info("auth.token_refreshed", user_id=row["user_id"]) + except Exception: + log.warning("auth.refresh_failed", session_id=session_id[:8]) + # Do NOT delete the session — transient errors (network, + # race-condition with parallel refresh) should not force + # a full re-login. Next request will retry. + return None # Fetch user from gnexus-auth try: diff --git a/tests/unit/auth/test_deps.py b/tests/unit/auth/test_deps.py index 5ee1f72..d258a8d 100644 --- a/tests/unit/auth/test_deps.py +++ b/tests/unit/auth/test_deps.py @@ -91,16 +91,23 @@ def _make_conn(user_id="u1", expired=False): - """Build a FakeConnection with one auth-session row.""" + """Build a FakeConnection with one auth-session row. + + For expired sessions we enqueue the same row twice because _resolve_user + re-reads the session inside the refresh lock. + """ conn = FakeConnection() enc = TokenEncryptor("hocAswNbSlUFITrAPnpv-3ky9EpiZBqZs0km73FR5nE=") expires = datetime.now(timezone.utc) - timedelta(hours=1) if expired else datetime.now(timezone.utc) + timedelta(hours=1) - conn.enqueue({ + row = { "user_id": user_id, "access_token_enc": enc.encrypt("access-token"), "refresh_token_enc": enc.encrypt("refresh-token"), "expires_at": expires, - }) + } + conn.enqueue(row) + if expired: + conn.enqueue(row) return conn @@ -158,7 +165,7 @@ @pytest.mark.asyncio -async def test_refresh_failure_deletes_session(_auth_env): +async def test_refresh_failure_returns_none_without_delete(_auth_env): conn = _make_conn(user_id="u1", expired=True) _auth_env["get_store"].return_value = FakeSessionStore(conn) @@ -169,9 +176,44 @@ req = FakeRequest(cookies={"navi_session": "sess1"}) user = await get_current_user(req) assert user is None - # Ensure DELETE was issued + # Session must NOT be deleted on transient refresh failure delete_calls = [c for c in conn.calls if c[0] == "execute" and "DELETE" in c[1]] - assert len(delete_calls) == 1 + assert len(delete_calls) == 0 + + +@pytest.mark.asyncio +async def test_refresh_skipped_when_already_refreshed_by_parallel_request(_auth_env): + """If another request refreshed the token while we waited for the lock, + we should re-read the session and skip the refresh call.""" + enc = TokenEncryptor("hocAswNbSlUFITrAPnpv-3ky9EpiZBqZs0km73FR5nE=") + conn = FakeConnection() + conn.enqueue({ + "user_id": "u1", + "access_token_enc": enc.encrypt("old-access"), + "refresh_token_enc": enc.encrypt("refresh-token"), + "expires_at": datetime.now(timezone.utc) - timedelta(hours=1), + }) + conn.enqueue({ + "user_id": "u1", + "access_token_enc": enc.encrypt("new-access"), + "refresh_token_enc": enc.encrypt("refresh-token"), + "expires_at": datetime.now(timezone.utc) + timedelta(hours=1), + }) + _auth_env["get_store"].return_value = FakeSessionStore(conn) + + client = MagicMock() + client.fetch_user.return_value = FakeAuthUser( + user_id="u1", email="u@test.com", display_name="User" + ) + _auth_env["get_client"].return_value = client + + req = FakeRequest(cookies={"navi_session": "sess1"}) + user = await get_current_user(req) + assert user is not None + assert user.id == "u1" + # refresh_token should NOT have been called because the second + # fetchrow showed an un-expired session + assert not client.refresh_token.called @pytest.mark.asyncio