Newer
Older
navi-1 / navi / core / tool_executor.py
"""Tool execution helpers — extracted from agent.py."""

import asyncio
from datetime import datetime, timezone
from typing import TYPE_CHECKING

import structlog

from navi.llm.base import Message, ToolCallRequest
from navi.tools.base import Tool

if TYPE_CHECKING:
    from navi.core.events import ToolEvent

log = structlog.get_logger()


def _resolve_tool(tool_map: dict[str, Tool], name: str) -> tuple[str, Tool | None]:
    """Resolve exact tool names plus common MCP alias mistakes."""
    tool = tool_map.get(name)
    if tool is not None:
        return name, tool

    bare_matches = [
        (candidate_name, candidate)
        for candidate_name, candidate in tool_map.items()
        if candidate_name.startswith("mcp:") and candidate_name.endswith(f":{name}")
    ]
    if len(bare_matches) == 1:
        return bare_matches[0]

    normalized = name.replace("-", "_")
    normalized_matches = [
        (candidate_name, candidate)
        for candidate_name, candidate in tool_map.items()
        if candidate_name.startswith("mcp:") and candidate_name.replace("-", "_") == normalized
    ]
    if len(normalized_matches) == 1:
        return normalized_matches[0]

    return name, None


class ToolExecutor:
    """Runs tool calls and builds ToolEvent / Message results."""

    def __init__(self, tool_registry) -> None:
        self._tools = tool_registry

    async def _run_single_tool(
        self,
        tc: ToolCallRequest,
        tool_map: dict[str, Tool],
    ) -> tuple["ToolEvent", Message, "Message | None"]:
        """Execute one tool call and return (ToolEvent, tool_msg, optional_image_msg).

        Called via asyncio.create_task() from run_stream() so that the parent
        generator can drain the event sink queue concurrently.
        """
        from navi.core.events import ToolEvent

        resolved_name, tool = _resolve_tool(tool_map, tc.name)
        image_msg = None
        metadata: dict = {}
        if tool is None:
            content = f"Error: tool '{tc.name}' not found."
            event = ToolEvent(tool_name=tc.name, arguments=tc.arguments,
                              result=content, success=False)
        else:
            log.info("tool.execute", tool=resolved_name, requested_tool=tc.name, args=tc.arguments)
            middlewares = getattr(self._tools, "_middlewares", [])
            for mw in middlewares:
                await mw.before_execute(resolved_name, tc.arguments)
            result = await tool.execute(tc.arguments)
            for mw in middlewares:
                await mw.after_execute(resolved_name, tc.arguments, result)
            content = result.to_message_content()
            metadata = result.metadata or {}
            event = ToolEvent(tool_name=resolved_name, arguments=tc.arguments,
                              result=content, success=result.success,
                              metadata=metadata)
            if result.success and result.metadata and result.metadata.get("is_image"):
                b64 = result.metadata.get("base64")
                if b64:
                    image_msg = Message(
                        role="user",
                        content=f"[Image loaded via {resolved_name} — analyse it]",
                        images=[b64],
                    )
        msg = Message(role="tool", content=content, tool_call_id=tc.id,
                      name=resolved_name if tool is not None else tc.name, metadata=metadata)
        return event, msg, image_msg

    async def _execute_tool_calls(
        self, tool_calls: list[ToolCallRequest], tools: list[Tool]
    ) -> tuple[list[Message], list[Message]]:
        tool_map = {t.name: t for t in tools}

        async def _run_one(tc: ToolCallRequest) -> tuple[Message, Message | None]:
            resolved_name, tool = _resolve_tool(tool_map, tc.name)
            image_msg = None
            metadata: dict = {}
            if tool is None:
                content = f"Error: tool '{tc.name}' not found."
            else:
                log.info("tool.execute", tool=resolved_name, requested_tool=tc.name, args=tc.arguments)
                middlewares = getattr(self._tools, "_middlewares", [])
                for mw in middlewares:
                    await mw.before_execute(resolved_name, tc.arguments)
                result = await tool.execute(tc.arguments)
                for mw in middlewares:
                    await mw.after_execute(resolved_name, tc.arguments, result)
                content = result.to_message_content()
                metadata = result.metadata or {}
                if result.success and result.metadata and result.metadata.get("is_image"):
                    b64 = result.metadata.get("base64")
                    if b64:
                        image_msg = Message(
                            role="user",
                            content=f"[Image loaded via {resolved_name} — analyse it]",
                            images=[b64],
                        )
            tool_msg = Message(role="tool", content=content, tool_call_id=tc.id, name=resolved_name if tool is not None else tc.name, metadata=metadata)
            return tool_msg, image_msg

        pairs = await asyncio.gather(*[_run_one(tc) for tc in tool_calls])
        tool_msgs = [p[0] for p in pairs]
        image_msgs = [p[1] for p in pairs if p[1] is not None]
        return tool_msgs, image_msgs

    async def _execute_tool_calls_streaming(
        self, tool_calls: list[ToolCallRequest], tools: list[Tool]
    ) -> tuple[list[tuple["ToolEvent", Message]], list[Message]]:
        from navi.core.events import ToolEvent

        tool_map = {t.name: t for t in tools}
        middlewares = getattr(self._tools, "_middlewares", [])

        async def _run_one(tc: ToolCallRequest) -> tuple[ToolEvent, Message, Message | None]:
            resolved_name, tool = _resolve_tool(tool_map, tc.name)
            image_msg = None
            metadata: dict = {}
            if tool is None:
                content = f"Error: tool '{tc.name}' not found."
                event = ToolEvent(
                    tool_name=tc.name, arguments=tc.arguments, result=content, success=False
                )
            else:
                log.info("tool.execute", tool=resolved_name, requested_tool=tc.name, args=tc.arguments)
                for mw in middlewares:
                    await mw.before_execute(resolved_name, tc.arguments)
                result = await tool.execute(tc.arguments)
                for mw in middlewares:
                    await mw.after_execute(resolved_name, tc.arguments, result)
                content = result.to_message_content()
                metadata = result.metadata or {}
                event = ToolEvent(
                    tool_name=resolved_name,
                    arguments=tc.arguments,
                    result=content,
                    success=result.success,
                    metadata=metadata,
                )
                if result.success and result.metadata and result.metadata.get("is_image"):
                    b64 = result.metadata.get("base64")
                    if b64:
                        image_msg = Message(
                            role="user",
                            content=f"[Image loaded via {resolved_name} — analyse it]",
                            images=[b64],
                        )
            msg = Message(role="tool", content=content, tool_call_id=tc.id, name=resolved_name if tool is not None else tc.name, metadata=metadata)
            return event, msg, image_msg

        triples = await asyncio.gather(*[_run_one(tc) for tc in tool_calls])
        return [(t[0], t[1]) for t in triples], [t[2] for t in triples if t[2] is not None]