"""
AIHelper — reusable LLM utility for AI-enhanced tools.

Provides a simple ask() / ask_json() interface over a LLMBackend.
Model selection: reads current_model ContextVar (set by run_stream/run_ephemeral
before each tool turn), falls back to default_model.

Usage in any tool:
    class MyTool(Tool):
        def __init__(self, ai_helper: AIHelper) -> None:
            self._ai = ai_helper

        async def execute(self, params: dict) -> ToolResult:
            answer = await self._ai.ask("You are ...", "Question: ...")
            data   = await self._ai.ask_json("...", "Return JSON: ...")
"""

import json
import re
import structlog

log = structlog.get_logger()


class AIHelper:
    """
    Thin, reusable wrapper over LLMBackend for single-turn AI calls.

    Parameters
    ----------
    backend : LLMBackend
        The LLM backend to use (e.g. OllamaBackend).
    default_model : str
        Fallback model name when current_model ContextVar is not set.
    temperature : float
        Sampling temperature for all calls (default 0.1 for determinism).
    """

    def __init__(self, backend, default_model: str, temperature: float = 0.1) -> None:
        self._backend = backend
        self._default_model = default_model
        self._temperature = temperature

    def _active_model(self) -> str:
        """Return current session model or fall back to default."""
        from navi.tools.base import current_model
        return current_model.get() or self._default_model

    async def ask(self, system: str, prompt: str) -> str:
        """Single non-streaming LLM call. Returns the response text."""
        from navi.llm.base import Message
        messages = [
            Message(role="system", content=system),
            Message(role="user",   content=prompt),
        ]
        response = await self._backend.complete(
            messages,
            tools=None,
            temperature=self._temperature,
            model=self._active_model(),
            think=False,
        )
        return (response.content or "").strip()

    async def ask_json(self, system: str, prompt: str) -> list | dict | None:
        """
        Single LLM call expecting JSON output.
        Returns parsed list/dict, or None if the response cannot be parsed.
        Handles markdown code fences automatically.
        """
        raw = await self.ask(system, prompt)
        result = _extract_json(raw)
        if result is None:
            log.warning("ai_helper.json_parse_failed", raw_preview=raw[:300])
        return result


# ─── JSON extraction ───────────────────────────────────────────────────────

def _extract_json(text: str) -> list | dict | None:
    """
    Extract the first valid JSON array or object from text.
    Handles markdown code fences (```json ... ```) and inline JSON.
    Uses bracket-matching to find the outermost structure.
    """
    # Strip markdown code fences
    cleaned = re.sub(r"```(?:json)?\s*", "", text)
    cleaned = re.sub(r"```", "", cleaned).strip()

    # Try direct parse first
    try:
        return json.loads(cleaned)
    except json.JSONDecodeError:
        pass

    # Bracket-match: find outermost [ ] or { }
    for open_c, close_c in (("[", "]"), ("{", "}")):
        start = cleaned.find(open_c)
        if start == -1:
            continue
        depth = 0
        for i, c in enumerate(cleaned[start:], start):
            if c == open_c:
                depth += 1
            elif c == close_c:
                depth -= 1
                if depth == 0:
                    try:
                        return json.loads(cleaned[start : i + 1])
                    except json.JSONDecodeError:
                        break

    return None
