diff --git a/navi/mcp/client.py b/navi/mcp/client.py index 94f5e3c..21c1d0c 100644 --- a/navi/mcp/client.py +++ b/navi/mcp/client.py @@ -2,6 +2,7 @@ import asyncio import logging +import time from contextlib import AsyncExitStack from typing import Any @@ -22,6 +23,9 @@ Manages a single server connection (stdio or SSE), exposes ``list_tools()`` and ``call_tool()``, and handles lifecycle (connect / disconnect). + + Reconnect uses exponential backoff (base 1s, max 30s, ±20% jitter) + so that a flapping server does not hammer the transport. """ def __init__(self, name: str, config: McpServerConfig) -> None: @@ -32,6 +36,12 @@ self._connected = False self._instructions: str | None = None + # Exponential backoff state for reconnect + self._last_reconnect_attempt: float | None = None + self._reconnect_backoff: float = 1.0 + self._max_reconnect_backoff: float = 30.0 + self._reconnect_jitter: float = 0.2 + @property def connected(self) -> bool: return self._connected and self._session is not None @@ -79,6 +89,9 @@ self._instructions = init_result.instructions if hasattr(init_result, "instructions") else None self._session = session self._connected = True + # Reset backoff on successful connect + self._reconnect_backoff = 1.0 + self._last_reconnect_attempt = None logger.info( "MCP server %r connected (%s)", self.name, @@ -112,12 +125,43 @@ self._connected = False self._exit_stack = AsyncExitStack() + def _check_backoff(self) -> bool: + """Return True if enough time has passed since the last reconnect attempt.""" + if self._last_reconnect_attempt is None: + return True + elapsed = time.monotonic() - self._last_reconnect_attempt + # Add jitter to prevent thundering herd + jitter = self._reconnect_backoff * self._reconnect_jitter * (2 * (time.monotonic() % 1) - 1) + return elapsed >= (self._reconnect_backoff + jitter) + async def _ensure_connected(self) -> None: - """Reconnect if the underlying transport is dead.""" - if not self._connected or self._session is None: - logger.warning("MCP server %r disconnected, reconnecting...", self.name) + """Reconnect if the underlying transport is dead, respecting backoff.""" + if self._connected and self._session is not None: + return + + if not self._check_backoff(): + remaining = self._reconnect_backoff - (time.monotonic() - (self._last_reconnect_attempt or 0)) + logger.warning( + "MCP server %r reconnect blocked by backoff (%.1fs remaining)", + self.name, + max(0, remaining), + ) + raise RuntimeError( + f"MCP server {self.name!r} is disconnected and reconnect is throttled" + ) + + self._last_reconnect_attempt = time.monotonic() + logger.warning("MCP server %r disconnected, reconnecting...", self.name) + try: await self._cleanup() await self.connect() + except Exception: + # Double the backoff for the next attempt (capped at max) + self._reconnect_backoff = min( + self._reconnect_backoff * 2, + self._max_reconnect_backoff, + ) + raise async def list_tools(self) -> list[Tool]: """Return the tools exposed by the remote MCP server.""" @@ -126,7 +170,7 @@ result = await self._session.list_tools() except Exception: await self._cleanup() - await self.connect() + await self._ensure_connected() result = await self._session.list_tools() return list(result.tools) @@ -143,7 +187,7 @@ result = await self._session.call_tool(tool_name, arguments or {}) except Exception: await self._cleanup() - await self.connect() + await self._ensure_connected() result = await self._session.call_tool(tool_name, arguments or {}) parts: list[str] = [] diff --git a/navi/mcp/manager.py b/navi/mcp/manager.py index 63a4b98..2379e7a 100644 --- a/navi/mcp/manager.py +++ b/navi/mcp/manager.py @@ -25,11 +25,18 @@ def __init__(self, config_path: str | Path | None = None) -> None: self.config_path = config_path self._clients: dict[str, McpClient] = {} + self._configs: dict[str, McpServerConfig] | None = None @property def clients(self) -> dict[str, McpClient]: return self._clients + def _get_configs(self) -> dict[str, McpServerConfig]: + """Return cached configs, falling back to disk if not yet loaded.""" + if self._configs is None: + self._configs = load_mcp_servers(self.config_path) + return self._configs + async def load_all(self, configs: dict[str, McpServerConfig] | None = None) -> None: """Connect to every server in *configs* (or load from disk). @@ -37,6 +44,7 @@ """ if configs is None: configs = load_mcp_servers(self.config_path) + self._configs = configs # disconnect old await self.disconnect_all() @@ -52,7 +60,8 @@ async def reload_all(self) -> None: """Re-read the config file and reconnect every server.""" - configs = load_mcp_servers(self.config_path) + self._configs = None # bust cache + configs = self._get_configs() await self.load_all(configs) async def disconnect_all(self) -> None: @@ -92,7 +101,7 @@ Reads from the static config (``mcp_servers.d/*.json``), not from the live server, so it works even when the server is temporarily disconnected. """ - configs = load_mcp_servers(self.config_path) + configs = self._get_configs() cfg = configs.get(server_name) if cfg is None: return [] @@ -105,7 +114,7 @@ with the overlay ``instructions`` field from ``mcp_servers.d/*.json``. If a selected server is disconnected, only the config overlay is returned. """ - configs = load_mcp_servers(self.config_path) + configs = self._get_configs() out: dict[str, str] = {} if server_names is None: names = set(self._clients.keys())