Newer
Older
navi-1 / navi / llm / ollama.py
"""Ollama LLM backend."""

import uuid
from typing import AsyncGenerator

import ollama as ollama_client

from navi.exceptions import LLMBackendError

from .base import LLMBackend, LLMChunk, LLMResponse, Message, ToolCallRequest, ToolSchema


def _to_ollama_messages(messages: list[Message]) -> list[dict]:
    result = []
    for m in messages:
        msg: dict = {"role": m.role, "content": m.content or ""}
        if m.tool_calls:
            msg["tool_calls"] = [
                {"function": {"name": tc.name, "arguments": tc.arguments}}
                for tc in m.tool_calls
            ]
        if m.tool_call_id:
            # Ollama uses role="tool" with content
            pass
        result.append(msg)
    return result


def _to_ollama_tools(tools: list[ToolSchema]) -> list[dict]:
    return [t.model_dump() for t in tools]


class OllamaBackend(LLMBackend):
    def __init__(self, model: str, host: str = "http://localhost:11434"):
        self.model = model
        self._client = ollama_client.AsyncClient(host=host)

    async def complete(
        self,
        messages: list[Message],
        tools: list[ToolSchema] | None = None,
        temperature: float = 0.7,
    ) -> LLMResponse:
        try:
            kwargs: dict = {
                "model": self.model,
                "messages": _to_ollama_messages(messages),
                "options": {"temperature": temperature},
                "stream": False,
            }
            if tools:
                kwargs["tools"] = _to_ollama_tools(tools)

            response = await self._client.chat(**kwargs)
            msg = response.message

            tool_calls = None
            if msg.tool_calls:
                tool_calls = [
                    ToolCallRequest(
                        id=str(uuid.uuid4()),
                        name=tc.function.name,
                        arguments=dict(tc.function.arguments),
                    )
                    for tc in msg.tool_calls
                ]

            finish_reason = "tool_calls" if tool_calls else "stop"
            return LLMResponse(
                content=msg.content or None,
                tool_calls=tool_calls,
                finish_reason=finish_reason,
            )
        except Exception as e:
            raise LLMBackendError(str(e)) from e

    async def stream(
        self,
        messages: list[Message],
        temperature: float = 0.7,
    ) -> AsyncGenerator[LLMChunk, None]:
        try:
            async for chunk in await self._client.chat(
                model=self.model,
                messages=_to_ollama_messages(messages),
                options={"temperature": temperature},
                stream=True,
            ):
                delta = chunk.message.content or None
                finish_reason = "stop" if chunk.done else None
                yield LLMChunk(delta=delta, finish_reason=finish_reason)
        except Exception as e:
            raise LLMBackendError(str(e)) from e