diff --git a/navi/api/deps.py b/navi/api/deps.py index bfc940c..f022719 100644 --- a/navi/api/deps.py +++ b/navi/api/deps.py @@ -25,6 +25,8 @@ ) from navi.memory import MemoryStore from navi.workers import Worker, build_default_workers +from navi.mcp import McpManager, load_mcp_servers +from navi.mcp.tools import McpTool def _make_session_store() -> SessionStore: @@ -41,6 +43,7 @@ _memory_store: MemoryStore | None = None _registries: tuple[ToolRegistry, ProfileRegistry, BackendRegistry, ContextProviderRegistry] | None = None +_mcp_manager: McpManager | None = None def get_memory_store() -> MemoryStore: @@ -93,6 +96,33 @@ return get_registries()[3] +async def get_mcp_manager() -> McpManager: + global _mcp_manager + if _mcp_manager is None: + _mcp_manager = McpManager() + await _mcp_manager.load_all() + return _mcp_manager + + +async def register_mcp_tools(registry: ToolRegistry, manager: McpManager) -> None: + """Discover tools from all connected MCP servers and register them as external.""" + # clear previous external MCP tools + for name in list(registry._external_names): + if name.startswith("mcp_"): + registry.unregister_external(name) + + tools = await manager.get_all_tools() + for server_name, tool in tools: + mcp_tool = McpTool( + server_name=server_name, + tool_name=tool.name, + description=tool.description or "", + parameters=tool.inputSchema, + manager=manager, + ) + registry.register_external(mcp_tool) + + _session_store: SessionStore | None = None _workers: list[Worker] | None = None diff --git a/navi/core/registry.py b/navi/core/registry.py index 6756e87..04b0d71 100644 --- a/navi/core/registry.py +++ b/navi/core/registry.py @@ -36,6 +36,7 @@ from navi.tools.content_publish import ContentPublishTool from navi.tools.model_3d import Model3DTool from navi.tools.render_3d import Render3DTool +from navi.tools.mcp_status import McpStatusTool from navi.tools.loader import LoadResult, load_tools_from_dir from navi.tools.logging_middleware import LoggingMiddleware from navi.context_providers._loader import ContextProviderRegistry @@ -45,6 +46,7 @@ def __init__(self) -> None: self._tools: dict[str, Tool] = {} self._builtin_names: set[str] = set() + self._external_names: set[str] = set() self._middlewares: list = [] def register(self, tool: Tool, builtin: bool = False) -> None: @@ -52,6 +54,19 @@ if builtin: self._builtin_names.add(tool.name) + def register_external(self, tool: Tool) -> None: + """Register a tool from an external source (e.g. MCP server). + + External tools survive ``reload_user_tools()`` just like builtins. + """ + self._tools[tool.name] = tool + self._external_names.add(tool.name) + + def unregister_external(self, name: str) -> None: + """Remove a previously registered external tool.""" + self._external_names.discard(name) + self._tools.pop(name, None) + def add_middleware(self, middleware) -> None: """Add a ToolMiddleware instance.""" self._middlewares.append(middleware) @@ -69,9 +84,9 @@ def reload_user_tools(self, tools_dir: str) -> LoadResult: """Remove all user tools and reload from disk. Safe: errors are isolated.""" - # Drop previously loaded user tools + # Drop previously loaded user tools (not builtin, not external) for name in list(self._tools): - if name not in self._builtin_names: + if name not in self._builtin_names and name not in self._external_names: del self._tools[name] result = load_tools_from_dir(tools_dir) @@ -183,12 +198,13 @@ list_tool = ListToolsTool(registry=tools) manual_tool = ToolManualTool(registry=tools) memory_tool = MemoryTool(memory_store) if memory_store else None + mcp_status_tool = McpStatusTool() builtins = [WebSearchTool(), FilesystemTool(ai_helper=ai_helper), HttpRequestTool(), WebViewTool(), CodeExecTool(), TerminalTool(), SshExecTool(), ImageViewTool(), ScadLintTool(), ShareFileTool(), ContentPublishTool(), TestToolTool(), Model3DTool(), Render3DTool(), TodoTool(), ScratchpadTool(), ReflectTool(ai_helper=ai_helper), - reload_tool, write_tool, delete_tool, list_tool, manual_tool] + reload_tool, write_tool, delete_tool, list_tool, manual_tool, mcp_status_tool] if memory_tool: builtins.append(memory_tool) for builtin in builtins: diff --git a/navi/main.py b/navi/main.py index 9e6434b..080a7fd 100644 --- a/navi/main.py +++ b/navi/main.py @@ -80,6 +80,18 @@ # Initialize registries before embed health check. The memory store gets its # embedding backend wired during registry construction. get_registries() + # Connect MCP servers and register their tools as external. + from navi.api.deps import get_mcp_manager, get_tool_registry, register_mcp_tools + try: + mcp_manager = await get_mcp_manager() + tool_registry = get_tool_registry() + await register_mcp_tools(tool_registry, mcp_manager) + for tool_name in ("reload_tools", "mcp_status"): + tool = tool_registry.get(tool_name) + if hasattr(tool, "_mcp_manager"): + tool._mcp_manager = mcp_manager + except Exception: + log.warning("startup.mcp_connect_failed", exc_info=True) # Apply persisted profile overrides (e.g. is_admin_only) to in-memory profiles. from navi.api.deps import get_profile_registry, get_session_store from navi.profiles._overrides import ensure_table, load_overrides @@ -115,8 +127,14 @@ @app.on_event("shutdown") async def _on_shutdown() -> None: from navi.tools.ssh_exec import close_all_connections + from navi.api.deps import _mcp_manager close_all_connections() + if _mcp_manager is not None: + try: + await _mcp_manager.disconnect_all() + except Exception: + pass @app.get("/", include_in_schema=False) diff --git a/navi/mcp/__init__.py b/navi/mcp/__init__.py new file mode 100644 index 0000000..6444326 --- /dev/null +++ b/navi/mcp/__init__.py @@ -0,0 +1,5 @@ +from .client import McpClient +from .config import McpServerConfig, load_mcp_servers +from .manager import McpManager + +__all__ = ["McpClient", "McpServerConfig", "load_mcp_servers", "McpManager"] diff --git a/navi/mcp/client.py b/navi/mcp/client.py new file mode 100644 index 0000000..02977b9 --- /dev/null +++ b/navi/mcp/client.py @@ -0,0 +1,122 @@ +from __future__ import annotations + +import asyncio +import logging +from contextlib import AsyncExitStack +from typing import Any + +import anyio +from mcp import ClientSession +from mcp.client.sse import sse_client +from mcp.client.stdio import StdioServerParameters, stdio_client +from mcp.types import Tool + +from .config import McpServerConfig + +logger = logging.getLogger(__name__) + + +class McpClient: + """Lightweight wrapper around the official Python MCP client SDK. + + Manages a single server connection (stdio or SSE), exposes + ``list_tools()`` and ``call_tool()``, and handles lifecycle + (connect / disconnect). + """ + + def __init__(self, name: str, config: McpServerConfig) -> None: + self.name = name + self.config = config + self._session: ClientSession | None = None + self._exit_stack = AsyncExitStack() + self._connected = False + + @property + def connected(self) -> bool: + return self._connected and self._session is not None + + async def connect(self) -> None: + """Open transport, initialise session, and store it.""" + if self._connected: + return + + try: + if self.config.is_stdio: + if not self.config.command: + raise ValueError("stdio transport requires 'command'") + params = StdioServerParameters( + command=self.config.command, + args=self.config.args or [], + env=self.config.env, + cwd=self.config.cwd, + ) + transport = await self._exit_stack.enter_async_context( + stdio_client(params) + ) + elif self.config.is_sse: + if not self.config.url: + raise ValueError("sse transport requires 'url'") + transport = await self._exit_stack.enter_async_context( + sse_client( + self.config.url, + headers=self.config.headers, + ) + ) + else: + raise ValueError(f"unknown transport: {self.config.transport}") + + read_stream, write_stream = transport + session = await self._exit_stack.enter_async_context( + ClientSession(read_stream, write_stream) + ) + await session.initialize() + self._session = session + self._connected = True + logger.info("MCP server %r connected (%s)", self.name, self.config.transport) + except Exception: + await self._cleanup() + raise + + async def disconnect(self) -> None: + """Close transport and reset state.""" + if not self._connected: + return + await self._cleanup() + + async def _cleanup(self) -> None: + try: + await self._exit_stack.aclose() + except Exception: + pass + finally: + self._session = None + self._connected = False + self._exit_stack = AsyncExitStack() + + async def list_tools(self) -> list[Tool]: + """Return the tools exposed by the remote MCP server.""" + if not self._session: + raise RuntimeError("Not connected") + result = await self._session.list_tools() + return list(result.tools) + + async def call_tool(self, tool_name: str, arguments: dict[str, Any] | None = None) -> str: + """Execute a remote tool and return its output as a string. + + Text content is concatenated; images are reported as a placeholder. + """ + if not self._session: + raise RuntimeError("Not connected") + + result = await self._session.call_tool(tool_name, arguments or {}) + + parts: list[str] = [] + for item in result.content: + if item.type == "text": + parts.append(item.text) + elif item.type == "image": + parts.append(f"[image: {item.mimeType} ({len(item.data)} bytes)]") + else: + parts.append(f"[{item.type}]") + + return "\n".join(parts) diff --git a/navi/mcp/config.py b/navi/mcp/config.py new file mode 100644 index 0000000..6aec728 --- /dev/null +++ b/navi/mcp/config.py @@ -0,0 +1,49 @@ +from __future__ import annotations + +import json +from pathlib import Path +from typing import Literal + +from pydantic import BaseModel, Field + + +class McpServerConfig(BaseModel): + """Configuration for a single MCP server.""" + + transport: Literal["stdio", "sse"] = "stdio" + + # stdio fields + command: str | None = None + args: list[str] = Field(default_factory=list) + env: dict[str, str] | None = None + cwd: str | None = None + + # sse fields + url: str | None = None + headers: dict[str, str] | None = None + + @property + def is_stdio(self) -> bool: + return self.transport == "stdio" + + @property + def is_sse(self) -> bool: + return self.transport == "sse" + + +def load_mcp_servers(path: str | Path | None = None) -> dict[str, McpServerConfig]: + """Load MCP server configurations from a JSON file. + + Default path is ``mcp_servers.json`` in the current working directory. + Returns an empty dict if the file does not exist. + """ + if path is None: + path = Path("mcp_servers.json") + else: + path = Path(path) + + if not path.exists(): + return {} + + raw = json.loads(path.read_text(encoding="utf-8")) + return {name: McpServerConfig.model_validate(cfg) for name, cfg in raw.items()} diff --git a/navi/mcp/manager.py b/navi/mcp/manager.py new file mode 100644 index 0000000..3c2ff4c --- /dev/null +++ b/navi/mcp/manager.py @@ -0,0 +1,90 @@ +from __future__ import annotations + +import asyncio +import logging +from pathlib import Path +from typing import Any + +from .client import McpClient +from .config import McpServerConfig, load_mcp_servers + +logger = logging.getLogger(__name__) + + +class McpManager: + """Holds a pool of :class:`McpClient` instances and manages their lifecycle. + + Typical usage at application startup:: + + manager = McpManager() + await manager.load_all() + # … later, on reload_tools built-in invocation … + await manager.reload_all() + """ + + def __init__(self, config_path: str | Path | None = None) -> None: + self.config_path = config_path + self._clients: dict[str, McpClient] = {} + + @property + def clients(self) -> dict[str, McpClient]: + return self._clients + + async def load_all(self, configs: dict[str, McpServerConfig] | None = None) -> None: + """Connect to every server in *configs* (or load from disk). + + Existing clients are disconnected first so that a reload is clean. + """ + if configs is None: + configs = load_mcp_servers(self.config_path) + + # disconnect old + await self.disconnect_all() + + # connect new + for name, cfg in configs.items(): + client = McpClient(name, cfg) + try: + await client.connect() + self._clients[name] = client + except Exception as exc: + logger.warning("MCP server %r failed to connect: %s", name, exc) + + async def reload_all(self) -> None: + """Re-read the config file and reconnect every server.""" + configs = load_mcp_servers(self.config_path) + await self.load_all(configs) + + async def disconnect_all(self) -> None: + """Close every open connection and clear the client pool.""" + if not self._clients: + return + results = await asyncio.gather( + *[c.disconnect() for c in self._clients.values()], + return_exceptions=True, + ) + for name, exc in zip(self._clients, results): + if isinstance(exc, Exception): + logger.warning("MCP server %r disconnect error: %s", name, exc) + self._clients.clear() + + async def get_all_tools(self) -> list[tuple[str, Any]]: + """Return ``(server_name, mcp_tool)`` for every tool on every server. + + Servers that fail to list tools are skipped gracefully. + """ + out: list[tuple[str, Any]] = [] + for name, client in self._clients.items(): + try: + tools = await client.list_tools() + out.extend((name, t) for t in tools) + except Exception as exc: + logger.warning("MCP server %r list_tools failed: %s", name, exc) + return out + + async def call_tool(self, server_name: str, tool_name: str, arguments: dict[str, Any] | None = None) -> str: + """Proxy a tool call to the named server.""" + client = self._clients.get(server_name) + if client is None: + raise RuntimeError(f"MCP server {server_name!r} is not connected") + return await client.call_tool(tool_name, arguments) diff --git a/navi/mcp/tools.py b/navi/mcp/tools.py new file mode 100644 index 0000000..46ca1b3 --- /dev/null +++ b/navi/mcp/tools.py @@ -0,0 +1,39 @@ +from __future__ import annotations + +from typing import Any + +from navi.tools.base import Tool, ToolResult + +from .manager import McpManager + + +class McpTool(Tool): + """A :class:`Tool` proxy that forwards execution to an MCP server. + + The name is ``mcp__`` to avoid collisions with built-in + and user-defined tools. + """ + + def __init__( + self, + server_name: str, + tool_name: str, + description: str, + parameters: dict[str, Any], + manager: McpManager, + ) -> None: + self.server_name = server_name + self.tool_name = tool_name + self.description = description + self.parameters = parameters + self._manager = manager + self.name = f"mcp_{server_name}_{tool_name}" + + async def execute(self, params: dict[str, Any]) -> ToolResult: + try: + output = await self._manager.call_tool( + self.server_name, self.tool_name, params + ) + return ToolResult(success=True, output=output) + except Exception as exc: + return ToolResult(success=False, output="", error=str(exc)) diff --git a/navi/tools/mcp_status.py b/navi/tools/mcp_status.py new file mode 100644 index 0000000..168825e --- /dev/null +++ b/navi/tools/mcp_status.py @@ -0,0 +1,45 @@ +"""Built-in tool to list connected MCP servers and their exposed tools.""" + +from navi.mcp import McpManager + +from .base import Tool, ToolResult + + +class McpStatusTool(Tool): + name = "mcp_status" + description = ( + "Show the status of all configured MCP servers and the tools they expose. " + "Use this to discover what external tools are currently available." + ) + parameters = { + "type": "object", + "properties": {}, + "required": [], + } + + def __init__(self, manager: McpManager | None = None) -> None: + self._manager = manager + + async def execute(self, params: dict) -> ToolResult: + if self._manager is None: + return ToolResult( + success=False, + output="", + error="MCP manager not available.", + ) + + lines: list[str] = [] + for name, client in self._manager.clients.items(): + status = "connected" if client.connected else "disconnected" + lines.append(f"Server: {name} ({status})") + try: + tools = await client.list_tools() + for t in tools: + lines.append(f" - {t.name}: {t.description or 'no description'}") + except Exception as exc: + lines.append(f" (failed to list tools: {exc})") + + if not lines: + return ToolResult(success=True, output="No MCP servers configured.") + + return ToolResult(success=True, output="\n".join(lines)) diff --git a/navi/tools/memory.py b/navi/tools/memory.py index a0d5a15..b951625 100644 --- a/navi/tools/memory.py +++ b/navi/tools/memory.py @@ -3,7 +3,7 @@ from datetime import datetime, timedelta, timezone from navi.memory.store import MemoryStore -from navi.tools.base import current_session_id +from navi.tools.base import current_session_id, current_user_id from .base import Tool, ToolResult @@ -136,10 +136,12 @@ pass session_id = current_session_id.get(None) + user_id = current_user_id.get(None) await self._store.upsert_fact( category=category, key=key, value=value, + user_id=user_id, source_session_id=session_id, source=source, confidence=confidence, @@ -153,7 +155,8 @@ if not query: return ToolResult(success=False, output="query is required for search.", error="missing query") - facts = await self._store.search_facts(query, limit=15) + user_id = current_user_id.get(None) + facts = await self._store.search_facts(query, user_id=user_id, limit=15) if not facts: return ToolResult(success=True, output="No matching facts found in memory.") @@ -177,7 +180,8 @@ if not key: return ToolResult(success=False, output="key is required for forget.", error="missing key") - deleted = await self._store.delete_fact(key, category) + user_id = current_user_id.get(None) + deleted = await self._store.delete_fact(key, category, user_id=user_id) if deleted == 0: return ToolResult(success=False, output=f"No fact found with key '{key}'.", error="not found") @@ -185,7 +189,8 @@ return ToolResult(success=True, output=f"Deleted {deleted} {noun} with key '{key}'.") async def _list(self) -> ToolResult: - facts = await self._store.get_all_facts() + user_id = current_user_id.get(None) + facts = await self._store.get_all_facts(user_id=user_id) if not facts: return ToolResult(success=True, output="Memory is empty.") diff --git a/navi/tools/memory_forget.py b/navi/tools/memory_forget.py index bba7e9e..96545dc 100644 --- a/navi/tools/memory_forget.py +++ b/navi/tools/memory_forget.py @@ -1,6 +1,7 @@ """Memory forget tool — delete a fact from long-term memory.""" from navi.memory.store import MemoryStore +from navi.tools.base import current_user_id from .base import Tool, ToolResult @@ -37,7 +38,8 @@ if not key: return ToolResult(success=False, output="Key is required.", error="missing key") - deleted = await self._store.delete_fact(key, category) + user_id = current_user_id.get(None) + deleted = await self._store.delete_fact(key, category, user_id=user_id) if deleted == 0: return ToolResult(success=False, output=f"No fact found with key '{key}'.", error="not found") diff --git a/navi/tools/memory_save.py b/navi/tools/memory_save.py index 2eed1f2..6ea9190 100644 --- a/navi/tools/memory_save.py +++ b/navi/tools/memory_save.py @@ -1,7 +1,7 @@ """Memory save tool — persist a fact about the user to long-term memory.""" from navi.memory.store import MemoryStore -from navi.tools.base import current_session_id +from navi.tools.base import current_session_id, current_user_id from .base import Tool, ToolResult @@ -67,5 +67,6 @@ return ToolResult(success=False, output="value is required.", error="missing value") session_id = current_session_id.get(None) - await self._store.upsert_fact(category, key, value, session_id) + user_id = current_user_id.get(None) + await self._store.upsert_fact(category, key, value, user_id=user_id, source_session_id=session_id) return ToolResult(success=True, output=f"Saved [{category}] {key}: {value}") diff --git a/navi/tools/memory_search.py b/navi/tools/memory_search.py index 0812aa0..be86ff8 100644 --- a/navi/tools/memory_search.py +++ b/navi/tools/memory_search.py @@ -1,6 +1,7 @@ """Memory search tool — query facts about the user from long-term memory.""" from navi.memory.store import MemoryStore +from navi.tools.base import current_user_id from .base import Tool, ToolResult @@ -35,7 +36,8 @@ if not query: return ToolResult(success=False, output="Query is required.", error="missing query") - facts = await self._store.search_facts(query, limit=15) + user_id = current_user_id.get(None) + facts = await self._store.search_facts(query, user_id=user_id, limit=15) if not facts: return ToolResult(success=True, output="No matching facts found in memory.") diff --git a/navi/tools/reload_tools.py b/navi/tools/reload_tools.py index 97ad470..63792f3 100644 --- a/navi/tools/reload_tools.py +++ b/navi/tools/reload_tools.py @@ -8,10 +8,11 @@ class ReloadToolsTool(Tool): name = "reload_tools" description = ( - "Hot-reload all tools from the tools/ directory and context providers from " - "context_providers/ without restarting the server. " - "Call this after writing or editing a tool or context provider file. " - "Returns a report of what was loaded and any errors per file." + "Hot-reload all tools from the tools/ directory, context providers from " + "context_providers/, and reconnect all configured MCP servers without " + "restarting the server. Call this after writing or editing a tool, context " + "provider, or MCP server configuration. Returns a report of what was loaded " + "and any errors per file or server." ) parameters = { "type": "object", @@ -19,9 +20,10 @@ "required": [], } - def __init__(self, registry=None, cp_registry=None) -> None: + def __init__(self, registry=None, cp_registry=None, mcp_manager=None) -> None: self._registry = registry self._cp_registry = cp_registry + self._mcp_manager = mcp_manager async def execute(self, params: dict) -> ToolResult: if self._registry is None: @@ -53,4 +55,16 @@ for filename, error in cp_result.errors.items(): lines.append(f" {filename}: {error}") + # Reconnect MCP servers + if self._mcp_manager is not None: + try: + await self._mcp_manager.reload_all() + from navi.api.deps import register_mcp_tools + await register_mcp_tools(self._registry, self._mcp_manager) + mcp_tools = [t.name for t in self._registry.all() if t.name.startswith("mcp_")] + lines.append(f"MCP tools ({len(mcp_tools)}): {', '.join(mcp_tools) or 'none'}") + except Exception as exc: + has_errors = True + lines.append(f"MCP reload error: {exc}") + return ToolResult(success=not has_errors, output="\n".join(lines)) diff --git a/pyproject.toml b/pyproject.toml index 9208514..19b73f8 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -31,6 +31,9 @@ "gnexus-auth-client-py @ git+https://git.gnexus.space/root/gnexus-auth-client-py.git", "cryptography>=42", + # MCP (Model Context Protocol) + "mcp>=1.27", + # Config "pydantic>=2.7", "pydantic-settings>=2.3", diff --git a/tests/_mcp_test_server.py b/tests/_mcp_test_server.py new file mode 100644 index 0000000..ef877f7 --- /dev/null +++ b/tests/_mcp_test_server.py @@ -0,0 +1,21 @@ +"""Minimal MCP server used only by integration tests.""" + +from mcp.server.fastmcp import FastMCP + +mcp = FastMCP("test-server") + + +@mcp.tool() +def hello(name: str) -> str: + """Say hello to someone.""" + return f"Hello, {name}!" + + +@mcp.tool() +def add(a: int, b: int) -> int: + """Add two numbers.""" + return a + b + + +if __name__ == "__main__": + mcp.run(transport="stdio") diff --git a/tests/integration/test_mcp_integration.py b/tests/integration/test_mcp_integration.py new file mode 100644 index 0000000..e19d76e --- /dev/null +++ b/tests/integration/test_mcp_integration.py @@ -0,0 +1,34 @@ +"""Integration test for navi/mcp/ against a real stdio MCP server.""" + +import sys + +import pytest + +from navi.mcp.client import McpClient +from navi.mcp.config import McpServerConfig + + +@pytest.mark.anyio +async def test_stdio_client_connects_lists_and_calls_tools(): + cfg = McpServerConfig( + transport="stdio", + command=sys.executable, + args=["-m", "tests._mcp_test_server"], + ) + client = McpClient("test", cfg) + await client.connect() + assert client.connected + + tools = await client.list_tools() + names = {t.name for t in tools} + assert "hello" in names + assert "add" in names + + result = await client.call_tool("hello", {"name": "Navi"}) + assert "Hello, Navi!" in result + + result = await client.call_tool("add", {"a": 2, "b": 3}) + assert "5" in result + + await client.disconnect() + assert not client.connected diff --git a/tests/unit/test_mcp.py b/tests/unit/test_mcp.py new file mode 100644 index 0000000..1bbd37d --- /dev/null +++ b/tests/unit/test_mcp.py @@ -0,0 +1,115 @@ +"""Unit tests for navi/mcp/ infrastructure.""" + +from unittest.mock import AsyncMock + +import pytest + +from navi.mcp.client import McpClient +from navi.mcp.config import McpServerConfig, load_mcp_servers +from navi.mcp.manager import McpManager +from navi.mcp.tools import McpTool + + +class TestMcpServerConfig: + def test_stdio_config(self): + cfg = McpServerConfig( + transport="stdio", + command="python", + args=["-m", "app.mcp_server"], + env={"FOO": "bar"}, + ) + assert cfg.is_stdio + assert not cfg.is_sse + assert cfg.command == "python" + + def test_sse_config(self): + cfg = McpServerConfig( + transport="sse", + url="http://localhost:3001/sse", + headers={"Authorization": "Bearer token"}, + ) + assert cfg.is_sse + assert not cfg.is_stdio + assert cfg.url == "http://localhost:3001/sse" + + def test_default_transport_is_stdio(self): + cfg = McpServerConfig(command="npx") + assert cfg.transport == "stdio" + + +class TestLoadMcpServers: + def test_missing_file_returns_empty(self, tmp_path): + result = load_mcp_servers(tmp_path / "nonexistent.json") + assert result == {} + + def test_loads_valid_json(self, tmp_path): + path = tmp_path / "mcp_servers.json" + path.write_text( + '{"book": {"transport": "stdio", "command": "python", "args": ["server.py"]}}' + ) + result = load_mcp_servers(path) + assert "book" in result + assert result["book"].command == "python" + + +class TestMcpManager: + async def test_load_all_with_empty_config(self): + manager = McpManager() + await manager.load_all({}) + assert manager.clients == {} + + async def test_disconnect_all_when_empty(self): + manager = McpManager() + await manager.disconnect_all() # should not raise + + async def test_get_all_tools_skips_broken_server(self): + manager = McpManager() + client = AsyncMock(spec=McpClient) + client.connected = True + client.list_tools.side_effect = RuntimeError("boom") + manager._clients = {"bad": client} + + tools = await manager.get_all_tools() + assert tools == [] + + +class TestMcpTool: + def test_name_prefix(self): + mock_manager = AsyncMock(spec=McpManager) + tool = McpTool( + server_name="gnexus-book", + tool_name="search_docs", + description="Search docs", + parameters={"type": "object", "properties": {}}, + manager=mock_manager, + ) + assert tool.name == "mcp_gnexus-book_search_docs" + + async def test_execute_success(self): + mock_manager = AsyncMock(spec=McpManager) + mock_manager.call_tool.return_value = "found 3 results" + tool = McpTool( + server_name="book", + tool_name="search", + description="", + parameters={}, + manager=mock_manager, + ) + result = await tool.execute({"query": "foo"}) + assert result.success + assert result.output == "found 3 results" + mock_manager.call_tool.assert_awaited_once_with("book", "search", {"query": "foo"}) + + async def test_execute_failure(self): + mock_manager = AsyncMock(spec=McpManager) + mock_manager.call_tool.side_effect = RuntimeError("server down") + tool = McpTool( + server_name="book", + tool_name="search", + description="", + parameters={}, + manager=mock_manager, + ) + result = await tool.execute({}) + assert not result.success + assert "server down" in result.error diff --git a/tests/unit/test_startup.py b/tests/unit/test_startup.py index 88ddaf8..64be317 100644 --- a/tests/unit/test_startup.py +++ b/tests/unit/test_startup.py @@ -17,6 +17,17 @@ order.append("get_registries") return None + def fake_get_tool_registry(): + order.append("get_tool_registry") + return None + + async def fake_get_mcp_manager(): + order.append("get_mcp_manager") + return None + + async def fake_register_mcp_tools(registry, manager): + order.append("register_mcp_tools") + async def fake_check_embed(): order.append("check_embed") return {"ok": True, "backend": "fake", "error": None} @@ -31,6 +42,9 @@ monkeypatch.setattr(content_store, "ensure_tables", fake_ensure_tables) monkeypatch.setattr(deps, "get_registries", fake_get_registries) + monkeypatch.setattr(deps, "get_tool_registry", fake_get_tool_registry) + monkeypatch.setattr(deps, "get_mcp_manager", fake_get_mcp_manager) + monkeypatch.setattr(deps, "register_mcp_tools", fake_register_mcp_tools) monkeypatch.setattr(health_mod, "_check_embed", fake_check_embed) monkeypatch.setattr(session_files, "cleanup_loop", fake_cleanup_loop) monkeypatch.setattr(deps, "get_session_store", lambda: object()) @@ -43,5 +57,8 @@ await main_mod._on_startup() - assert order[:3] == ["ensure_tables", "get_registries", "check_embed"] + assert order[:2] == ["ensure_tables", "get_registries"] + assert "get_mcp_manager" in order + assert "register_mcp_tools" in order + assert "check_embed" in order assert order[-1] == "create_task"