Newer
Older
navi-1 / navi / mcp / client.py
from __future__ import annotations

import asyncio
import logging
from contextlib import AsyncExitStack
from typing import Any

import anyio
from mcp import ClientSession
from mcp.client.sse import sse_client
from mcp.client.stdio import StdioServerParameters, stdio_client
from mcp.types import Tool

from .config import McpServerConfig

logger = logging.getLogger(__name__)


class McpClient:
    """Lightweight wrapper around the official Python MCP client SDK.

    Manages a single server connection (stdio or SSE), exposes
    ``list_tools()`` and ``call_tool()``, and handles lifecycle
    (connect / disconnect).
    """

    def __init__(self, name: str, config: McpServerConfig) -> None:
        self.name = name
        self.config = config
        self._session: ClientSession | None = None
        self._exit_stack = AsyncExitStack()
        self._connected = False
        self._instructions: str | None = None

    @property
    def connected(self) -> bool:
        return self._connected and self._session is not None

    @property
    def instructions(self) -> str | None:
        """Server-provided instructions from MCP initialize handshake."""
        return self._instructions

    async def connect(self) -> None:
        """Open transport, initialise session, and store it."""
        if self._connected:
            return

        try:
            if self.config.is_stdio:
                if not self.config.command:
                    raise ValueError("stdio transport requires 'command'")
                params = StdioServerParameters(
                    command=self.config.command,
                    args=self.config.args or [],
                    env=self.config.env,
                    cwd=self.config.cwd,
                )
                transport = await self._exit_stack.enter_async_context(
                    stdio_client(params)
                )
            elif self.config.is_sse:
                if not self.config.url:
                    raise ValueError("sse transport requires 'url'")
                transport = await self._exit_stack.enter_async_context(
                    sse_client(
                        self.config.url,
                        headers=self.config.headers,
                    )
                )
            else:
                raise ValueError(f"unknown transport: {self.config.transport}")

            read_stream, write_stream = transport
            session = await self._exit_stack.enter_async_context(
                ClientSession(read_stream, write_stream)
            )
            init_result = await session.initialize()
            self._instructions = init_result.instructions if hasattr(init_result, "instructions") else None
            self._session = session
            self._connected = True
            logger.info(
                "MCP server %r connected (%s)",
                self.name,
                self.config.transport,
            )
        except Exception:
            await self._cleanup()
            raise

    async def disconnect(self) -> None:
        """Close transport and reset state."""
        if not self._connected:
            return
        try:
            await self._cleanup()
        except (asyncio.CancelledError, RuntimeError):
            # Graceful shutdown during app teardown — SSE transport teardown
            # throws CancelledError / RuntimeError from anyio task scopes.
            pass

    async def _cleanup(self) -> None:
        try:
            await self._exit_stack.aclose()
        except Exception:
            pass
        finally:
            self._session = None
            self._instructions = None
            self._connected = False
            self._exit_stack = AsyncExitStack()

    async def list_tools(self) -> list[Tool]:
        """Return the tools exposed by the remote MCP server."""
        if not self._session:
            raise RuntimeError("Not connected")
        result = await self._session.list_tools()
        return list(result.tools)

    async def call_tool(self, tool_name: str, arguments: dict[str, Any] | None = None) -> str:
        """Execute a remote tool and return its output as a string.

        Text content is concatenated; images are reported as a placeholder.
        """
        if not self._session:
            raise RuntimeError("Not connected")

        result = await self._session.call_tool(tool_name, arguments or {})

        parts: list[str] = []
        for item in result.content:
            if item.type == "text":
                parts.append(item.text)
            elif item.type == "image":
                parts.append(f"[image: {item.mimeType} ({len(item.data)} bytes)]")
            else:
                parts.append(f"[{item.type}]")

        return "\n".join(parts)