"""Tool middleware — pre/post execute hooks.
Middleware runs around every tool call. Useful for logging, metrics,
rate limiting, and authorization without modifying individual tools.
"""
from abc import ABC, abstractmethod
from typing import Awaitable, Callable
from .base import ToolResult
MiddlewareFunc = Callable[[str, dict], Awaitable[None]]
PostExecuteFunc = Callable[[str, dict, ToolResult], Awaitable[None]]
class ToolMiddleware(ABC):
"""Base class for tool middleware.
Subclasses override `before_execute` and/or `after_execute`.
"""
async def before_execute(self, tool_name: str, params: dict) -> None:
"""Called before the tool executes."""
pass
async def after_execute(self, tool_name: str, params: dict, result: ToolResult) -> None:
"""Called after the tool executes."""
pass
class MiddlewareChain:
"""Chains multiple middleware instances around a tool execution."""
def __init__(self, middlewares: list[ToolMiddleware]) -> None:
self._middlewares = middlewares
async def run(self, tool_name: str, params: dict, execute: Callable[[], Awaitable[ToolResult]]) -> ToolResult:
for mw in self._middlewares:
await mw.before_execute(tool_name, params)
try:
result = await execute()
finally:
# We don't have result on exception, but we can still call after_execute
# with a synthetic failure result if needed. For now, call only on success.
pass
for mw in self._middlewares:
await mw.after_execute(tool_name, params, result)
return result