Newer
Older
navi-1 / navi / core / tool_executor.py
"""Tool execution helpers — extracted from agent.py."""

import asyncio
from datetime import datetime, timezone
from typing import TYPE_CHECKING

import structlog

from navi.llm.base import Message, ToolCallRequest
from navi.tools._internal.base import Tool

if TYPE_CHECKING:
    from navi.core.events import ToolEvent

log = structlog.get_logger()


def _resolve_tool(tool_map: dict[str, Tool], name: str) -> tuple[str, Tool | None]:
    """Resolve exact tool names plus common MCP alias mistakes."""
    from navi.mcp.tools import is_mcp_tool, parse_mcp_name

    tool = tool_map.get(name)
    if tool is not None:
        return name, tool

    # Support bare tool name when the full MCP name ends with it
    # e.g. "web_search" -> "mcp__navi_web__web_search"
    bare_matches = [
        (candidate_name, candidate)
        for candidate_name, candidate in tool_map.items()
        if is_mcp_tool(candidate_name) and candidate_name.endswith(f"__{name}")
    ]
    if len(bare_matches) == 1:
        return bare_matches[0]

    # Normalized variant (dash vs underscore)
    normalized = name.replace("-", "_")
    normalized_matches = [
        (candidate_name, candidate)
        for candidate_name, candidate in tool_map.items()
        if is_mcp_tool(candidate_name) and candidate_name.replace("-", "_") == normalized
    ]
    if len(normalized_matches) == 1:
        return normalized_matches[0]

    # Fallback: old underscore format like mcp_server_tool -> mcp__server__tool
    old_format_matches = [
        (candidate_name, candidate)
        for candidate_name, candidate in tool_map.items()
        if is_mcp_tool(candidate_name) and name.startswith("mcp_")
    ]
    for candidate_name, candidate in old_format_matches:
        parsed = parse_mcp_name(candidate_name)
        if parsed is None:
            continue
        server_name, tool_name = parsed
        # mcp_navi_web_search -> navi_web_search -> split into navi, web, search
        # We try matching by removing the mcp_ prefix and comparing
        expected_old = f"mcp_{server_name}_{tool_name}"
        if expected_old == name:
            return candidate_name, candidate

    # Extra fallback: legacy colon format mcp:server:tool (from old sessions)
    legacy_matches = [
        (candidate_name, candidate)
        for candidate_name, candidate in tool_map.items()
        if is_mcp_tool(candidate_name) and name.replace(":", "__") == candidate_name
    ]
    if len(legacy_matches) == 1:
        return legacy_matches[0]

    return name, None


class ToolExecutor:
    """Runs tool calls and builds ToolEvent / Message results."""

    def __init__(self, tool_registry) -> None:
        self._tools = tool_registry

    async def _execute_one(
        self,
        tc: ToolCallRequest,
        tool_map: dict[str, Tool],
        ctx=None,
    ) -> tuple["ToolEvent", Message, "Message | None"]:
        """Execute a single tool call and return (ToolEvent, tool_msg, optional_image_msg).

        This is the single canonical path for tool resolution, middleware,
        execution, and message construction. All public batch methods delegate here.
        """
        from navi.core.events import ToolEvent

        resolved_name, tool = _resolve_tool(tool_map, tc.name)
        image_msg = None
        metadata: dict = {}
        if tool is None:
            content = f"Error: tool '{tc.name}' not found."
            event = ToolEvent(tool_name=tc.name, arguments=tc.arguments,
                              result=content, success=False,
                              tool_call_id=tc.id)
        else:
            log.info("tool.execute", tool=resolved_name, requested_tool=tc.name, args=tc.arguments)
            middlewares = getattr(self._tools, "_middlewares", [])
            for mw in middlewares:
                await mw.before_execute(resolved_name, tc.arguments)
            result = await tool.execute(tc.arguments, ctx=ctx)
            for mw in middlewares:
                await mw.after_execute(resolved_name, tc.arguments, result)
            content = result.to_message_content()
            metadata = result.metadata or {}
            event = ToolEvent(tool_name=resolved_name, arguments=tc.arguments,
                              result=content, success=result.success,
                              metadata=metadata,
                              tool_call_id=tc.id)
            if result.success and result.metadata and result.metadata.get("is_image"):
                b64 = result.metadata.get("base64")
                if b64:
                    image_msg = Message(
                        role="user",
                        content=f"[Image loaded via {resolved_name} — analyse it]",
                        images=[b64],
                    )
        msg = Message(role="tool", content=content, tool_call_id=tc.id,
                      name=resolved_name if tool is not None else tc.name, metadata=metadata)
        return event, msg, image_msg

    async def _run_single_tool(
        self,
        tc: ToolCallRequest,
        tool_map: dict[str, Tool],
        ctx=None,
    ) -> tuple["ToolEvent", Message, "Message | None"]:
        """Execute one tool call and return (ToolEvent, tool_msg, optional_image_msg).

        Called via asyncio.create_task() from run_stream() so that the parent
        generator can drain the event sink queue concurrently.
        """
        return await self._execute_one(tc, tool_map, ctx=ctx)

    async def _execute_tool_calls(
        self, tool_calls: list[ToolCallRequest], tools: list[Tool], ctx=None
    ) -> tuple[list[Message], list[Message]]:
        tool_map = {t.name: t for t in tools}
        pairs = await asyncio.gather(*[self._execute_one(tc, tool_map, ctx=ctx) for tc in tool_calls])
        tool_msgs = [p[1] for p in pairs]
        image_msgs = [p[2] for p in pairs if p[2] is not None]
        return tool_msgs, image_msgs

    async def _execute_tool_calls_streaming(
        self, tool_calls: list[ToolCallRequest], tools: list[Tool], ctx=None
    ) -> tuple[list[tuple["ToolEvent", Message]], list[Message]]:
        tool_map = {t.name: t for t in tools}
        triples = await asyncio.gather(*[self._execute_one(tc, tool_map, ctx=ctx) for tc in tool_calls])
        return [(t[0], t[1]) for t in triples], [t[2] for t in triples if t[2] is not None]