Newer
Older
gnexus-creds / gnexus_creds / mcp_protocol.py
@Eugene Sukhodolskiy Eugene Sukhodolskiy 2 days ago 7 KB Improve MCP tool instructions
"""Official MCP Streamable HTTP transport."""

from typing import Any
from uuid import UUID

from mcp.server.auth.middleware.auth_context import get_access_token
from mcp.server.auth.provider import AccessToken
from mcp.server.auth.settings import AuthSettings
from mcp.server.fastmcp import FastMCP
from sqlalchemy.orm import Session
from starlette.applications import Starlette

from gnexus_creds.auth import require_enabled_user
from gnexus_creds.config import get_settings
from gnexus_creds.db import SessionLocal
from gnexus_creds.mcp_descriptions import MCP_SERVER_INSTRUCTIONS
from gnexus_creds.schemas import Scope, SecretCreate, SecretStatus, SecretUpdate
from gnexus_creds.services import (
    Actor,
    authenticate_api_token,
    create_secret,
    get_secret,
    list_secrets,
    reveal_secret,
    update_secret,
)


class ApiTokenVerifier:
    """Validate gnexus-creds API tokens for MCP SDK auth middleware."""

    async def verify_token(self, token: str) -> AccessToken | None:
        with SessionLocal() as db:
            actor = authenticate_api_token(db, token)
            if actor is None or Scope.mcp.value not in actor.api_token.scopes:
                return None
            db.commit()
            return AccessToken(
                token=token,
                client_id=str(actor.user.id),
                scopes=actor.api_token.scopes,
            )


def _actor_from_mcp_context(db: Session) -> Actor:
    access_token = get_access_token()
    if access_token is None:
        raise ValueError("MCP authentication context is missing.")
    actor = authenticate_api_token(db, access_token.token)
    if actor is None or actor.api_token is None or Scope.mcp.value not in actor.api_token.scopes:
        raise ValueError("MCP token is invalid or missing required scope.")
    db.commit()
    require_enabled_user(actor.user)
    actor.channel = "mcp"
    return actor


def create_mcp_protocol_server() -> FastMCP:
    settings = get_settings()
    server = FastMCP(
        "gnexus-creds",
        instructions=MCP_SERVER_INSTRUCTIONS,
        stateless_http=True,
        json_response=True,
        streamable_http_path="/",
        token_verifier=ApiTokenVerifier(),
        auth=AuthSettings(
            issuer_url=settings.auth_base_url,
            resource_server_url=settings.mcp_resource_url,
            required_scopes=[Scope.mcp.value],
        ),
    )

    @server.tool()
    async def search_secrets(
        q: str | None = None,
        category: str | None = None,
        status: str | None = None,
        offset: int = 0,
        limit: int = 20,
    ) -> dict[str, Any]:
        """Search MCP-available, non-archived secrets.

        Searches metadata and unencrypted fields only. Does not decrypt
        encrypted values. Use this before get_secret or reveal_secret to find
        candidate secrets.
        """
        with SessionLocal() as db:
            actor = _actor_from_mcp_context(db)
            items, total = list_secrets(
                db,
                actor,
                q=q,
                category=category,
                status=SecretStatus(status) if status else None,
                include_archived=False,
                offset=offset,
                limit=limit,
                mcp=True,
            )
            return {"items": [item.model_dump(mode="json") for item in items], "total": total}

    @server.tool(name="get_secret")
    async def get_secret_tool(secret_id: str) -> dict[str, Any]:
        """Get metadata and public or masked fields for one MCP-available secret.

        This does not decrypt encrypted values. Use reveal_secret only when the
        user explicitly needs the actual secret value.
        """
        with SessionLocal() as db:
            actor = _actor_from_mcp_context(db)
            result = get_secret(db, actor, UUID(secret_id), mcp=True)
            return result.model_dump(mode="json")

    @server.tool(name="reveal_secret")
    async def reveal_secret_tool(secret_id: str) -> dict[str, Any]:
        """Return decrypted field values for one MCP-available secret.

        Use only when the user explicitly needs secret values. This creates an
        audit event with channel=mcp.
        """
        with SessionLocal() as db:
            actor = _actor_from_mcp_context(db)
            result = reveal_secret(db, actor, UUID(secret_id), mcp=True)
            db.commit()
            return result.model_dump(mode="json")

    @server.tool(name="create_secret")
    async def create_secret_tool(
        title: str,
        purpose: str | None = None,
        category: str | None = None,
        source: str | None = None,
        notes: str | None = None,
        tags: list[str] | None = None,
        allow_ui: bool = True,
        allow_rest_api: bool = True,
        allow_mcp: bool = False,
        fields: list[dict[str, Any]] | None = None,
    ) -> dict[str, Any]:
        """Create a secret through MCP.

        The fields argument is a list of objects with name, value, encrypted,
        masked, and optional position. Passwords, tokens, PINs, private keys,
        recovery codes, and similar sensitive values must use encrypted=true.
        Non-sensitive identifiers such as usernames can remain unencrypted for
        search.
        """
        with SessionLocal() as db:
            actor = _actor_from_mcp_context(db)
            result = create_secret(
                db,
                actor,
                SecretCreate(
                    title=title,
                    purpose=purpose,
                    category=category,
                    source=source,
                    notes=notes,
                    tags=tags or [],
                    allow_ui=allow_ui,
                    allow_rest_api=allow_rest_api,
                    allow_mcp=allow_mcp,
                    fields=fields or [],
                ),
            )
            db.commit()
            return result.model_dump(mode="json")

    @server.tool(name="update_secret")
    async def update_secret_tool(secret_id: str, payload: dict[str, Any]) -> dict[str, Any]:
        """Update metadata or fields for one MCP-available secret.

        Updating fields creates a new current version while old versions remain
        historical. Use only with explicit user intent.
        """
        with SessionLocal() as db:
            actor = _actor_from_mcp_context(db)
            get_secret(db, actor, UUID(secret_id), mcp=True)
            result = update_secret(db, actor, UUID(secret_id), SecretUpdate(**payload))
            db.commit()
            return result.model_dump(mode="json")

    @server.tool()
    async def set_secret_status(secret_id: str, status: str) -> dict[str, Any]:
        """Set a secret status to actual, outdated, or archived through MCP.

        Use only when the user explicitly asks for a status change.
        """
        with SessionLocal() as db:
            actor = _actor_from_mcp_context(db)
            get_secret(db, actor, UUID(secret_id), mcp=True)
            result = update_secret(
                db,
                actor,
                UUID(secret_id),
                SecretUpdate(status=SecretStatus(status)),
            )
            db.commit()
            return result.model_dump(mode="json")

    @server.tool()
    async def archive_secret(secret_id: str) -> dict[str, Any]:
        """Archive one MCP-available secret.

        Archived secrets are unavailable through normal MCP access. Use only
        when the user explicitly asks to archive it.
        """
        with SessionLocal() as db:
            actor = _actor_from_mcp_context(db)
            get_secret(db, actor, UUID(secret_id), mcp=True)
            result = update_secret(db, actor, UUID(secret_id), SecretUpdate(archived=True))
            db.commit()
            return result.model_dump(mode="json")

    return server


def create_mcp_protocol_app() -> Starlette:
    return create_mcp_protocol_server().streamable_http_app()