diff --git a/.gitignore b/.gitignore index 81d7737..87c7e9d 100644 --- a/.gitignore +++ b/.gitignore @@ -8,3 +8,5 @@ frontend/node_modules/ frontend/dist/ + +var/ diff --git a/gnexus_creds/oauth.py b/gnexus_creds/oauth.py index 0b386be..a854605 100644 --- a/gnexus_creds/oauth.py +++ b/gnexus_creds/oauth.py @@ -14,7 +14,7 @@ from gnexus_creds.db import get_db from gnexus_creds.errors import AppError from gnexus_creds.models import OAuthState, User, utcnow -from gnexus_creds.services import create_session, delete_session, upsert_user_from_auth +from gnexus_creds.services import create_session, delete_session, is_expired, upsert_user_from_auth router = APIRouter(prefix="/auth", tags=["auth"]) webhook_router = APIRouter(prefix="/webhooks", tags=["webhooks"]) @@ -59,7 +59,7 @@ async def callback(code: str, state: str, db: Session = Depends(get_db)) -> RedirectResponse: settings = get_settings() saved = db.get(OAuthState, state) - if saved is None or saved.expires_at <= datetime.now(UTC): + if saved is None or is_expired(saved.expires_at, now=datetime.now(UTC)): raise AppError("invalid_oauth_state", "Invalid or expired OAuth state.", status_code=400) config = _config() token_set = HttpTokenEndpoint(config).exchange_authorization_code(code, saved.pkce_verifier) diff --git a/gnexus_creds/services.py b/gnexus_creds/services.py index 56bd23a..ca2e6de 100644 --- a/gnexus_creds/services.py +++ b/gnexus_creds/services.py @@ -142,10 +142,10 @@ return row -def _expired_at(start: datetime, delta: timedelta, now: datetime) -> bool: +def is_expired(start: datetime, *, now: datetime, delta: timedelta | None = None) -> bool: if start.tzinfo is None and now.tzinfo is not None: now = now.replace(tzinfo=None) - return start + delta <= now + return start + (delta or timedelta(seconds=0)) <= now def check_rate_limit( @@ -163,7 +163,7 @@ if row is None: db.add(RateLimitBucket(key=key, count=1, window_start=now)) return - if _expired_at(row.window_start, timedelta(seconds=window_seconds), now): + if is_expired(row.window_start, delta=timedelta(seconds=window_seconds), now=now): row.count = 1 row.window_start = now return @@ -188,7 +188,7 @@ now = utcnow() bucket_key = f"failed_access:{key}" row = db.get(RateLimitBucket, bucket_key) - if row is None or _expired_at(row.window_start, timedelta(minutes=5), now): + if row is None or is_expired(row.window_start, delta=timedelta(minutes=5), now=now): if row is None: db.add(RateLimitBucket(key=bucket_key, count=1, window_start=now)) else: @@ -574,7 +574,7 @@ if not session_id: return None row = db.get(SessionRecord, session_id) - if row is None or _expired_at(row.expires_at, timedelta(seconds=0), utcnow()): + if row is None or is_expired(row.expires_at, now=utcnow()): return None user = db.get(User, row.user_id) if user is None or user.status == "disabled": diff --git a/tests/test_core.py b/tests/test_core.py index b3e5f8b..348e366 100644 --- a/tests/test_core.py +++ b/tests/test_core.py @@ -1,5 +1,13 @@ +from datetime import UTC, datetime, timedelta + from gnexus_creds.schemas import SecretCreate, SecretFieldIn, SecretUpdate -from gnexus_creds.services import create_secret, list_versions, reveal_secret, update_secret +from gnexus_creds.services import ( + create_secret, + is_expired, + list_versions, + reveal_secret, + update_secret, +) def test_secret_versioning_and_reveal(db_session, actor): @@ -42,3 +50,9 @@ revealed = reveal_secret(db_session, actor, created.id) values = {field.name: field.value for field in revealed.fields} assert values["password"] == "new-secret" + + +def test_is_expired_handles_sqlite_naive_datetime(): + now = datetime.now(UTC) + naive_start = (now - timedelta(minutes=2)).replace(tzinfo=None) + assert is_expired(naive_start, now=now, delta=timedelta(minutes=1))