diff --git a/gnexus_creds/api.py b/gnexus_creds/api.py index 001792b..1bab470 100644 --- a/gnexus_creds/api.py +++ b/gnexus_creds/api.py @@ -18,6 +18,7 @@ Page, Scope, SecretCreate, + SecretFieldIn, SecretRead, SecretReveal, SecretStatus, @@ -27,6 +28,7 @@ from gnexus_creds.services import ( Actor, audit, + check_actor_rate_limit, create_api_token, create_secret, delete_secret, @@ -56,6 +58,33 @@ } +def _export_secret(row: Secret, actor: Actor, db: Session) -> dict: + revealed = serialize_secret(row, reveal=True, db=db, user=actor.user) + return SecretCreate( + title=revealed.title, + purpose=revealed.purpose, + category=revealed.category, + source=revealed.source, + notes=revealed.notes, + tags=revealed.tags, + status=revealed.status, + archived=revealed.archived, + allow_ui=revealed.allow_ui, + allow_rest_api=revealed.allow_rest_api, + allow_mcp=revealed.allow_mcp, + fields=[ + SecretFieldIn( + name=field.name, + value=field.value or "", + encrypted=field.encrypted, + masked=field.masked, + position=field.position, + ) + for field in revealed.fields + ], + ).model_dump(mode="json") + + @router.get("/me") async def me(actor: Actor = Depends(actor_from_request)) -> dict: return { @@ -319,17 +348,14 @@ actor: Actor = Depends(actor_from_request), ) -> dict: actor.require(Scope.reveal) + check_actor_rate_limit(db, actor, "export") rows = db.scalars( select(Secret) .where(Secret.user_id == actor.user.id) .options(selectinload(Secret.versions), selectinload(Secret.tags)) .order_by(Secret.created_at) ).unique() - exported = [] - for row in rows: - exported.append( - serialize_secret(row, reveal=True, db=db, user=actor.user).model_dump(mode="json") - ) + exported = [_export_secret(row, actor, db) for row in rows] audit(db, actor, action="export.created") db.commit() return { @@ -347,6 +373,11 @@ actor: Actor = Depends(actor_from_request), ) -> dict[str, int]: actor.require(Scope.write) + check_actor_rate_limit(db, actor, "import") + if payload.format != "gnexus-creds-export" or payload.version != 1: + from gnexus_creds.errors import AppError + + raise AppError("unsupported_import_format", "Unsupported import format.", status_code=400) created = 0 for item in payload.secrets: create_secret(db, actor, item) diff --git a/gnexus_creds/auth.py b/gnexus_creds/auth.py index 9bd7129..b2810f7 100644 --- a/gnexus_creds/auth.py +++ b/gnexus_creds/auth.py @@ -9,7 +9,12 @@ from gnexus_creds.errors import AppError from gnexus_creds.models import User from gnexus_creds.schemas import Scope -from gnexus_creds.services import Actor, authenticate_api_token, get_session_user +from gnexus_creds.services import ( + Actor, + authenticate_api_token, + get_session_user, + log_failed_access_once, +) bearer = HTTPBearer(auto_error=False) @@ -35,6 +40,19 @@ ip_address=request.client.host if request.client else None, user_agent=request.headers.get("user-agent"), ) + failed_key = "anonymous" + if credentials: + failed_key = f"bearer:{credentials.credentials[:12]}" + elif session_id: + failed_key = f"session:{session_id[:12]}" + log_failed_access_once( + db, + key=failed_key, + channel="rest", + ip_address=request.client.host if request.client else None, + user_agent=request.headers.get("user-agent"), + ) + db.commit() raise AppError("unauthorized", "Authentication required.", status_code=401) diff --git a/gnexus_creds/mcp.py b/gnexus_creds/mcp.py index 91a0528..79353ee 100644 --- a/gnexus_creds/mcp.py +++ b/gnexus_creds/mcp.py @@ -98,20 +98,25 @@ return result.model_dump(mode="json") if tool_name == "update_secret": secret_id = UUID(args.pop("secret_id")) + get_secret(db, actor, secret_id, mcp=True) result = update_secret(db, actor, secret_id, SecretUpdate(**args)) db.commit() return result.model_dump(mode="json") if tool_name == "set_secret_status": + secret_id = UUID(args["secret_id"]) + get_secret(db, actor, secret_id, mcp=True) result = update_secret( db, actor, - UUID(args["secret_id"]), + secret_id, SecretUpdate(status=SecretStatus(args["status"])), ) db.commit() return result.model_dump(mode="json") if tool_name == "archive_secret": - result = update_secret(db, actor, UUID(args["secret_id"]), SecretUpdate(archived=True)) + secret_id = UUID(args["secret_id"]) + get_secret(db, actor, secret_id, mcp=True) + result = update_secret(db, actor, secret_id, SecretUpdate(archived=True)) db.commit() return result.model_dump(mode="json") raise AppError("mcp_tool_not_found", "MCP tool not found.", status_code=404) diff --git a/gnexus_creds/services.py b/gnexus_creds/services.py index 9cf8433..56bd23a 100644 --- a/gnexus_creds/services.py +++ b/gnexus_creds/services.py @@ -3,7 +3,7 @@ import secrets as py_secrets import uuid from dataclasses import dataclass -from datetime import timedelta +from datetime import datetime, timedelta from typing import Any from sqlalchemy import func, or_, select @@ -15,6 +15,7 @@ from gnexus_creds.models import ( ApiToken, AuditEvent, + RateLimitBucket, Secret, SecretTag, SecretVersion, @@ -141,6 +142,75 @@ return row +def _expired_at(start: datetime, delta: timedelta, now: datetime) -> bool: + if start.tzinfo is None and now.tzinfo is not None: + now = now.replace(tzinfo=None) + return start + delta <= now + + +def check_rate_limit( + db: Session, + *, + key: str, + max_requests: int | None = None, + window_seconds: int | None = None, +) -> None: + settings = get_settings() + max_requests = max_requests or settings.rate_limit_max_sensitive_requests + window_seconds = window_seconds or settings.rate_limit_window_seconds + now = utcnow() + row = db.get(RateLimitBucket, key) + if row is None: + db.add(RateLimitBucket(key=key, count=1, window_start=now)) + return + if _expired_at(row.window_start, timedelta(seconds=window_seconds), now): + row.count = 1 + row.window_start = now + return + row.count += 1 + if row.count > max_requests: + raise AppError("rate_limited", "Too many requests.", status_code=429) + + +def check_actor_rate_limit(db: Session, actor: Actor, action: str) -> None: + token_part = actor.api_token.public_id if actor.api_token else "session" + check_rate_limit(db, key=f"{action}:{actor.user.id}:{token_part}") + + +def log_failed_access_once( + db: Session, + *, + key: str, + channel: str, + ip_address: str | None, + user_agent: str | None, +) -> None: + now = utcnow() + bucket_key = f"failed_access:{key}" + row = db.get(RateLimitBucket, bucket_key) + if row is None or _expired_at(row.window_start, timedelta(minutes=5), now): + if row is None: + db.add(RateLimitBucket(key=bucket_key, count=1, window_start=now)) + else: + row.count = 1 + row.window_start = now + db.add( + AuditEvent( + user_id=None, + actor_user_id=None, + api_token_id=None, + secret_id=None, + channel=channel, + action="access.failed", + ip_address=ip_address, + user_agent=user_agent, + audit_metadata={"key": key}, + ) + ) + else: + row.count += 1 + + def _store_fields(db: Session, user: User, fields: list[SecretFieldIn]) -> list[dict]: key_id, key = get_user_key(db, user) stored = [] @@ -344,6 +414,7 @@ mcp: bool = False, ) -> SecretReveal: actor.require(Scope.reveal) + check_actor_rate_limit(db, actor, "reveal") secret = _load_secret(db, actor.user.id, secret_id) if mcp and (secret.archived or not secret.allow_mcp): raise AppError("secret_not_found", "Secret not found.", status_code=404) @@ -453,6 +524,7 @@ def create_api_token(db: Session, actor: Actor, payload: ApiTokenCreate): actor.require(Scope.admin) + check_actor_rate_limit(db, actor, "api_token_create") public_id = py_secrets.token_hex(8) secret = crypto.token_secret() full_token = f"gcr_{public_id}_{secret}" @@ -502,7 +574,7 @@ if not session_id: return None row = db.get(SessionRecord, session_id) - if row is None or row.expires_at <= utcnow(): + if row is None or _expired_at(row.expires_at, timedelta(seconds=0), utcnow()): return None user = db.get(User, row.user_id) if user is None or user.status == "disabled": diff --git a/tests/conftest.py b/tests/conftest.py index 7aa2990..7172f0e 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -66,3 +66,15 @@ app.dependency_overrides[get_db] = override_db app.dependency_overrides[actor_from_request] = override_actor return app + + +@pytest.fixture() +def auth_app(session_factory): + app = create_app() + + async def override_db(): + with session_factory() as session: + yield session + + app.dependency_overrides[get_db] = override_db + return app diff --git a/tests/test_api.py b/tests/test_api.py index 17fbeec..53fc6ab 100644 --- a/tests/test_api.py +++ b/tests/test_api.py @@ -1,6 +1,11 @@ import pytest from httpx import ASGITransport, AsyncClient +from gnexus_creds import crypto +from gnexus_creds.models import ApiToken, AuditEvent +from gnexus_creds.schemas import SecretCreate, SecretFieldIn +from gnexus_creds.services import Actor, create_secret + @pytest.mark.anyio async def test_rest_create_list_reveal(app): @@ -30,6 +35,13 @@ response = await client.get("/api/v1/secrets?q=deploy") assert response.status_code == 200 assert response.json()["total"] == 1 + response = await client.get("/api/v1/secrets?q=pass123") + assert response.status_code == 200 + assert response.json()["total"] == 0 + response = await client.get("/api/v1/secrets?q=password") + assert response.status_code == 200 + assert response.json()["total"] == 1 + response = await client.get("/api/v1/secrets?q=deploy") fields = response.json()["items"][0]["fields"] assert {field["name"]: field["value"] for field in fields}["username"] == "deploy" assert {field["name"]: field["value"] for field in fields}["password"] is None @@ -39,3 +51,107 @@ assert {field["name"]: field["value"] for field in response.json()["fields"]}[ "password" ] == "pass123" + + +@pytest.mark.anyio +async def test_export_import_round_trip(app): + async with AsyncClient(transport=ASGITransport(app=app), base_url="http://test") as client: + create_response = await client.post( + "/api/v1/secrets", + json={ + "title": "Card PIN", + "purpose": "bank card", + "category": "finance", + "tags": ["card"], + "fields": [ + {"name": "card", "value": "1111", "encrypted": False}, + {"name": "pin", "value": "1234", "encrypted": True, "masked": True}, + ], + }, + ) + assert create_response.status_code == 200 + + export_response = await client.post("/api/v1/export") + assert export_response.status_code == 200 + payload = export_response.json() + assert payload["format"] == "gnexus-creds-export" + assert payload["secrets"][0]["fields"][1]["value"] == "1234" + + delete_response = await client.delete("/api/v1/account-data") + assert delete_response.status_code == 204 + list_response = await client.get("/api/v1/secrets") + assert list_response.json()["total"] == 0 + + import_response = await client.post("/api/v1/import", json=payload) + assert import_response.status_code == 200 + assert import_response.json()["created"] == 1 + reveal_id = (await client.get("/api/v1/secrets")).json()["items"][0]["id"] + reveal_response = await client.post(f"/api/v1/secrets/{reveal_id}/reveal") + assert {field["name"]: field["value"] for field in reveal_response.json()["fields"]}[ + "pin" + ] == "1234" + + +@pytest.mark.anyio +async def test_failed_access_is_aggregated(auth_app, db_session): + async with AsyncClient(transport=ASGITransport(app=auth_app), base_url="http://test") as client: + assert (await client.get("/api/v1/me")).status_code == 401 + assert (await client.get("/api/v1/me")).status_code == 401 + + events = db_session.query(AuditEvent).filter(AuditEvent.action == "access.failed").all() + assert len(events) == 1 + + +@pytest.mark.anyio +async def test_api_token_scopes(auth_app, db_session, user): + created = create_secret( + db_session, + Actor(user=user, channel="ui"), + SecretCreate( + title="Token scoped", + fields=[SecretFieldIn(name="password", value="secret", encrypted=True)], + ), + ) + read_token = "gcr_read_secret" + reveal_token = "gcr_reveal_secret" + db_session.add_all( + [ + ApiToken( + user_id=user.id, + public_id="read", + name="read", + token_hash=crypto.token_hash(read_token), + scopes=["read"], + ), + ApiToken( + user_id=user.id, + public_id="reveal", + name="reveal", + token_hash=crypto.token_hash(reveal_token), + scopes=["read", "reveal"], + ), + ] + ) + db_session.commit() + + async with AsyncClient(transport=ASGITransport(app=auth_app), base_url="http://test") as client: + list_response = await client.get( + "/api/v1/secrets", headers={"Authorization": f"Bearer {read_token}"} + ) + assert list_response.status_code == 200 + assert list_response.json()["total"] == 1 + + denied_response = await client.post( + f"/api/v1/secrets/{created.id}/reveal", + headers={"Authorization": f"Bearer {read_token}"}, + ) + assert denied_response.status_code == 403 + + reveal_response = await client.post( + f"/api/v1/secrets/{created.id}/reveal", + headers={"Authorization": f"Bearer {reveal_token}"}, + ) + assert reveal_response.status_code == 200 + assert {field["name"]: field["value"] for field in reveal_response.json()["fields"]}[ + "password" + ] == "secret" diff --git a/tests/test_mcp.py b/tests/test_mcp.py index 5602251..4da90d0 100644 --- a/tests/test_mcp.py +++ b/tests/test_mcp.py @@ -2,6 +2,8 @@ from httpx import ASGITransport, AsyncClient from gnexus_creds.models import ApiToken +from gnexus_creds.schemas import SecretCreate, SecretFieldIn +from gnexus_creds.services import Actor, create_secret @pytest.mark.anyio @@ -17,3 +19,32 @@ async with AsyncClient(transport=ASGITransport(app=app), base_url="http://test") as client: response = await client.post("/mcp/tools/search_secrets", json={"arguments": {}}) assert response.status_code == 403 + + +@pytest.mark.anyio +async def test_mcp_update_requires_secret_allow_mcp(app, db_session, actor): + actor.channel = "mcp" + actor.api_token = ApiToken( + user_id=actor.user.id, + public_id="mcp", + name="mcp", + token_hash="hash", + scopes=["mcp", "read", "write"], + ) + secret = create_secret( + db_session, + Actor(user=actor.user, channel="ui"), + SecretCreate( + title="UI only", + allow_mcp=False, + fields=[SecretFieldIn(name="username", value="demo", encrypted=False)], + ), + ) + db_session.commit() + + async with AsyncClient(transport=ASGITransport(app=app), base_url="http://test") as client: + response = await client.post( + "/mcp/tools/update_secret", + json={"arguments": {"secret_id": str(secret.id), "title": "Changed"}}, + ) + assert response.status_code == 404