Newer
Older
navi-1 / navi / mcp / client.py
@Eugene Sukhodolskiy Eugene Sukhodolskiy on 24 May 8 KB Fix MCP tool registration at startup
from __future__ import annotations

import asyncio
import logging
import time
from contextlib import AsyncExitStack
from typing import Any

import anyio
import httpx
from mcp import ClientSession
from mcp.client.sse import sse_client
from mcp.client.stdio import StdioServerParameters, stdio_client
from mcp.client.streamable_http import streamable_http_client
from mcp.types import Tool

from .config import McpServerConfig

logger = logging.getLogger(__name__)


class McpClient:
    """Lightweight wrapper around the official Python MCP client SDK.

    Manages a single server connection (stdio or SSE), exposes
    ``list_tools()`` and ``call_tool()``, and handles lifecycle
    (connect / disconnect).

    Reconnect uses exponential backoff (base 1s, max 30s, ±20% jitter)
    so that a flapping server does not hammer the transport.
    """

    def __init__(self, name: str, config: McpServerConfig) -> None:
        self.name = name
        self.config = config
        self._session: ClientSession | None = None
        self._exit_stack = AsyncExitStack()
        self._connected = False
        self._instructions: str | None = None

        # Exponential backoff state for reconnect
        self._last_reconnect_attempt: float | None = None
        self._reconnect_backoff: float = 1.0
        self._max_reconnect_backoff: float = 30.0
        self._reconnect_jitter: float = 0.2

    @property
    def connected(self) -> bool:
        return self._connected and self._session is not None

    @property
    def instructions(self) -> str | None:
        """Server-provided instructions from MCP initialize handshake."""
        return self._instructions

    async def connect(self) -> None:
        """Open transport, initialise session, and store it."""
        if self._connected:
            return

        try:
            if self.config.is_stdio:
                if not self.config.command:
                    raise ValueError("stdio transport requires 'command'")
                params = StdioServerParameters(
                    command=self.config.command,
                    args=self.config.args or [],
                    env=self.config.env,
                    cwd=self.config.cwd,
                )
                transport = await self._exit_stack.enter_async_context(
                    stdio_client(params)
                )
            elif self.config.is_sse:
                if not self.config.url:
                    raise ValueError("sse transport requires 'url'")
                transport = await self._exit_stack.enter_async_context(
                    sse_client(
                        self.config.url,
                        headers=self.config.headers,
                    )
                )
            elif self.config.is_streamable_http:
                if not self.config.url:
                    raise ValueError("streamable_http transport requires 'url'")
                http_client = await self._exit_stack.enter_async_context(
                    httpx.AsyncClient(headers=self.config.headers or {})
                )
                transport = await self._exit_stack.enter_async_context(
                    streamable_http_client(self.config.url, http_client=http_client)
                )
                # streamable_http_client returns (read, write, get_session_id).
                # We only need read and write for ClientSession.
                transport = transport[:2]
            else:
                raise ValueError(f"unknown transport: {self.config.transport}")

            read_stream, write_stream = transport
            session = await self._exit_stack.enter_async_context(
                ClientSession(read_stream, write_stream)
            )
            init_result = await session.initialize()
            self._instructions = init_result.instructions if hasattr(init_result, "instructions") else None
            self._session = session
            self._connected = True
            # Reset backoff on successful connect
            self._reconnect_backoff = 1.0
            self._last_reconnect_attempt = None
            logger.info(
                "MCP server %r connected (%s)",
                self.name,
                self.config.transport,
            )
        except Exception:
            await self._cleanup()
            raise

    async def disconnect(self) -> None:
        """Close transport and reset state."""
        if not self._connected:
            return
        try:
            # Do NOT use asyncio.wait_for here — it spawns a new task,
            # and MCP transports (stdio, sse, streamable_http) use anyio
            # task groups that require __aexit__ in the same task as __aenter__.
            await self._cleanup()
        except asyncio.CancelledError:
            # Graceful shutdown during app teardown.
            pass

    def mark_disconnected(self) -> None:
        """Force the client into a disconnected state without closing transport.

        Used by the health-check loop when a server drops silently.
        """
        self._connected = False

    async def _cleanup(self) -> None:
        try:
            await self._exit_stack.aclose()
        except SystemExit:
            raise
        except BaseException:
            # Graceful shutdown noise from MCP transport task groups —
            # anyio raises RuntimeError / BaseExceptionGroup when an async
            # generator is closed from a different task than it was created in.
            pass
        finally:
            self._session = None
            self._instructions = None
            self._connected = False
            self._exit_stack = AsyncExitStack()

    def _check_backoff(self) -> bool:
        """Return True if enough time has passed since the last reconnect attempt."""
        if self._last_reconnect_attempt is None:
            return True
        elapsed = time.monotonic() - self._last_reconnect_attempt
        # Add jitter to prevent thundering herd
        jitter = self._reconnect_backoff * self._reconnect_jitter * (2 * (time.monotonic() % 1) - 1)
        return elapsed >= (self._reconnect_backoff + jitter)

    async def _ensure_connected(self) -> None:
        """Reconnect if the underlying transport is dead, respecting backoff."""
        if self._connected and self._session is not None:
            return

        if not self._check_backoff():
            remaining = self._reconnect_backoff - (time.monotonic() - (self._last_reconnect_attempt or 0))
            logger.warning(
                "MCP server %r reconnect blocked by backoff (%.1fs remaining)",
                self.name,
                max(0, remaining),
            )
            raise RuntimeError(
                f"MCP server {self.name!r} is disconnected and reconnect is throttled"
            )

        self._last_reconnect_attempt = time.monotonic()
        logger.warning("MCP server %r disconnected, reconnecting...", self.name)
        try:
            await self._cleanup()
            await self.connect()
        except Exception:
            # Double the backoff for the next attempt (capped at max)
            self._reconnect_backoff = min(
                self._reconnect_backoff * 2,
                self._max_reconnect_backoff,
            )
            raise

    async def list_tools(self) -> list[Tool]:
        """Return the tools exposed by the remote MCP server."""
        await self._ensure_connected()
        try:
            result = await self._session.list_tools()
        except Exception:
            await self._cleanup()
            await self._ensure_connected()
            result = await self._session.list_tools()
        return list(result.tools)

    async def call_tool(self, tool_name: str, arguments: dict[str, Any] | None = None) -> tuple[str, bool]:
        """Execute a remote tool and return its output as a string.

        Returns (output_text, is_error).  ``is_error`` comes from the MCP
        ``CallToolResult.isError`` field so the caller can set ``success=False``.

        Text content is concatenated; images are reported as a placeholder.
        """
        await self._ensure_connected()
        try:
            result = await self._session.call_tool(tool_name, arguments or {})
        except Exception:
            await self._cleanup()
            await self._ensure_connected()
            result = await self._session.call_tool(tool_name, arguments or {})

        parts: list[str] = []
        for item in result.content:
            if item.type == "text":
                parts.append(item.text)
            elif item.type == "image":
                parts.append(f"[image: {item.mimeType} ({len(item.data)} bytes)]")
            else:
                parts.append(f"[{item.type}]")

        is_error = getattr(result, "isError", False)
        return "\n".join(parts), is_error