from __future__ import annotations
import asyncio
import logging
from pathlib import Path
from typing import Any
from .client import McpClient
from .config import McpServerConfig, load_mcp_servers
logger = logging.getLogger(__name__)
class McpManager:
"""Holds a pool of :class:`McpClient` instances and manages their lifecycle.
Typical usage at application startup::
manager = McpManager()
await manager.load_all()
# … later, on reload_tools built-in invocation …
await manager.reload_all()
"""
def __init__(self, config_path: str | Path | None = None) -> None:
self.config_path = config_path
self._clients: dict[str, McpClient] = {}
@property
def clients(self) -> dict[str, McpClient]:
return self._clients
async def load_all(self, configs: dict[str, McpServerConfig] | None = None) -> None:
"""Connect to every server in *configs* (or load from disk).
Existing clients are disconnected first so that a reload is clean.
"""
if configs is None:
configs = load_mcp_servers(self.config_path)
# disconnect old
await self.disconnect_all()
# connect new
for name, cfg in configs.items():
client = McpClient(name, cfg)
try:
await client.connect()
self._clients[name] = client
except Exception as exc:
logger.warning("MCP server %r failed to connect: %s", name, exc)
async def reload_all(self) -> None:
"""Re-read the config file and reconnect every server."""
configs = load_mcp_servers(self.config_path)
await self.load_all(configs)
async def disconnect_all(self) -> None:
"""Close every open connection and clear the client pool."""
if not self._clients:
return
results = await asyncio.gather(
*[c.disconnect() for c in self._clients.values()],
return_exceptions=True,
)
for name, exc in zip(self._clients, results):
if isinstance(exc, Exception):
logger.warning("MCP server %r disconnect error: %s", name, exc)
self._clients.clear()
async def get_all_tools(self) -> list[tuple[str, Any]]:
"""Return ``(server_name, mcp_tool)`` for every tool on every server.
Servers that fail to list tools are skipped gracefully.
"""
out: list[tuple[str, Any]] = []
for name, client in self._clients.items():
try:
tools = await client.list_tools()
out.extend((name, t) for t in tools)
except Exception as exc:
logger.warning("MCP server %r list_tools failed: %s", name, exc)
return out
async def call_tool(self, server_name: str, tool_name: str, arguments: dict[str, Any] | None = None) -> str:
"""Proxy a tool call to the named server."""
client = self._clients.get(server_name)
if client is None:
raise RuntimeError(f"MCP server {server_name!r} is not connected")
return await client.call_tool(tool_name, arguments)