from __future__ import annotations

import asyncio
import logging
from pathlib import Path
from typing import Any

from .client import McpClient
from .config import McpServerConfig, load_mcp_servers

logger = logging.getLogger(__name__)


class McpManager:
    """Holds a pool of :class:`McpClient` instances and manages their lifecycle.

    Typical usage at application startup::

        manager = McpManager()
        await manager.load_all()
        # … later, on reload_tools built-in invocation …
        await manager.reload_all()
    """

    def __init__(self, config_path: str | Path | None = None) -> None:
        self.config_path = config_path
        self._clients: dict[str, McpClient] = {}

    @property
    def clients(self) -> dict[str, McpClient]:
        return self._clients

    async def load_all(self, configs: dict[str, McpServerConfig] | None = None) -> None:
        """Connect to every server in *configs* (or load from disk).

        Existing clients are disconnected first so that a reload is clean.
        """
        if configs is None:
            configs = load_mcp_servers(self.config_path)

        # disconnect old
        await self.disconnect_all()

        # connect new
        for name, cfg in configs.items():
            client = McpClient(name, cfg)
            try:
                await client.connect()
                self._clients[name] = client
            except Exception as exc:
                logger.warning("MCP server %r failed to connect: %s", name, exc)

    async def reload_all(self) -> None:
        """Re-read the config file and reconnect every server."""
        configs = load_mcp_servers(self.config_path)
        await self.load_all(configs)

    async def disconnect_all(self) -> None:
        """Close every open connection and clear the client pool."""
        if not self._clients:
            return
        try:
            results = await asyncio.gather(
                *[c.disconnect() for c in self._clients.values()],
                return_exceptions=True,
            )
            for name, exc in zip(self._clients, results):
                if isinstance(exc, Exception):
                    logger.warning("MCP server %r disconnect error: %s", name, exc)
        except (asyncio.CancelledError, RuntimeError):
            # Shutdown-time cancellation from anyio task scopes — safe to ignore
            pass
        self._clients.clear()

    async def get_all_tools(self) -> list[tuple[str, Any]]:
        """Return ``(server_name, mcp_tool)`` for every tool on every server.

        Servers that fail to list tools are skipped gracefully.
        """
        out: list[tuple[str, Any]] = []
        for name, client in self._clients.items():
            try:
                tools = await client.list_tools()
                out.extend((name, t) for t in tools)
            except Exception as exc:
                logger.warning("MCP server %r list_tools failed: %s", name, exc)
        return out

    def resolve_group(self, server_name: str, group_name: str) -> list[str]:
        """Return the list of tool names in a server group.

        Reads from the static config (``mcp_servers.json``), not from the live
        server, so it works even when the server is temporarily disconnected.
        """
        configs = load_mcp_servers(self.config_path)
        cfg = configs.get(server_name)
        if cfg is None:
            return []
        return list(cfg.groups.get(group_name, []))

    def get_instructions(self, server_names: list[str] | set[str] | None = None) -> dict[str, str]:
        """Return combined instructions for every connected server.

        Server-provided instructions (from MCP initialize handshake) are merged
        with the overlay ``instructions`` field from ``mcp_servers.json``.
        If a selected server is disconnected, only the config overlay is returned.
        """
        configs = load_mcp_servers(self.config_path)
        out: dict[str, str] = {}
        if server_names is None:
            names = set(self._clients.keys())
        else:
            names = set(server_names)
        for name in names:
            client = self._clients.get(name)
            parts: list[str] = []
            if client and client.instructions:
                parts.append(client.instructions)
            cfg = configs.get(name)
            if cfg and cfg.instructions:
                if parts:
                    parts.append("")
                parts.append(cfg.instructions)
            if parts:
                out[name] = "\n".join(parts)
        return out

    async def call_tool(self, server_name: str, tool_name: str, arguments: dict[str, Any] | None = None) -> tuple[str, bool]:
        """Proxy a tool call to the named server.

        Returns (output_text, is_error) so the caller knows whether the MCP
        tool itself reported a failure.
        """
        client = self._clients.get(server_name)
        if client is None:
            raise RuntimeError(f"MCP server {server_name!r} is not connected")
        return await client.call_tool(tool_name, arguments)
