Newer
Older
navi-1 / navi / mcp / manager.py
from __future__ import annotations

import asyncio
import logging
from pathlib import Path
from typing import Any

from .client import McpClient
from .config import McpServerConfig, load_mcp_servers

logger = logging.getLogger(__name__)


class McpManager:
    """Holds a pool of :class:`McpClient` instances and manages their lifecycle.

    Typical usage at application startup::

        manager = McpManager()
        await manager.load_all()
        # … later, on reload_tools built-in invocation …
        await manager.reload_all()
    """

    def __init__(self, config_path: str | Path | None = None) -> None:
        self.config_path = config_path
        self._clients: dict[str, McpClient] = {}

    @property
    def clients(self) -> dict[str, McpClient]:
        return self._clients

    async def load_all(self, configs: dict[str, McpServerConfig] | None = None) -> None:
        """Connect to every server in *configs* (or load from disk).

        Existing clients are disconnected first so that a reload is clean.
        """
        if configs is None:
            configs = load_mcp_servers(self.config_path)

        # disconnect old
        await self.disconnect_all()

        # connect new
        for name, cfg in configs.items():
            client = McpClient(name, cfg)
            try:
                await client.connect()
                self._clients[name] = client
            except Exception as exc:
                logger.warning("MCP server %r failed to connect: %s", name, exc)

    async def reload_all(self) -> None:
        """Re-read the config file and reconnect every server."""
        configs = load_mcp_servers(self.config_path)
        await self.load_all(configs)

    async def disconnect_all(self) -> None:
        """Close every open connection and clear the client pool."""
        if not self._clients:
            return
        results = await asyncio.gather(
            *[c.disconnect() for c in self._clients.values()],
            return_exceptions=True,
        )
        for name, exc in zip(self._clients, results):
            if isinstance(exc, Exception):
                logger.warning("MCP server %r disconnect error: %s", name, exc)
        self._clients.clear()

    async def get_all_tools(self) -> list[tuple[str, Any]]:
        """Return ``(server_name, mcp_tool)`` for every tool on every server.

        Servers that fail to list tools are skipped gracefully.
        """
        out: list[tuple[str, Any]] = []
        for name, client in self._clients.items():
            try:
                tools = await client.list_tools()
                out.extend((name, t) for t in tools)
            except Exception as exc:
                logger.warning("MCP server %r list_tools failed: %s", name, exc)
        return out

    async def call_tool(self, server_name: str, tool_name: str, arguments: dict[str, Any] | None = None) -> str:
        """Proxy a tool call to the named server."""
        client = self._clients.get(server_name)
        if client is None:
            raise RuntimeError(f"MCP server {server_name!r} is not connected")
        return await client.call_tool(tool_name, arguments)