"""Ollama LLM backend."""
import uuid
from typing import AsyncGenerator
import ollama as ollama_client
from navi.config import settings
from navi.exceptions import LLMBackendError, LLMConnectionError, LLMModelNotFoundError
from .base import LLMBackend, LLMChunk, LLMResponse, Message, ToolCallRequest, ToolSchema
def _to_ollama_messages(messages: list[Message]) -> list[dict]:
result = []
for m in messages:
msg: dict = {"role": m.role, "content": m.content or ""}
if m.images:
msg["images"] = m.images # list of base64 strings, Ollama format
if m.tool_calls:
msg["tool_calls"] = [
{"function": {"name": tc.name, "arguments": tc.arguments}}
for tc in m.tool_calls
]
result.append(msg)
return result
def _to_ollama_tools(tools: list[ToolSchema]) -> list[dict]:
return [t.model_dump() for t in tools]
def _base_options(
temperature: float,
think: bool | None = None,
max_tokens: int | None = None,
top_k: int | None = None,
top_p: float | None = None,
) -> dict:
opts: dict = {"temperature": temperature, "num_ctx": settings.ollama_num_ctx}
# think=None → use global setting; think=False → force off even if global is True
effective_think = settings.ollama_think if think is None else think
if effective_think:
opts["think"] = True
if max_tokens is not None:
opts["num_predict"] = max_tokens
if top_k is not None:
opts["top_k"] = top_k
if top_p is not None:
opts["top_p"] = top_p
return opts
def _resolve_model(model: "list[str] | str | None", default: str) -> str:
"""Normalize model param: list → first element, None → default."""
if isinstance(model, list):
return model[0] if model else default
return model or default
def _classify_error(e: Exception) -> Exception:
"""Wrap raw Ollama/network exceptions into typed LLM exceptions."""
if isinstance(e, ollama_client.RequestError):
return LLMConnectionError(str(e))
if isinstance(e, ollama_client.ResponseError):
msg = e.error.lower()
if "not found" in msg or "does not exist" in msg or e.status_code == 404:
return LLMModelNotFoundError(str(e))
return LLMBackendError(str(e))
# Catch httpx / socket connection failures by message
err_str = str(e).lower()
if any(kw in err_str for kw in ("connect", "connection refused", "name or service not known",
"network", "timeout", "unreachable", "nodename")):
return LLMConnectionError(str(e))
return LLMBackendError(str(e))
class OllamaBackend(LLMBackend):
def __init__(self, model: str, host: str = "http://localhost:11434", api_key: str = ""):
self.model = model
self._host = host
headers = {"Authorization": f"Bearer {api_key}"} if api_key else None
self._client = ollama_client.AsyncClient(host=host, headers=headers)
async def complete(
self,
messages: list[Message],
tools: list[ToolSchema] | None = None,
temperature: float = 0.7,
model: "list[str] | str | None" = None,
think: bool | None = None,
max_tokens: int | None = None,
top_k: int | None = None,
top_p: float | None = None,
) -> LLMResponse:
resolved = _resolve_model(model, self.model)
try:
kwargs: dict = {
"model": resolved,
"messages": _to_ollama_messages(messages),
"options": _base_options(temperature, think=think, max_tokens=max_tokens, top_k=top_k, top_p=top_p),
"stream": False,
}
if tools:
kwargs["tools"] = _to_ollama_tools(tools)
response = await self._client.chat(**kwargs)
msg = response.message
tool_calls = None
if msg.tool_calls:
tool_calls = [
ToolCallRequest(
id=str(uuid.uuid4()),
name=tc.function.name,
arguments=dict(tc.function.arguments),
)
for tc in msg.tool_calls
]
finish_reason = "tool_calls" if tool_calls else "stop"
return LLMResponse(
content=msg.content or None,
tool_calls=tool_calls,
finish_reason=finish_reason,
thinking=getattr(msg, "thinking", None) or None,
prompt_tokens=getattr(response, "prompt_eval_count", None) or None,
completion_tokens=getattr(response, "eval_count", None) or None,
)
except (LLMConnectionError, LLMModelNotFoundError, LLMBackendError):
raise
except Exception as e:
raise _classify_error(e) from e
async def stream(
self,
messages: list[Message],
temperature: float = 0.7,
model: "list[str] | str | None" = None,
top_k: int | None = None,
top_p: float | None = None,
) -> AsyncGenerator[LLMChunk, None]:
resolved = _resolve_model(model, self.model)
try:
async for chunk in await self._client.chat(
model=resolved,
messages=_to_ollama_messages(messages),
options=_base_options(temperature, top_k=top_k, top_p=top_p),
stream=True,
):
thinking = getattr(chunk.message, "thinking", None) or None
delta = chunk.message.content or None
finish_reason = "stop" if chunk.done else None
yield LLMChunk(
delta=delta,
thinking=thinking,
finish_reason=finish_reason,
prompt_tokens=chunk.prompt_eval_count if chunk.done else None,
completion_tokens=chunk.eval_count if chunk.done else None,
)
except (LLMConnectionError, LLMModelNotFoundError, LLMBackendError):
raise
except Exception as e:
raise _classify_error(e) from e
async def stream_complete(
self,
messages: list[Message],
tools: list[ToolSchema] | None = None,
temperature: float = 0.7,
model: "list[str] | str | None" = None,
think: bool | None = None,
top_k: int | None = None,
top_p: float | None = None,
) -> AsyncGenerator[LLMChunk, None]:
resolved = _resolve_model(model, self.model)
try:
kwargs: dict = {
"model": resolved,
"messages": _to_ollama_messages(messages),
"options": _base_options(temperature, think=think, top_k=top_k, top_p=top_p),
"stream": True,
}
if tools:
kwargs["tools"] = _to_ollama_tools(tools)
async for chunk in await self._client.chat(**kwargs):
thinking = getattr(chunk.message, "thinking", None) or None
delta = chunk.message.content or None
tool_calls = None
if chunk.message.tool_calls:
tool_calls = [
ToolCallRequest(
id=str(uuid.uuid4()),
name=tc.function.name,
arguments=dict(tc.function.arguments),
)
for tc in chunk.message.tool_calls
]
finish_reason = None
if chunk.done:
finish_reason = "tool_calls" if tool_calls else "stop"
yield LLMChunk(
delta=delta,
thinking=thinking,
finish_reason=finish_reason,
tool_calls=tool_calls,
prompt_tokens=chunk.prompt_eval_count if chunk.done else None,
completion_tokens=chunk.eval_count if chunk.done else None,
)
except (LLMConnectionError, LLMModelNotFoundError, LLMBackendError):
raise
except Exception as e:
raise _classify_error(e) from e