"""Ollama LLM backend."""
import uuid
from typing import AsyncGenerator
import ollama as ollama_client
from navi.config import settings
from navi.exceptions import LLMBackendError, LLMConnectionError, LLMModelNotFoundError
from .base import LLMBackend, LLMChunk, LLMResponse, Message, ToolCallRequest, ToolSchema
def _clean_base64_image(img: str) -> str | None:
"""Strip data URI prefix and validate that result is non-empty base64."""
if not img:
return None
s = img.strip()
if s.startswith("data:"):
if "," in s:
s = s.split(",", 1)[1]
else:
return None
return s if s else None
def _to_ollama_messages(messages: list[Message]) -> list[dict]:
result = []
for m in messages:
msg: dict = {"role": m.role, "content": m.content or ""}
if m.images:
cleaned = [_clean_base64_image(img) for img in m.images]
cleaned = [img for img in cleaned if img is not None]
if cleaned:
msg["images"] = cleaned
if m.tool_calls:
msg["tool_calls"] = [
{"function": {"name": tc.name, "arguments": tc.arguments}}
for tc in m.tool_calls
]
result.append(msg)
return result
def _to_ollama_tools(tools: list[ToolSchema]) -> list[dict]:
return [t.model_dump() for t in tools]
def _base_options(
temperature: float,
max_tokens: int | None = None,
top_k: int | None = None,
top_p: float | None = None,
num_thread: int | None = None,
) -> dict:
opts: dict = {"temperature": temperature, "num_ctx": settings.ollama_num_ctx}
if max_tokens is not None:
opts["num_predict"] = max_tokens
if top_k is not None:
opts["top_k"] = top_k
if top_p is not None:
opts["top_p"] = top_p
if num_thread is not None:
opts["num_thread"] = num_thread
return opts
def _resolve_think(think: bool | None) -> bool | None:
# think=None → use global setting; think=False → force off even if global is True
return settings.ollama_think if think is None else think
def _resolve_model(model: "list[str] | str | None", default: str) -> str:
"""Normalize model param: list → first element, None → default."""
if isinstance(model, list):
return model[0] if model else default
return model or default
def _classify_error(e: Exception) -> Exception:
"""Wrap raw Ollama/network exceptions into typed LLM exceptions."""
if isinstance(e, ollama_client.RequestError):
return LLMConnectionError(str(e))
if isinstance(e, ollama_client.ResponseError):
msg = e.error.lower()
if "not found" in msg or "does not exist" in msg or e.status_code == 404:
return LLMModelNotFoundError(str(e))
return LLMBackendError(str(e))
# Catch httpx / socket connection failures by message
err_type = type(e).__name__.lower()
err_str = str(e).lower()
if "timeout" in err_type:
return LLMConnectionError(str(e) or type(e).__name__)
if any(kw in err_str for kw in ("connect", "connection refused", "name or service not known",
"network", "timeout", "unreachable", "nodename")):
return LLMConnectionError(str(e))
return LLMBackendError(str(e))
class OllamaBackend(LLMBackend):
def __init__(
self,
model: str,
host: str = "http://localhost:11434",
api_key: str = "",
timeout: int = 30,
):
self.model = model
self._host = host
headers = {"Authorization": f"Bearer {api_key}"} if api_key else None
self._client = ollama_client.AsyncClient(host=host, headers=headers, timeout=timeout)
async def complete(
self,
messages: list[Message],
tools: list[ToolSchema] | None = None,
temperature: float = 0.7,
model: "list[str] | str | None" = None,
think: bool | None = None,
max_tokens: int | None = None,
top_k: int | None = None,
top_p: float | None = None,
num_thread: int | None = None,
) -> LLMResponse:
resolved = _resolve_model(model, self.model)
try:
kwargs: dict = {
"model": resolved,
"messages": _to_ollama_messages(messages),
"options": _base_options(temperature, max_tokens=max_tokens, top_k=top_k, top_p=top_p, num_thread=num_thread),
"stream": False,
"think": _resolve_think(think),
}
if tools:
kwargs["tools"] = _to_ollama_tools(tools)
response = await self._client.chat(**kwargs)
msg = response.message
tool_calls = None
if msg.tool_calls:
tool_calls = [
ToolCallRequest(
id=str(uuid.uuid4()),
name=tc.function.name,
arguments=dict(tc.function.arguments),
)
for tc in msg.tool_calls
]
finish_reason = "tool_calls" if tool_calls else "stop"
return LLMResponse(
content=msg.content or None,
tool_calls=tool_calls,
finish_reason=finish_reason,
thinking=getattr(msg, "thinking", None) or None,
prompt_tokens=getattr(response, "prompt_eval_count", None),
completion_tokens=getattr(response, "eval_count", None),
)
except (LLMConnectionError, LLMModelNotFoundError, LLMBackendError):
raise
except Exception as e:
import structlog
log = structlog.get_logger()
log.warning(
"llm.ollama.chat_error",
model=resolved,
message_count=len(messages),
tools_count=len(tools) if tools else 0,
error=str(e),
)
raise _classify_error(e) from e
async def stream_complete(
self,
messages: list[Message],
tools: list[ToolSchema] | None = None,
temperature: float = 0.7,
model: "list[str] | str | None" = None,
think: bool | None = None,
top_k: int | None = None,
top_p: float | None = None,
num_thread: int | None = None,
) -> AsyncGenerator[LLMChunk, None]:
resolved = _resolve_model(model, self.model)
try:
kwargs: dict = {
"model": resolved,
"messages": _to_ollama_messages(messages),
"options": _base_options(temperature, top_k=top_k, top_p=top_p, num_thread=num_thread),
"stream": True,
"think": _resolve_think(think),
}
if tools:
kwargs["tools"] = _to_ollama_tools(tools)
async for chunk in await self._client.chat(**kwargs):
thinking = getattr(chunk.message, "thinking", None) or None
delta = chunk.message.content or None
tool_calls = None
if chunk.message.tool_calls:
tool_calls = [
ToolCallRequest(
id=str(uuid.uuid4()),
name=tc.function.name,
arguments=dict(tc.function.arguments),
)
for tc in chunk.message.tool_calls
]
finish_reason = None
if chunk.done:
finish_reason = "tool_calls" if tool_calls else "stop"
yield LLMChunk(
delta=delta,
thinking=thinking,
finish_reason=finish_reason,
tool_calls=tool_calls,
prompt_tokens=chunk.prompt_eval_count if chunk.done else None,
completion_tokens=chunk.eval_count if chunk.done else None,
)
except (LLMConnectionError, LLMModelNotFoundError, LLMBackendError):
raise
except Exception as e:
import structlog
log = structlog.get_logger()
log.warning(
"llm.ollama.chat_error",
model=resolved,
message_count=len(messages),
tools_count=len(tools) if tools else 0,
error=str(e),
)
raise _classify_error(e) from e
async def embed(
self,
texts: list[str],
model: "list[str] | str | None" = None,
) -> list[list[float]]:
resolved = _resolve_model(model, self.model)
try:
response = await self._client.embed(model=resolved, input=texts)
return response.embeddings
except (LLMConnectionError, LLMModelNotFoundError, LLMBackendError):
raise
except Exception as e:
raise _classify_error(e) from e