from __future__ import annotations
import asyncio
import logging
from pathlib import Path
from typing import Any, Awaitable, Callable
from .client import McpClient
from .config import McpServerConfig, load_mcp_servers
logger = logging.getLogger(__name__)
class McpManager:
"""Holds a pool of :class:`McpClient` instances and manages their lifecycle.
Typical usage at application startup::
manager = McpManager()
await manager.load_all()
# … later, on reload_tools built-in invocation …
await manager.reload_all()
"""
def __init__(self, config_path: str | Path | None = None) -> None:
self.config_path = config_path
self._clients: dict[str, McpClient] = {}
self._configs: dict[str, McpServerConfig] | None = None
# Callback invoked after a server successfully connects (or reconnects).
# Signature: async def callback(server_name: str) -> None
self._on_server_connected: Callable[[str], Awaitable[None]] | None = None
# Background health-check task
self._health_check_task: asyncio.Task | None = None
self._health_check_interval: float = 30.0
# Last known connected status per server — used to suppress duplicate
# "connected" toast notifications on every health-check poll.
self._connected_status: dict[str, bool] = {}
@property
def clients(self) -> dict[str, McpClient]:
return self._clients
def set_on_server_connected(self, callback: Callable[[str], Awaitable[None]] | None) -> None:
"""Set a callback that is called whenever an MCP server comes online."""
self._on_server_connected = callback
def start_health_check(self) -> None:
"""Start the background health-check loop if not already running."""
if self._health_check_task is None:
self._health_check_task = asyncio.create_task(self._health_check_loop())
async def stop_health_check(self) -> None:
"""Stop the background health-check loop."""
if self._health_check_task:
self._health_check_task.cancel()
try:
await self._health_check_task
except asyncio.CancelledError:
pass
self._health_check_task = None
def _get_configs(self) -> dict[str, McpServerConfig]:
"""Return cached configs, falling back to disk if not yet loaded."""
if self._configs is None:
self._configs = load_mcp_servers(self.config_path)
return self._configs
async def load_all(self, configs: dict[str, McpServerConfig] | None = None) -> None:
"""Connect to every server in *configs* (or load from disk).
Existing clients are disconnected first so that a reload is clean.
Servers that fail to connect are still added to the pool so the
health-check loop can retry them later.
"""
if configs is None:
configs = load_mcp_servers(self.config_path)
self._configs = configs
# disconnect old
await self.disconnect_all()
# connect new
for name, cfg in configs.items():
client = McpClient(name, cfg)
try:
await client.connect()
self._clients[name] = client
self._connected_status[name] = True
if self._on_server_connected:
await self._on_server_connected(name)
except Exception as exc:
logger.warning("MCP server %r failed to connect: %s", name, exc)
# Keep the client in the pool so health-check can retry later
self._clients[name] = client
self._connected_status[name] = False
async def reload_all(self) -> None:
"""Re-read the config file and reconnect every server."""
self._configs = None # bust cache
configs = self._get_configs()
await self.load_all(configs)
async def disconnect_all(self) -> None:
"""Close every open connection and clear the client pool."""
if not self._clients:
return
for name, client in list(self._clients.items()):
try:
await client.disconnect()
except asyncio.CancelledError:
pass
except Exception as exc:
logger.warning("MCP server %r disconnect error: %s", name, exc)
self._clients.clear()
self._connected_status.clear()
async def get_all_tools(self) -> list[tuple[str, Any]]:
"""Return ``(server_name, mcp_tool)`` for every tool on every server.
Servers that fail to list tools are skipped gracefully.
"""
out: list[tuple[str, Any]] = []
for name, client in self._clients.items():
try:
tools = await client.list_tools()
out.extend((name, t) for t in tools)
except Exception as exc:
logger.warning("MCP server %r list_tools failed: %s", name, exc)
return out
def resolve_group(self, server_name: str, group_name: str) -> list[str]:
"""Return the list of tool names in a server group.
Reads from the static config (``mcp_servers.d/*.json``), not from the live
server, so it works even when the server is temporarily disconnected.
"""
configs = self._get_configs()
cfg = configs.get(server_name)
if cfg is None:
return []
return list(cfg.groups.get(group_name, []))
def get_instructions(self, server_names: list[str] | set[str] | None = None) -> dict[str, str]:
"""Return combined instructions for every connected server.
Server-provided instructions (from MCP initialize handshake) are merged
with the overlay ``instructions`` field from ``mcp_servers.d/*.json``.
If a selected server is disconnected, only the config overlay is returned.
"""
configs = self._get_configs()
out: dict[str, str] = {}
if server_names is None:
names = set(self._clients.keys())
else:
names = set(server_names)
for name in names:
client = self._clients.get(name)
parts: list[str] = []
if client and client.instructions:
parts.append(client.instructions)
cfg = configs.get(name)
if cfg and cfg.instructions:
if parts:
parts.append("")
parts.append(cfg.instructions)
if parts:
out[name] = "\n".join(parts)
return out
async def call_tool(self, server_name: str, tool_name: str, arguments: dict[str, Any] | None = None) -> tuple[str, bool]:
"""Proxy a tool call to the named server.
Returns (output_text, is_error) so the caller knows whether the MCP
tool itself reported a failure.
"""
client = self._clients.get(server_name)
if client is None:
raise RuntimeError(f"MCP server {server_name!r} is not connected")
return await client.call_tool(tool_name, arguments)
# ── Health check ─────────────────────────────────────────────────────────
async def _health_check_loop(self) -> None:
"""Background task that periodically probes every configured server."""
while True:
try:
await asyncio.sleep(self._health_check_interval)
await self._run_health_check()
except asyncio.CancelledError:
raise
except Exception:
logger.exception("MCP health-check loop error")
# Brief sleep to avoid tight error loops
await asyncio.sleep(5.0)
async def _run_health_check(self) -> None:
"""Probe all servers: reconnect dead ones, verify live ones are still alive."""
from navi.core.event_bus import get_event_bus
from navi.core.events import McpStatusUpdate
for name, client in list(self._clients.items()):
if not client.connected:
# Dead server — try to bring it back
try:
await client.connect()
logger.info("MCP server %r reconnected by health check", name)
self._connected_status[name] = True
if self._on_server_connected:
await self._on_server_connected(name)
except Exception as exc:
logger.debug("MCP health-check reconnect failed for %r: %s", name, exc)
continue
continue
# Live server — make sure it still responds
try:
await client.list_tools()
except Exception:
logger.warning("MCP server %r dropped during health check", name)
client.mark_disconnected()
self._connected_status[name] = False
await get_event_bus().publish(
McpStatusUpdate(server_name=name, status="disconnected")
)
continue
# Server is still connected — only notify if it was previously
# known as disconnected (recovery), not on every routine poll.
if self._connected_status.get(name):
continue
self._connected_status[name] = True
try:
tools = await client.list_tools()
await get_event_bus().publish(
McpStatusUpdate(
server_name=name,
status="connected",
tool_count=len(tools),
)
)
except Exception:
pass