Newer
Older
navi-1 / navi / mcp / manager.py
from __future__ import annotations

import asyncio
import logging
from pathlib import Path
from typing import Any, Awaitable, Callable

from .client import McpClient
from .config import McpServerConfig, load_mcp_servers

logger = logging.getLogger(__name__)


class McpManager:
    """Holds a pool of :class:`McpClient` instances and manages their lifecycle.

    Typical usage at application startup::

        manager = McpManager()
        await manager.load_all()
        # … later, on reload_tools built-in invocation …
        await manager.reload_all()
    """

    def __init__(self, config_path: str | Path | None = None) -> None:
        self.config_path = config_path
        self._clients: dict[str, McpClient] = {}
        self._configs: dict[str, McpServerConfig] | None = None

        # Callback invoked after a server successfully connects (or reconnects).
        # Signature: async def callback(server_name: str) -> None
        self._on_server_connected: Callable[[str], Awaitable[None]] | None = None

        # Background health-check task
        self._health_check_task: asyncio.Task | None = None
        self._health_check_interval: float = 30.0

        # Last known connected status per server — used to suppress duplicate
        # "connected" toast notifications on every health-check poll.
        self._connected_status: dict[str, bool] = {}

    @property
    def clients(self) -> dict[str, McpClient]:
        return self._clients

    def set_on_server_connected(self, callback: Callable[[str], Awaitable[None]] | None) -> None:
        """Set a callback that is called whenever an MCP server comes online."""
        self._on_server_connected = callback

    def start_health_check(self) -> None:
        """Start the background health-check loop if not already running."""
        if self._health_check_task is None:
            self._health_check_task = asyncio.create_task(self._health_check_loop())

    async def stop_health_check(self) -> None:
        """Stop the background health-check loop."""
        if self._health_check_task:
            self._health_check_task.cancel()
            try:
                await self._health_check_task
            except asyncio.CancelledError:
                pass
            self._health_check_task = None

    def _get_configs(self) -> dict[str, McpServerConfig]:
        """Return cached configs, falling back to disk if not yet loaded."""
        if self._configs is None:
            self._configs = load_mcp_servers(self.config_path)
        return self._configs

    async def load_all(self, configs: dict[str, McpServerConfig] | None = None) -> None:
        """Connect to every server in *configs* (or load from disk).

        Existing clients are disconnected first so that a reload is clean.
        Servers that fail to connect are still added to the pool so the
        health-check loop can retry them later.
        """
        if configs is None:
            configs = load_mcp_servers(self.config_path)
        self._configs = configs

        # disconnect old
        await self.disconnect_all()

        # connect new
        for name, cfg in configs.items():
            client = McpClient(name, cfg)
            try:
                await client.connect()
                self._clients[name] = client
                self._connected_status[name] = True
                if self._on_server_connected:
                    await self._on_server_connected(name)
            except Exception as exc:
                logger.warning("MCP server %r failed to connect: %s", name, exc)
                # Keep the client in the pool so health-check can retry later
                self._clients[name] = client
                self._connected_status[name] = False

    async def reload_all(self) -> None:
        """Re-read the config file and reconnect every server."""
        self._configs = None  # bust cache
        configs = self._get_configs()
        await self.load_all(configs)

    async def disconnect_all(self) -> None:
        """Close every open connection and clear the client pool."""
        if not self._clients:
            return
        for name, client in list(self._clients.items()):
            try:
                await client.disconnect()
            except asyncio.CancelledError:
                pass
            except Exception as exc:
                logger.warning("MCP server %r disconnect error: %s", name, exc)
        self._clients.clear()
        self._connected_status.clear()

    async def get_all_tools(self) -> list[tuple[str, Any]]:
        """Return ``(server_name, mcp_tool)`` for every tool on every server.

        Servers that fail to list tools are skipped gracefully.
        """
        out: list[tuple[str, Any]] = []
        for name, client in self._clients.items():
            try:
                tools = await client.list_tools()
                out.extend((name, t) for t in tools)
            except Exception as exc:
                logger.warning("MCP server %r list_tools failed: %s", name, exc)
        return out

    def resolve_group(self, server_name: str, group_name: str) -> list[str]:
        """Return the list of tool names in a server group.

        Reads from the static config (``mcp_servers.d/*.json``), not from the live
        server, so it works even when the server is temporarily disconnected.
        """
        configs = self._get_configs()
        cfg = configs.get(server_name)
        if cfg is None:
            return []
        return list(cfg.groups.get(group_name, []))

    def get_instructions(self, server_names: list[str] | set[str] | None = None) -> dict[str, str]:
        """Return combined instructions for every connected server.

        Server-provided instructions (from MCP initialize handshake) are merged
        with the overlay ``instructions`` field from ``mcp_servers.d/*.json``.
        If a selected server is disconnected, only the config overlay is returned.
        """
        configs = self._get_configs()
        out: dict[str, str] = {}
        if server_names is None:
            names = set(self._clients.keys())
        else:
            names = set(server_names)
        for name in names:
            client = self._clients.get(name)
            parts: list[str] = []
            if client and client.instructions:
                parts.append(client.instructions)
            cfg = configs.get(name)
            if cfg and cfg.instructions:
                if parts:
                    parts.append("")
                parts.append(cfg.instructions)
            if parts:
                out[name] = "\n".join(parts)
        return out

    async def call_tool(self, server_name: str, tool_name: str, arguments: dict[str, Any] | None = None) -> tuple[str, bool]:
        """Proxy a tool call to the named server.

        Returns (output_text, is_error) so the caller knows whether the MCP
        tool itself reported a failure.
        """
        client = self._clients.get(server_name)
        if client is None:
            raise RuntimeError(f"MCP server {server_name!r} is not connected")
        return await client.call_tool(tool_name, arguments)

    # ── Health check ─────────────────────────────────────────────────────────

    async def _health_check_loop(self) -> None:
        """Background task that periodically probes every configured server."""
        while True:
            try:
                await asyncio.sleep(self._health_check_interval)
                await self._run_health_check()
            except asyncio.CancelledError:
                raise
            except Exception:
                logger.exception("MCP health-check loop error")
                # Brief sleep to avoid tight error loops
                await asyncio.sleep(5.0)

    async def _run_health_check(self) -> None:
        """Probe all servers: reconnect dead ones, verify live ones are still alive."""
        from navi.core.event_bus import get_event_bus
        from navi.core.events import McpStatusUpdate

        for name, client in list(self._clients.items()):
            if not client.connected:
                # Dead server — try to bring it back
                try:
                    await client.connect()
                    logger.info("MCP server %r reconnected by health check", name)
                    self._connected_status[name] = True
                    if self._on_server_connected:
                        await self._on_server_connected(name)
                except Exception as exc:
                    logger.debug("MCP health-check reconnect failed for %r: %s", name, exc)
                    continue
                continue

            # Live server — make sure it still responds
            try:
                await client.list_tools()
            except Exception:
                logger.warning("MCP server %r dropped during health check", name)
                client.mark_disconnected()
                self._connected_status[name] = False
                await get_event_bus().publish(
                    McpStatusUpdate(server_name=name, status="disconnected")
                )
                continue

            # Server is still connected — only notify if it was previously
            # known as disconnected (recovery), not on every routine poll.
            if self._connected_status.get(name):
                continue
            self._connected_status[name] = True
            try:
                tools = await client.list_tools()
                await get_event_bus().publish(
                    McpStatusUpdate(
                        server_name=name,
                        status="connected",
                        tool_count=len(tools),
                    )
                )
            except Exception:
                pass