"""Official MCP Streamable HTTP transport."""

from typing import Any
from uuid import UUID

from mcp.server.auth.middleware.auth_context import get_access_token
from mcp.server.auth.provider import AccessToken
from mcp.server.auth.settings import AuthSettings
from mcp.server.fastmcp import Context, FastMCP
from sqlalchemy.orm import Session

from gnexus_creds.auth import require_enabled_user
from gnexus_creds.config import get_settings
from gnexus_creds.db import SessionLocal
from gnexus_creds.mcp_descriptions import MCP_SERVER_INSTRUCTIONS
from gnexus_creds.schemas import Scope, SecretCreate, SecretStatus, SecretUpdate
from gnexus_creds.services import (
    Actor,
    authenticate_api_token,
    create_secret,
    get_secret,
    list_secrets,
    reveal_secret,
    update_secret,
)


class ApiTokenVerifier:
    """Validate gnexus-creds API tokens for MCP SDK auth middleware."""

    async def verify_token(self, token: str) -> AccessToken | None:
        with SessionLocal() as db:
            actor = authenticate_api_token(db, token)
            if actor is None or Scope.mcp.value not in actor.api_token.scopes:
                return None
            db.commit()
            return AccessToken(
                token=token,
                client_id=str(actor.user.id),
                scopes=actor.api_token.scopes,
            )


def _actor_from_mcp_context(
    db: Session,
    *,
    ip_address: str | None = None,
    user_agent: str | None = None,
) -> Actor:
    access_token = get_access_token()
    if access_token is None:
        raise ValueError("MCP authentication context is missing.")
    actor = authenticate_api_token(db, access_token.token)
    if actor is None or actor.api_token is None or Scope.mcp.value not in actor.api_token.scopes:
        raise ValueError("MCP token is invalid or missing required scope.")
    require_enabled_user(actor.user)
    actor.channel = "mcp"
    actor.ip_address = ip_address
    actor.user_agent = user_agent
    db.commit()
    return actor


def _request_from_ctx(ctx: Context | None) -> Any:
    if ctx is None:
        return None
    try:
        return ctx.request_context.request
    except ValueError:
        return None


def create_mcp_protocol_server() -> FastMCP:
    settings = get_settings()
    server = FastMCP(
        "gnexus-creds",
        instructions=MCP_SERVER_INSTRUCTIONS,
        stateless_http=True,
        json_response=True,
        streamable_http_path="/",
        token_verifier=ApiTokenVerifier(),
        auth=AuthSettings(
            issuer_url=settings.auth_base_url,
            resource_server_url=settings.mcp_resource_url,
            required_scopes=[Scope.mcp.value],
        ),
    )

    @server.tool()
    async def search_secrets(
        q: str | None = None,
        category: str | None = None,
        status: str | None = None,
        offset: int = 0,
        limit: int = 20,
        ctx: Context = None,  # type: ignore[assignment]
    ) -> dict[str, Any]:
        """Search MCP-available, non-archived secrets.

        Searches metadata and unencrypted fields only. Does not decrypt
        encrypted values. Use this before get_secret or reveal_secret to find
        candidate secrets.
        """
        request = _request_from_ctx(ctx)
        ip = request.client.host if request and request.client else None
        ua = request.headers.get("user-agent") if request else None
        with SessionLocal() as db:
            actor = _actor_from_mcp_context(db, ip_address=ip, user_agent=ua)
            items, total = list_secrets(
                db,
                actor,
                q=q,
                category=category,
                status=SecretStatus(status.strip()) if status and status.strip() else None,
                include_archived=False,
                offset=offset,
                limit=limit,
                mcp=True,
            )
            return {"items": [item.model_dump(mode="json") for item in items], "total": total}

    @server.tool(name="get_secret")
    async def get_secret_tool(secret_id: str, ctx: Context = None) -> dict[str, Any]:  # type: ignore[assignment]
        """Get metadata and public or masked fields for one MCP-available secret.

        This does not decrypt encrypted values. Use reveal_secret only when the
        user explicitly needs the actual secret value.
        """
        request = _request_from_ctx(ctx)
        ip = request.client.host if request and request.client else None
        ua = request.headers.get("user-agent") if request else None
        with SessionLocal() as db:
            actor = _actor_from_mcp_context(db, ip_address=ip, user_agent=ua)
            result = get_secret(db, actor, UUID(secret_id), mcp=True)
            return result.model_dump(mode="json")

    @server.tool(name="reveal_secret")
    async def reveal_secret_tool(secret_id: str, ctx: Context = None) -> dict[str, Any]:  # type: ignore[assignment]
        """Return decrypted field values for one MCP-available secret.

        Use only when the user explicitly needs secret values. This creates an
        audit event with channel=mcp.
        """
        request = _request_from_ctx(ctx)
        ip = request.client.host if request and request.client else None
        ua = request.headers.get("user-agent") if request else None
        with SessionLocal() as db:
            actor = _actor_from_mcp_context(db, ip_address=ip, user_agent=ua)
            result = reveal_secret(db, actor, UUID(secret_id), mcp=True)
            db.commit()
            return result.model_dump(mode="json")

    @server.tool(name="create_secret")
    async def create_secret_tool(
        title: str,
        purpose: str | None = None,
        category: str | None = None,
        source: str | None = None,
        notes: str | None = None,
        tags: list[str] | None = None,
        allow_ui: bool = True,
        allow_rest_api: bool = True,
        allow_mcp: bool = False,
        fields: list[dict[str, Any]] | None = None,
        ctx: Context = None,  # type: ignore[assignment]
    ) -> dict[str, Any]:
        """Create a secret through MCP.

        The fields argument is a list of objects with name, value, encrypted,
        masked, and optional position. Passwords, tokens, PINs, private keys,
        recovery codes, and similar sensitive values must use encrypted=true.
        Non-sensitive identifiers such as usernames can remain unencrypted for
        search.
        """
        request = _request_from_ctx(ctx)
        ip = request.client.host if request and request.client else None
        ua = request.headers.get("user-agent") if request else None
        with SessionLocal() as db:
            actor = _actor_from_mcp_context(db, ip_address=ip, user_agent=ua)
            result = create_secret(
                db,
                actor,
                SecretCreate(
                    title=title,
                    purpose=purpose,
                    category=category,
                    source=source,
                    notes=notes,
                    tags=tags or [],
                    allow_ui=allow_ui,
                    allow_rest_api=allow_rest_api,
                    allow_mcp=allow_mcp,
                    fields=fields or [],
                ),
            )
            db.commit()
            return result.model_dump(mode="json")

    @server.tool(name="update_secret")
    async def update_secret_tool(
        secret_id: str,
        title: str | None = None,
        purpose: str | None = None,
        category: str | None = None,
        source: str | None = None,
        notes: str | None = None,
        tags: list[str] | None = None,
        status: str | None = None,
        archived: bool | None = None,
        allow_ui: bool | None = None,
        allow_rest_api: bool | None = None,
        allow_mcp: bool | None = None,
        fields: list[dict[str, Any]] | None = None,
        ctx: Context = None,  # type: ignore[assignment]
    ) -> dict[str, Any]:
        """Update metadata or fields for one MCP-available secret.

        Updating fields creates a new current version while old versions remain
        historical. Use only with explicit user intent.
        """
        request = _request_from_ctx(ctx)
        ip = request.client.host if request and request.client else None
        ua = request.headers.get("user-agent") if request else None
        payload: dict[str, Any] = {}
        for key, value in [
            ("title", title),
            ("purpose", purpose),
            ("category", category),
            ("source", source),
            ("notes", notes),
            ("tags", tags),
            ("status", SecretStatus(status) if status and status.strip() else None),
            ("archived", archived),
            ("allow_ui", allow_ui),
            ("allow_rest_api", allow_rest_api),
            ("allow_mcp", allow_mcp),
            ("fields", fields),
        ]:
            if value is not None:
                payload[key] = value
        with SessionLocal() as db:
            actor = _actor_from_mcp_context(db, ip_address=ip, user_agent=ua)
            get_secret(db, actor, UUID(secret_id), mcp=True)
            result = update_secret(db, actor, UUID(secret_id), SecretUpdate(**payload))
            db.commit()
            return result.model_dump(mode="json")

    @server.tool()
    async def set_secret_status(
        secret_id: str,
        status: str,
        ctx: Context = None,  # type: ignore[assignment]
    ) -> dict[str, Any]:
        """Set a secret status to actual, outdated, or archived through MCP.

        Use only when the user explicitly asks for a status change.
        """
        request = _request_from_ctx(ctx)
        ip = request.client.host if request and request.client else None
        ua = request.headers.get("user-agent") if request else None
        with SessionLocal() as db:
            actor = _actor_from_mcp_context(db, ip_address=ip, user_agent=ua)
            get_secret(db, actor, UUID(secret_id), mcp=True)
            result = update_secret(
                db,
                actor,
                UUID(secret_id),
                SecretUpdate(status=SecretStatus(status.strip()) if status.strip() else None),
            )
            db.commit()
            return result.model_dump(mode="json")

    @server.tool()
    async def archive_secret(secret_id: str, ctx: Context = None) -> dict[str, Any]:  # type: ignore[assignment]
        """Archive one MCP-available secret.

        Archived secrets are unavailable through normal MCP access. Use only
        when the user explicitly asks to archive it.
        """
        request = _request_from_ctx(ctx)
        ip = request.client.host if request and request.client else None
        ua = request.headers.get("user-agent") if request else None
        with SessionLocal() as db:
            actor = _actor_from_mcp_context(db, ip_address=ip, user_agent=ua)
            get_secret(db, actor, UUID(secret_id), mcp=True)
            result = update_secret(db, actor, UUID(secret_id), SecretUpdate(archived=True))
            db.commit()
            return result.model_dump(mode="json")

    return server
