from __future__ import annotations
import asyncio
import logging
from pathlib import Path
from typing import Any
from .client import McpClient
from .config import McpServerConfig, load_mcp_servers
logger = logging.getLogger(__name__)
class McpManager:
"""Holds a pool of :class:`McpClient` instances and manages their lifecycle.
Typical usage at application startup::
manager = McpManager()
await manager.load_all()
# … later, on reload_tools built-in invocation …
await manager.reload_all()
"""
def __init__(self, config_path: str | Path | None = None) -> None:
self.config_path = config_path
self._clients: dict[str, McpClient] = {}
@property
def clients(self) -> dict[str, McpClient]:
return self._clients
async def load_all(self, configs: dict[str, McpServerConfig] | None = None) -> None:
"""Connect to every server in *configs* (or load from disk).
Existing clients are disconnected first so that a reload is clean.
"""
if configs is None:
configs = load_mcp_servers(self.config_path)
# disconnect old
await self.disconnect_all()
# connect new
for name, cfg in configs.items():
client = McpClient(name, cfg)
try:
await client.connect()
self._clients[name] = client
except Exception as exc:
logger.warning("MCP server %r failed to connect: %s", name, exc)
async def reload_all(self) -> None:
"""Re-read the config file and reconnect every server."""
configs = load_mcp_servers(self.config_path)
await self.load_all(configs)
async def disconnect_all(self) -> None:
"""Close every open connection and clear the client pool."""
if not self._clients:
return
try:
results = await asyncio.gather(
*[c.disconnect() for c in self._clients.values()],
return_exceptions=True,
)
for name, exc in zip(self._clients, results):
if isinstance(exc, Exception):
logger.warning("MCP server %r disconnect error: %s", name, exc)
except (asyncio.CancelledError, RuntimeError):
# Shutdown-time cancellation from anyio task scopes — safe to ignore
pass
self._clients.clear()
async def get_all_tools(self) -> list[tuple[str, Any]]:
"""Return ``(server_name, mcp_tool)`` for every tool on every server.
Servers that fail to list tools are skipped gracefully.
"""
out: list[tuple[str, Any]] = []
for name, client in self._clients.items():
try:
tools = await client.list_tools()
out.extend((name, t) for t in tools)
except Exception as exc:
logger.warning("MCP server %r list_tools failed: %s", name, exc)
return out
def resolve_group(self, server_name: str, group_name: str) -> list[str]:
"""Return the list of tool names in a server group.
Reads from the static config (``mcp_servers.json``), not from the live
server, so it works even when the server is temporarily disconnected.
"""
configs = load_mcp_servers(self.config_path)
cfg = configs.get(server_name)
if cfg is None:
return []
return list(cfg.groups.get(group_name, []))
def get_instructions(self, server_names: list[str] | set[str] | None = None) -> dict[str, str]:
"""Return combined instructions for every connected server.
Server-provided instructions (from MCP initialize handshake) are merged
with the overlay ``instructions`` field from ``mcp_servers.json``.
If a selected server is disconnected, only the config overlay is returned.
"""
configs = load_mcp_servers(self.config_path)
out: dict[str, str] = {}
if server_names is None:
names = set(self._clients.keys())
else:
names = set(server_names)
for name in names:
client = self._clients.get(name)
parts: list[str] = []
if client and client.instructions:
parts.append(client.instructions)
cfg = configs.get(name)
if cfg and cfg.instructions:
if parts:
parts.append("")
parts.append(cfg.instructions)
if parts:
out[name] = "\n".join(parts)
return out
async def call_tool(self, server_name: str, tool_name: str, arguments: dict[str, Any] | None = None) -> str:
"""Proxy a tool call to the named server."""
client = self._clients.get(server_name)
if client is None:
raise RuntimeError(f"MCP server {server_name!r} is not connected")
return await client.call_tool(tool_name, arguments)