Newer
Older
gnexus-creds / gnexus_creds / mcp.py
"""MCP HTTP/SSE adapter.

This module intentionally keeps MCP transport details isolated from the domain
services. The first implementation exposes tool calls over HTTP and provides an
SSE discovery stream. If the Python MCP SDK transport API changes, only this
adapter should need replacement.
"""

import json
from uuid import UUID

from fastapi import APIRouter, Depends
from fastapi.responses import StreamingResponse
from pydantic import BaseModel, Field
from sqlalchemy.orm import Session

from gnexus_creds.auth import actor_from_request
from gnexus_creds.db import get_db
from gnexus_creds.errors import AppError
from gnexus_creds.mcp_descriptions import LEGACY_TOOLS, MCP_SERVER_INSTRUCTIONS
from gnexus_creds.schemas import Scope, SecretCreate, SecretStatus, SecretUpdate
from gnexus_creds.services import (
    Actor,
    create_secret,
    get_secret,
    list_secrets,
    reveal_secret,
    update_secret,
)

router = APIRouter(prefix="/mcp", tags=["mcp"])


TOOLS = [tool["name"] for tool in LEGACY_TOOLS]


class ToolCall(BaseModel):
    arguments: dict = Field(default_factory=dict)


def _mcp_actor(actor: Actor) -> Actor:
    if actor.api_token is None or Scope.mcp.value not in actor.api_token.scopes:
        raise AppError("forbidden", "MCP scope is required.", status_code=403)
    actor.channel = "mcp"
    return actor


def _require_arg(args: dict, key: str):
    if key not in args:
        raise AppError("bad_request", f"Missing required argument: {key}.", status_code=400)
    return args[key]


def _parse_uuid(value) -> UUID:
    try:
        return UUID(value)
    except (ValueError, TypeError) as exc:
        raise AppError("bad_request", "Invalid UUID format.", status_code=400) from exc


def _parse_status(value):
    try:
        return SecretStatus(value)
    except ValueError as exc:
        raise AppError("bad_request", "Invalid status value.", status_code=400) from exc


@router.get("/sse")
async def mcp_sse(actor: Actor = Depends(actor_from_request)) -> StreamingResponse:
    _mcp_actor(actor)

    def stream():
        payload = json.dumps(
            {
                "type": "tools",
                "instructions": MCP_SERVER_INSTRUCTIONS,
                "tools": LEGACY_TOOLS,
            }
        )
        yield f"event: ready\ndata: {payload}\n\n"

    return StreamingResponse(stream(), media_type="text/event-stream")


@router.post("/tools/{tool_name}")
async def call_tool(
    tool_name: str,
    call: ToolCall,
    db: Session = Depends(get_db),
    actor: Actor = Depends(actor_from_request),
):
    actor = _mcp_actor(actor)
    args = call.arguments
    if tool_name == "search_secrets":
        status_raw = args.get("status")
        items, total = list_secrets(
            db,
            actor,
            q=args.get("q"),
            category=args.get("category"),
            status=_parse_status(status_raw) if status_raw and str(status_raw).strip() else None,
            include_archived=False,
            offset=int(args.get("offset", 0)),
            limit=int(args.get("limit", 20)),
            mcp=True,
        )
        return {"items": [item.model_dump(mode="json") for item in items], "total": total}
    if tool_name == "get_secret":
        result = get_secret(db, actor, _parse_uuid(_require_arg(args, "secret_id")), mcp=True)
        return result.model_dump(mode="json")
    if tool_name == "reveal_secret":
        result = reveal_secret(db, actor, _parse_uuid(_require_arg(args, "secret_id")), mcp=True)
        db.commit()
        return result.model_dump(mode="json")
    if tool_name == "create_secret":
        result = create_secret(db, actor, SecretCreate(**args))
        db.commit()
        return result.model_dump(mode="json")
    if tool_name == "update_secret":
        secret_id = _parse_uuid(args.pop("secret_id", None) or _require_arg(args, "secret_id"))
        get_secret(db, actor, secret_id, mcp=True)
        result = update_secret(db, actor, secret_id, SecretUpdate(**args))
        db.commit()
        return result.model_dump(mode="json")
    if tool_name == "set_secret_status":
        secret_id = _parse_uuid(_require_arg(args, "secret_id"))
        status = _parse_status(_require_arg(args, "status"))
        get_secret(db, actor, secret_id, mcp=True)
        result = update_secret(
            db,
            actor,
            secret_id,
            SecretUpdate(status=status),
        )
        db.commit()
        return result.model_dump(mode="json")
    if tool_name == "archive_secret":
        secret_id = _parse_uuid(_require_arg(args, "secret_id"))
        get_secret(db, actor, secret_id, mcp=True)
        result = update_secret(db, actor, secret_id, SecretUpdate(archived=True))
        db.commit()
        return result.model_dump(mode="json")
    raise AppError("mcp_tool_not_found", "MCP tool not found.", status_code=404)