Newer
Older
navi-1 / navi / core / ai_helper.py
"""
AIHelper — reusable LLM utility for AI-enhanced tools.

Provides a simple ask() / ask_json() interface over a LLMBackend.
Model selection: reads current_model ContextVar (set by run_stream/run_ephemeral
before each tool turn), falls back to default_model.

Usage in any tool:
    class MyTool(Tool):
        def __init__(self, ai_helper: AIHelper) -> None:
            self._ai = ai_helper

        async def execute(self, params: dict) -> ToolResult:
            answer = await self._ai.ask("You are ...", "Question: ...")
            data   = await self._ai.ask_json("...", "Return JSON: ...")
"""

import asyncio
import json
import re
import structlog

log = structlog.get_logger()


class AIHelper:
    """
    Thin, reusable wrapper over LLMBackend for single-turn AI calls.

    Parameters
    ----------
    backend : LLMBackend
        The LLM backend to use (e.g. OllamaBackend).
    default_model : str
        Fallback model name when current_model ContextVar is not set.
    temperature : float
        Sampling temperature for all calls (default 0.1 for determinism).
    """

    def __init__(self, backend, default_model: str, temperature: float = 0.1) -> None:
        self._backend = backend
        self._default_model = default_model
        self._temperature = temperature

    def _active_model(self) -> str:
        """Return current session model or fall back to default."""
        from navi.tools._internal.base import current_model
        return current_model.get() or self._default_model

    async def ask(self, system: str, prompt: str) -> str:
        """Single non-streaming LLM call. Returns the response text."""
        from navi.llm.base import Message
        from navi.tools._internal.base import current_event_sink
        from navi.core.events import AIHelperTokensUsed

        messages = [
            Message(role="system", content=system),
            Message(role="user",   content=prompt),
        ]
        try:
            response = await asyncio.wait_for(
                self._backend.complete(
                    messages,
                    tools=None,
                    temperature=self._temperature,
                    model=self._active_model(),
                    think=False,
                ),
                timeout=120,
            )
        except asyncio.TimeoutError:
            log.error("ai_helper.ask_timeout", timeout=120)
            return "[AIHelper error: LLM call timed out after 120s]"

        # Emit token usage so run_stream can account for AIHelper calls in session metrics
        if response.prompt_tokens or response.completion_tokens:
            sink = current_event_sink.get()
            if sink is not None:
                await sink.put(AIHelperTokensUsed(
                    prompt_tokens=response.prompt_tokens or 0,
                    completion_tokens=response.completion_tokens or 0,
                ))

        return (response.content or "").strip()

    async def ask_json(self, system: str, prompt: str) -> list | dict | None:
        """
        Single LLM call expecting JSON output.
        Returns parsed list/dict, or None if the response cannot be parsed.
        Handles markdown code fences automatically.
        """
        raw = await self.ask(system, prompt)
        result = _extract_json(raw)
        if result is None:
            log.warning("ai_helper.json_parse_failed", raw_preview=raw[:300])
        return result


# ─── JSON extraction ───────────────────────────────────────────────────────

def _extract_json(text: str) -> list | dict | None:
    """
    Extract the first valid JSON array or object from text.
    Handles markdown code fences (```json ... ```) and inline JSON.
    Uses bracket-matching to find the outermost structure.
    """
    # Strip markdown code fences
    cleaned = re.sub(r"```(?:json)?\s*", "", text)
    cleaned = re.sub(r"```", "", cleaned).strip()

    # Try direct parse first
    try:
        return json.loads(cleaned)
    except json.JSONDecodeError:
        pass

    # Bracket-match: find outermost [ ] or { }
    for open_c, close_c in (("[", "]"), ("{", "}")):
        start = cleaned.find(open_c)
        if start == -1:
            continue
        depth = 0
        for i, c in enumerate(cleaned[start:], start):
            if c == open_c:
                depth += 1
            elif c == close_c:
                depth -= 1
                if depth == 0:
                    try:
                        return json.loads(cleaned[start : i + 1])
                    except json.JSONDecodeError:
                        break

    return None