Newer
Older
navi-1 / navi / mcp / config.py
from __future__ import annotations

import json
import logging
from pathlib import Path
from typing import Literal

from pydantic import BaseModel, Field

logger = logging.getLogger(__name__)


class McpServerConfig(BaseModel):
    """Configuration for a single MCP server."""

    transport: Literal["stdio", "sse", "streamable_http"] = "stdio"

    # stdio fields
    command: str | None = None
    args: list[str] = Field(default_factory=list)
    env: dict[str, str] | None = None
    cwd: str | None = None

    # sse fields
    url: str | None = None
    headers: dict[str, str] | None = None

    # tool groups: name -> list of tool names exposed by this server.
    # Profiles reference groups by name instead of listing individual tools.
    groups: dict[str, list[str]] = Field(default_factory=dict)

    # Overlay instructions injected into Navi's system prompt alongside the
    # instructions provided by the MCP server itself during the initialize handshake.
    instructions: str | None = None

    @property
    def is_stdio(self) -> bool:
        return self.transport == "stdio"

    @property
    def is_sse(self) -> bool:
        return self.transport == "sse"

    @property
    def is_streamable_http(self) -> bool:
        return self.transport == "streamable_http"


def _default_dir() -> Path:
    """Return the default directory for per-server MCP configs."""
    return Path("mcp_servers.d")


def _default_legacy_file() -> Path:
    """Return the legacy monolithic config file path."""
    return Path("mcp_servers.json")


def _migrate_if_needed() -> None:
    """Auto-migrate legacy ``mcp_servers.json`` to ``mcp_servers.d/``.

    Called transparently by :func:`load_mcp_servers` when the legacy file
    exists but the directory does not.
    """
    legacy = _default_legacy_file()
    target_dir = _default_dir()

    if not legacy.exists() or legacy.is_dir():
        return
    if target_dir.exists():
        return

    try:
        raw = json.loads(legacy.read_text(encoding="utf-8"))
        target_dir.mkdir(parents=True, exist_ok=True)
        for name, cfg_data in raw.items():
            file_path = target_dir / f"{name}.json"
            file_path.write_text(
                json.dumps(cfg_data, indent=2, ensure_ascii=False) + "\n",
                encoding="utf-8",
            )
        # Rename the legacy file so it is no longer picked up.
        legacy.rename(legacy.with_suffix(".json.bak"))
        logger.info(
            "MCP config migrated from %s to %s (%s servers)",
            legacy,
            target_dir,
            len(raw),
        )
    except Exception:
        logger.warning("MCP config migration failed", exc_info=True)


def load_mcp_servers(path: str | Path | None = None) -> dict[str, McpServerConfig]:
    """Load MCP server configurations.

    If *path* is a directory (or None), read every ``*.json`` file inside it.
    The filename without extension becomes the server name.

    If *path* points to the legacy monolithic ``mcp_servers.json`` file,
    it is read directly (and auto-migration to ``mcp_servers.d/`` is attempted).

    Returns an empty dict if nothing is found.
    """
    if path is None:
        legacy = _default_legacy_file()
        target_dir = _default_dir()

        # Auto-migrate legacy file to directory if needed
        if legacy.exists() and not legacy.is_dir() and not target_dir.exists():
            _migrate_if_needed()

        if target_dir.exists() and target_dir.is_dir():
            path = target_dir
        elif legacy.exists() and legacy.is_file():
            path = legacy
        else:
            return {}
    else:
        path = Path(path)

    if path.is_dir():
        result: dict[str, McpServerConfig] = {}
        for file_path in sorted(path.glob("*.json")):
            try:
                raw = json.loads(file_path.read_text(encoding="utf-8"))
                name = file_path.stem
                result[name] = McpServerConfig.model_validate(raw)
            except Exception:
                logger.warning("Failed to load MCP config from %s", file_path, exc_info=True)
        return result

    if path.is_file():
        raw = json.loads(path.read_text(encoding="utf-8"))
        return {name: McpServerConfig.model_validate(cfg_data) for name, cfg_data in raw.items()}

    return {}


def save_mcp_servers(
    configs: dict[str, McpServerConfig],
    path: str | Path | None = None,
) -> None:
    """Write MCP server configurations.

    If *path* is a directory (or None), each server is written to its own
    ``<name>.json`` file inside that directory. Any ``*.json`` files for
    servers that are no longer in *configs* are removed.

    If *path* points to a file, the legacy monolithic format is used.
    """
    if path is None:
        path = _default_dir()
    else:
        path = Path(path)

    if path.is_dir() or (not path.exists() and str(path).endswith(".d")):
        path.mkdir(parents=True, exist_ok=True)
        for name, cfg in configs.items():
            file_path = path / f"{name}.json"
            file_path.write_text(
                json.dumps(cfg.model_dump(), indent=2, ensure_ascii=False) + "\n",
                encoding="utf-8",
            )
        # Clean up stale files
        current_names = set(configs)
        for file_path in path.glob("*.json"):
            if file_path.stem not in current_names:
                file_path.unlink()
    else:
        raw = {name: cfg.model_dump() for name, cfg in configs.items()}
        path.write_text(json.dumps(raw, indent=2, ensure_ascii=False) + "\n", encoding="utf-8")