Newer
Older
navi-1 / navi / llm / ollama.py
"""Ollama LLM backend."""

import uuid
from typing import AsyncGenerator

import ollama as ollama_client

from navi.config import settings
from navi.exceptions import LLMBackendError, LLMConnectionError, LLMModelNotFoundError

from .base import LLMBackend, LLMChunk, LLMResponse, Message, ToolCallRequest, ToolSchema


def _to_ollama_messages(messages: list[Message]) -> list[dict]:
    result = []
    for m in messages:
        msg: dict = {"role": m.role, "content": m.content or ""}
        if m.images:
            msg["images"] = m.images  # list of base64 strings, Ollama format
        if m.tool_calls:
            msg["tool_calls"] = [
                {"function": {"name": tc.name, "arguments": tc.arguments}}
                for tc in m.tool_calls
            ]
        result.append(msg)
    return result


def _to_ollama_tools(tools: list[ToolSchema]) -> list[dict]:
    return [t.model_dump() for t in tools]


def _base_options(
    temperature: float,
    think: bool | None = None,
    max_tokens: int | None = None,
) -> dict:
    opts: dict = {"temperature": temperature, "num_ctx": settings.ollama_num_ctx}
    # think=None → use global setting; think=False → force off even if global is True
    effective_think = settings.ollama_think if think is None else think
    if effective_think:
        opts["think"] = True
    if max_tokens is not None:
        opts["num_predict"] = max_tokens
    return opts


def _resolve_model(model: "list[str] | str | None", default: str) -> str:
    """Normalize model param: list → first element, None → default."""
    if isinstance(model, list):
        return model[0] if model else default
    return model or default


def _classify_error(e: Exception) -> Exception:
    """Wrap raw Ollama/network exceptions into typed LLM exceptions."""
    if isinstance(e, ollama_client.RequestError):
        return LLMConnectionError(str(e))
    if isinstance(e, ollama_client.ResponseError):
        msg = e.error.lower()
        if "not found" in msg or "does not exist" in msg or e.status_code == 404:
            return LLMModelNotFoundError(str(e))
        return LLMBackendError(str(e))
    # Catch httpx / socket connection failures by message
    err_str = str(e).lower()
    if any(kw in err_str for kw in ("connect", "connection refused", "name or service not known",
                                     "network", "timeout", "unreachable", "nodename")):
        return LLMConnectionError(str(e))
    return LLMBackendError(str(e))


class OllamaBackend(LLMBackend):
    def __init__(self, model: str, host: str = "http://localhost:11434", api_key: str = ""):
        self.model = model
        self._host = host
        headers = {"Authorization": f"Bearer {api_key}"} if api_key else None
        self._client = ollama_client.AsyncClient(host=host, headers=headers)

    async def complete(
        self,
        messages: list[Message],
        tools: list[ToolSchema] | None = None,
        temperature: float = 0.7,
        model: "list[str] | str | None" = None,
        think: bool | None = None,
        max_tokens: int | None = None,
    ) -> LLMResponse:
        resolved = _resolve_model(model, self.model)
        try:
            kwargs: dict = {
                "model": resolved,
                "messages": _to_ollama_messages(messages),
                "options": _base_options(temperature, think=think, max_tokens=max_tokens),
                "stream": False,
            }
            if tools:
                kwargs["tools"] = _to_ollama_tools(tools)

            response = await self._client.chat(**kwargs)
            msg = response.message

            tool_calls = None
            if msg.tool_calls:
                tool_calls = [
                    ToolCallRequest(
                        id=str(uuid.uuid4()),
                        name=tc.function.name,
                        arguments=dict(tc.function.arguments),
                    )
                    for tc in msg.tool_calls
                ]

            finish_reason = "tool_calls" if tool_calls else "stop"
            return LLMResponse(
                content=msg.content or None,
                tool_calls=tool_calls,
                finish_reason=finish_reason,
                thinking=getattr(msg, "thinking", None) or None,
                prompt_tokens=getattr(response, "prompt_eval_count", None) or None,
                completion_tokens=getattr(response, "eval_count", None) or None,
            )
        except (LLMConnectionError, LLMModelNotFoundError, LLMBackendError):
            raise
        except Exception as e:
            raise _classify_error(e) from e

    async def stream(
        self,
        messages: list[Message],
        temperature: float = 0.7,
        model: "list[str] | str | None" = None,
    ) -> AsyncGenerator[LLMChunk, None]:
        resolved = _resolve_model(model, self.model)
        try:
            async for chunk in await self._client.chat(
                model=resolved,
                messages=_to_ollama_messages(messages),
                options=_base_options(temperature),
                stream=True,
            ):
                thinking = getattr(chunk.message, "thinking", None) or None
                delta = chunk.message.content or None
                finish_reason = "stop" if chunk.done else None
                yield LLMChunk(
                    delta=delta,
                    thinking=thinking,
                    finish_reason=finish_reason,
                    prompt_tokens=chunk.prompt_eval_count if chunk.done else None,
                    completion_tokens=chunk.eval_count if chunk.done else None,
                )
        except (LLMConnectionError, LLMModelNotFoundError, LLMBackendError):
            raise
        except Exception as e:
            raise _classify_error(e) from e

    async def stream_complete(
        self,
        messages: list[Message],
        tools: list[ToolSchema] | None = None,
        temperature: float = 0.7,
        model: "list[str] | str | None" = None,
        think: bool | None = None,
    ) -> AsyncGenerator[LLMChunk, None]:
        resolved = _resolve_model(model, self.model)
        try:
            kwargs: dict = {
                "model": resolved,
                "messages": _to_ollama_messages(messages),
                "options": _base_options(temperature, think=think),
                "stream": True,
            }
            if tools:
                kwargs["tools"] = _to_ollama_tools(tools)

            async for chunk in await self._client.chat(**kwargs):
                thinking = getattr(chunk.message, "thinking", None) or None
                delta = chunk.message.content or None

                tool_calls = None
                if chunk.message.tool_calls:
                    tool_calls = [
                        ToolCallRequest(
                            id=str(uuid.uuid4()),
                            name=tc.function.name,
                            arguments=dict(tc.function.arguments),
                        )
                        for tc in chunk.message.tool_calls
                    ]

                finish_reason = None
                if chunk.done:
                    finish_reason = "tool_calls" if tool_calls else "stop"

                yield LLMChunk(
                    delta=delta,
                    thinking=thinking,
                    finish_reason=finish_reason,
                    tool_calls=tool_calls,
                    prompt_tokens=chunk.prompt_eval_count if chunk.done else None,
                    completion_tokens=chunk.eval_count if chunk.done else None,
                )
        except (LLMConnectionError, LLMModelNotFoundError, LLMBackendError):
            raise
        except Exception as e:
            raise _classify_error(e) from e