Newer
Older
navi-1 / navi / llm / ollama.py
@Eugene Sukhodolskiy Eugene Sukhodolskiy on 21 May 8 KB Add structured logging for Ollama chat errors
"""Ollama LLM backend."""

import uuid
from typing import AsyncGenerator

import ollama as ollama_client

from navi.config import settings
from navi.exceptions import LLMBackendError, LLMConnectionError, LLMModelNotFoundError

from .base import LLMBackend, LLMChunk, LLMResponse, Message, ToolCallRequest, ToolSchema


def _clean_base64_image(img: str) -> str | None:
    """Strip data URI prefix and validate that result is non-empty base64."""
    if not img:
        return None
    s = img.strip()
    if s.startswith("data:"):
        if "," in s:
            s = s.split(",", 1)[1]
        else:
            return None
    return s if s else None


def _to_ollama_messages(messages: list[Message]) -> list[dict]:
    result = []
    for m in messages:
        msg: dict = {"role": m.role, "content": m.content or ""}
        if m.images:
            cleaned = [_clean_base64_image(img) for img in m.images]
            cleaned = [img for img in cleaned if img is not None]
            if cleaned:
                msg["images"] = cleaned
        if m.tool_calls:
            msg["tool_calls"] = [
                {"function": {"name": tc.name, "arguments": tc.arguments}}
                for tc in m.tool_calls
            ]
        result.append(msg)
    return result


def _to_ollama_tools(tools: list[ToolSchema]) -> list[dict]:
    return [t.model_dump() for t in tools]


def _base_options(
    temperature: float,
    max_tokens: int | None = None,
    top_k: int | None = None,
    top_p: float | None = None,
    num_thread: int | None = None,
) -> dict:
    opts: dict = {"temperature": temperature, "num_ctx": settings.ollama_num_ctx}
    if max_tokens is not None:
        opts["num_predict"] = max_tokens
    if top_k is not None:
        opts["top_k"] = top_k
    if top_p is not None:
        opts["top_p"] = top_p
    if num_thread is not None:
        opts["num_thread"] = num_thread
    return opts


def _resolve_think(think: bool | None) -> bool | None:
    # think=None → use global setting; think=False → force off even if global is True
    return settings.ollama_think if think is None else think


def _resolve_model(model: "list[str] | str | None", default: str) -> str:
    """Normalize model param: list → first element, None → default."""
    if isinstance(model, list):
        return model[0] if model else default
    return model or default


def _classify_error(e: Exception) -> Exception:
    """Wrap raw Ollama/network exceptions into typed LLM exceptions."""
    if isinstance(e, ollama_client.RequestError):
        return LLMConnectionError(str(e))
    if isinstance(e, ollama_client.ResponseError):
        msg = e.error.lower()
        if "not found" in msg or "does not exist" in msg or e.status_code == 404:
            return LLMModelNotFoundError(str(e))
        return LLMBackendError(str(e))
    # Catch httpx / socket connection failures by message
    err_type = type(e).__name__.lower()
    err_str = str(e).lower()
    if "timeout" in err_type:
        return LLMConnectionError(str(e) or type(e).__name__)
    if any(kw in err_str for kw in ("connect", "connection refused", "name or service not known",
                                     "network", "timeout", "unreachable", "nodename")):
        return LLMConnectionError(str(e))
    return LLMBackendError(str(e))


class OllamaBackend(LLMBackend):
    def __init__(
        self,
        model: str,
        host: str = "http://localhost:11434",
        api_key: str = "",
        timeout: int = 30,
    ):
        self.model = model
        self._host = host
        headers = {"Authorization": f"Bearer {api_key}"} if api_key else None
        self._client = ollama_client.AsyncClient(host=host, headers=headers, timeout=timeout)

    async def complete(
        self,
        messages: list[Message],
        tools: list[ToolSchema] | None = None,
        temperature: float = 0.7,
        model: "list[str] | str | None" = None,
        think: bool | None = None,
        max_tokens: int | None = None,
        top_k: int | None = None,
        top_p: float | None = None,
        num_thread: int | None = None,
    ) -> LLMResponse:
        resolved = _resolve_model(model, self.model)
        try:
            kwargs: dict = {
                "model": resolved,
                "messages": _to_ollama_messages(messages),
                "options": _base_options(temperature, max_tokens=max_tokens, top_k=top_k, top_p=top_p, num_thread=num_thread),
                "stream": False,
                "think": _resolve_think(think),
            }
            if tools:
                kwargs["tools"] = _to_ollama_tools(tools)

            response = await self._client.chat(**kwargs)
            msg = response.message

            tool_calls = None
            if msg.tool_calls:
                tool_calls = [
                    ToolCallRequest(
                        id=str(uuid.uuid4()),
                        name=tc.function.name,
                        arguments=dict(tc.function.arguments),
                    )
                    for tc in msg.tool_calls
                ]

            finish_reason = "tool_calls" if tool_calls else "stop"
            return LLMResponse(
                content=msg.content or None,
                tool_calls=tool_calls,
                finish_reason=finish_reason,
                thinking=getattr(msg, "thinking", None) or None,
                prompt_tokens=getattr(response, "prompt_eval_count", None),
                completion_tokens=getattr(response, "eval_count", None),
            )
        except (LLMConnectionError, LLMModelNotFoundError, LLMBackendError):
            raise
        except Exception as e:
            import structlog
            log = structlog.get_logger()
            log.warning(
                "llm.ollama.chat_error",
                model=resolved,
                message_count=len(messages),
                tools_count=len(tools) if tools else 0,
                error=str(e),
            )
            raise _classify_error(e) from e

    async def stream_complete(
        self,
        messages: list[Message],
        tools: list[ToolSchema] | None = None,
        temperature: float = 0.7,
        model: "list[str] | str | None" = None,
        think: bool | None = None,
        top_k: int | None = None,
        top_p: float | None = None,
        num_thread: int | None = None,
    ) -> AsyncGenerator[LLMChunk, None]:
        resolved = _resolve_model(model, self.model)
        try:
            kwargs: dict = {
                "model": resolved,
                "messages": _to_ollama_messages(messages),
                "options": _base_options(temperature, top_k=top_k, top_p=top_p, num_thread=num_thread),
                "stream": True,
                "think": _resolve_think(think),
            }
            if tools:
                kwargs["tools"] = _to_ollama_tools(tools)

            async for chunk in await self._client.chat(**kwargs):
                thinking = getattr(chunk.message, "thinking", None) or None
                delta = chunk.message.content or None

                tool_calls = None
                if chunk.message.tool_calls:
                    tool_calls = [
                        ToolCallRequest(
                            id=str(uuid.uuid4()),
                            name=tc.function.name,
                            arguments=dict(tc.function.arguments),
                        )
                        for tc in chunk.message.tool_calls
                    ]

                finish_reason = None
                if chunk.done:
                    finish_reason = "tool_calls" if tool_calls else "stop"

                yield LLMChunk(
                    delta=delta,
                    thinking=thinking,
                    finish_reason=finish_reason,
                    tool_calls=tool_calls,
                    prompt_tokens=chunk.prompt_eval_count if chunk.done else None,
                    completion_tokens=chunk.eval_count if chunk.done else None,
                )
        except (LLMConnectionError, LLMModelNotFoundError, LLMBackendError):
            raise
        except Exception as e:
            import structlog
            log = structlog.get_logger()
            log.warning(
                "llm.ollama.chat_error",
                model=resolved,
                message_count=len(messages),
                tools_count=len(tools) if tools else 0,
                error=str(e),
            )
            raise _classify_error(e) from e

    async def embed(
        self,
        texts: list[str],
        model: "list[str] | str | None" = None,
    ) -> list[list[float]]:
        resolved = _resolve_model(model, self.model)
        try:
            response = await self._client.embed(model=resolved, input=texts)
            return response.embeddings
        except (LLMConnectionError, LLMModelNotFoundError, LLMBackendError):
            raise
        except Exception as e:
            raise _classify_error(e) from e