from __future__ import annotations
import asyncio
import logging
import time
from contextlib import AsyncExitStack
from typing import Any
import anyio
import httpx
from mcp import ClientSession
from mcp.client.sse import sse_client
from mcp.client.stdio import StdioServerParameters, stdio_client
from mcp.client.streamable_http import streamable_http_client
from mcp.types import Tool
from .config import McpServerConfig
logger = logging.getLogger(__name__)
class McpClient:
"""Lightweight wrapper around the official Python MCP client SDK.
Manages a single server connection (stdio or SSE), exposes
``list_tools()`` and ``call_tool()``, and handles lifecycle
(connect / disconnect).
Reconnect uses exponential backoff (base 1s, max 30s, ±20% jitter)
so that a flapping server does not hammer the transport.
"""
def __init__(self, name: str, config: McpServerConfig) -> None:
self.name = name
self.config = config
self._session: ClientSession | None = None
self._exit_stack = AsyncExitStack()
self._connected = False
self._instructions: str | None = None
# Exponential backoff state for reconnect
self._last_reconnect_attempt: float | None = None
self._reconnect_backoff: float = 1.0
self._max_reconnect_backoff: float = 30.0
self._reconnect_jitter: float = 0.2
@property
def connected(self) -> bool:
return self._connected and self._session is not None
@property
def instructions(self) -> str | None:
"""Server-provided instructions from MCP initialize handshake."""
return self._instructions
async def connect(self) -> None:
"""Open transport, initialise session, and store it."""
if self._connected:
return
try:
if self.config.is_stdio:
if not self.config.command:
raise ValueError("stdio transport requires 'command'")
params = StdioServerParameters(
command=self.config.command,
args=self.config.args or [],
env=self.config.env,
cwd=self.config.cwd,
)
transport = await self._exit_stack.enter_async_context(
stdio_client(params)
)
elif self.config.is_sse:
if not self.config.url:
raise ValueError("sse transport requires 'url'")
transport = await self._exit_stack.enter_async_context(
sse_client(
self.config.url,
headers=self.config.headers,
)
)
elif self.config.is_streamable_http:
if not self.config.url:
raise ValueError("streamable_http transport requires 'url'")
http_client = await self._exit_stack.enter_async_context(
httpx.AsyncClient(headers=self.config.headers or {})
)
transport = await self._exit_stack.enter_async_context(
streamable_http_client(self.config.url, http_client=http_client)
)
# streamable_http_client returns (read, write, get_session_id).
# We only need read and write for ClientSession.
transport = transport[:2]
else:
raise ValueError(f"unknown transport: {self.config.transport}")
read_stream, write_stream = transport
session = await self._exit_stack.enter_async_context(
ClientSession(read_stream, write_stream)
)
init_result = await session.initialize()
self._instructions = init_result.instructions if hasattr(init_result, "instructions") else None
self._session = session
self._connected = True
# Reset backoff on successful connect
self._reconnect_backoff = 1.0
self._last_reconnect_attempt = None
logger.info(
"MCP server %r connected (%s)",
self.name,
self.config.transport,
)
except Exception:
await self._cleanup()
raise
async def disconnect(self) -> None:
"""Close transport and reset state."""
if not self._connected:
return
try:
await asyncio.wait_for(self._cleanup(), timeout=5.0)
except asyncio.TimeoutError:
logger.warning("MCP server %r disconnect timed out, forcing cleanup", self.name)
except (asyncio.CancelledError, RuntimeError):
# Graceful shutdown during app teardown — SSE transport teardown
# throws CancelledError / RuntimeError from anyio task scopes.
pass
async def _cleanup(self) -> None:
try:
await self._exit_stack.aclose()
except Exception:
pass
finally:
self._session = None
self._instructions = None
self._connected = False
self._exit_stack = AsyncExitStack()
def _check_backoff(self) -> bool:
"""Return True if enough time has passed since the last reconnect attempt."""
if self._last_reconnect_attempt is None:
return True
elapsed = time.monotonic() - self._last_reconnect_attempt
# Add jitter to prevent thundering herd
jitter = self._reconnect_backoff * self._reconnect_jitter * (2 * (time.monotonic() % 1) - 1)
return elapsed >= (self._reconnect_backoff + jitter)
async def _ensure_connected(self) -> None:
"""Reconnect if the underlying transport is dead, respecting backoff."""
if self._connected and self._session is not None:
return
if not self._check_backoff():
remaining = self._reconnect_backoff - (time.monotonic() - (self._last_reconnect_attempt or 0))
logger.warning(
"MCP server %r reconnect blocked by backoff (%.1fs remaining)",
self.name,
max(0, remaining),
)
raise RuntimeError(
f"MCP server {self.name!r} is disconnected and reconnect is throttled"
)
self._last_reconnect_attempt = time.monotonic()
logger.warning("MCP server %r disconnected, reconnecting...", self.name)
try:
await self._cleanup()
await self.connect()
except Exception:
# Double the backoff for the next attempt (capped at max)
self._reconnect_backoff = min(
self._reconnect_backoff * 2,
self._max_reconnect_backoff,
)
raise
async def list_tools(self) -> list[Tool]:
"""Return the tools exposed by the remote MCP server."""
await self._ensure_connected()
try:
result = await self._session.list_tools()
except Exception:
await self._cleanup()
await self._ensure_connected()
result = await self._session.list_tools()
return list(result.tools)
async def call_tool(self, tool_name: str, arguments: dict[str, Any] | None = None) -> tuple[str, bool]:
"""Execute a remote tool and return its output as a string.
Returns (output_text, is_error). ``is_error`` comes from the MCP
``CallToolResult.isError`` field so the caller can set ``success=False``.
Text content is concatenated; images are reported as a placeholder.
"""
await self._ensure_connected()
try:
result = await self._session.call_tool(tool_name, arguments or {})
except Exception:
await self._cleanup()
await self._ensure_connected()
result = await self._session.call_tool(tool_name, arguments or {})
parts: list[str] = []
for item in result.content:
if item.type == "text":
parts.append(item.text)
elif item.type == "image":
parts.append(f"[image: {item.mimeType} ({len(item.data)} bytes)]")
else:
parts.append(f"[{item.type}]")
is_error = getattr(result, "isError", False)
return "\n".join(parts), is_error