from __future__ import annotations
import asyncio
import logging
from contextlib import AsyncExitStack
from typing import Any
import anyio
from mcp import ClientSession
from mcp.client.sse import sse_client
from mcp.client.stdio import StdioServerParameters, stdio_client
from mcp.types import Tool
from .config import McpServerConfig
logger = logging.getLogger(__name__)
class McpClient:
"""Lightweight wrapper around the official Python MCP client SDK.
Manages a single server connection (stdio or SSE), exposes
``list_tools()`` and ``call_tool()``, and handles lifecycle
(connect / disconnect).
"""
def __init__(self, name: str, config: McpServerConfig) -> None:
self.name = name
self.config = config
self._session: ClientSession | None = None
self._exit_stack = AsyncExitStack()
self._connected = False
self._instructions: str | None = None
@property
def connected(self) -> bool:
return self._connected and self._session is not None
@property
def instructions(self) -> str | None:
"""Server-provided instructions from MCP initialize handshake."""
return self._instructions
async def connect(self) -> None:
"""Open transport, initialise session, and store it."""
if self._connected:
return
try:
if self.config.is_stdio:
if not self.config.command:
raise ValueError("stdio transport requires 'command'")
params = StdioServerParameters(
command=self.config.command,
args=self.config.args or [],
env=self.config.env,
cwd=self.config.cwd,
)
transport = await self._exit_stack.enter_async_context(
stdio_client(params)
)
elif self.config.is_sse:
if not self.config.url:
raise ValueError("sse transport requires 'url'")
transport = await self._exit_stack.enter_async_context(
sse_client(
self.config.url,
headers=self.config.headers,
)
)
else:
raise ValueError(f"unknown transport: {self.config.transport}")
read_stream, write_stream = transport
session = await self._exit_stack.enter_async_context(
ClientSession(read_stream, write_stream)
)
init_result = await session.initialize()
self._instructions = init_result.instructions if hasattr(init_result, "instructions") else None
self._session = session
self._connected = True
logger.info(
"MCP server %r connected (%s)",
self.name,
self.config.transport,
)
except Exception:
await self._cleanup()
raise
async def disconnect(self) -> None:
"""Close transport and reset state."""
if not self._connected:
return
try:
await self._cleanup()
except (asyncio.CancelledError, RuntimeError):
# Graceful shutdown during app teardown — SSE transport teardown
# throws CancelledError / RuntimeError from anyio task scopes.
pass
async def _cleanup(self) -> None:
try:
await self._exit_stack.aclose()
except Exception:
pass
finally:
self._session = None
self._instructions = None
self._connected = False
self._exit_stack = AsyncExitStack()
async def list_tools(self) -> list[Tool]:
"""Return the tools exposed by the remote MCP server."""
if not self._session:
raise RuntimeError("Not connected")
result = await self._session.list_tools()
return list(result.tools)
async def call_tool(self, tool_name: str, arguments: dict[str, Any] | None = None) -> str:
"""Execute a remote tool and return its output as a string.
Text content is concatenated; images are reported as a placeholder.
"""
if not self._session:
raise RuntimeError("Not connected")
result = await self._session.call_tool(tool_name, arguments or {})
parts: list[str] = []
for item in result.content:
if item.type == "text":
parts.append(item.text)
elif item.type == "image":
parts.append(f"[image: {item.mimeType} ({len(item.data)} bytes)]")
else:
parts.append(f"[{item.type}]")
return "\n".join(parts)