"""FallbackOllamaBackend — tries multiple Ollama servers with per-server model fallback.
Servers and models are independent priority lists.
Algorithm (server-first):
For each server in priority order (skip blacklisted servers):
For each model in priority order (skip blacklisted model+server pairs):
Try the request.
LLMConnectionError → blacklist the server, break to next server.
LLMModelNotFoundError → blacklist (server, model), continue to next model.
Success → use this result.
If all combinations exhausted → raise LLMBackendError.
Blacklists live in module-level sets (reset on server restart).
"""
import asyncio
import json
import time
import structlog
from dataclasses import dataclass
from pathlib import Path
from typing import AsyncGenerator
from navi.config import settings
from navi.exceptions import LLMBackendError, LLMConnectionError, LLMModelNotFoundError
from .base import LLMBackend, LLMChunk, LLMResponse, Message, ToolSchema
from .ollama import OllamaBackend
log = structlog.get_logger()
@dataclass
class ServerEntry:
host: str
api_key: str = ""
# Module-level blacklists — persist for the lifetime of the server process.
_dead_servers: dict[str, float] = {} # host -> time.monotonic()
_dead_models: dict[tuple[str, str], float] = {} # (host, model_name) -> time.monotonic()
_TTL = 300 # 5 minutes
def _is_dead_server(host: str) -> bool:
t = _dead_servers.get(host)
if t is None:
return False
if time.monotonic() - t > _TTL:
del _dead_servers[host]
return False
return True
def _is_dead_model(host: str, model: str) -> bool:
key = (host, model)
t = _dead_models.get(key)
if t is None:
return False
if time.monotonic() - t > _TTL:
del _dead_models[key]
return False
return True
def clear_blacklists() -> None:
"""Manually clear all dead-server and dead-model blacklists."""
_dead_servers.clear()
_dead_models.clear()
log.info("fallback.blacklists_cleared")
def load_servers_from_file(path: str) -> list[ServerEntry]:
"""Load server list from a JSON file: [{host, api_key?}, ...]
Returns an empty list on missing file, bad JSON, or entries without a host.
"""
try:
data = json.loads(Path(path).read_text(encoding="utf-8"))
except FileNotFoundError:
log.warning("fallback.servers_file_missing", path=path)
return []
except json.JSONDecodeError:
log.warning("fallback.servers_file_bad_json", path=path)
return []
servers: list[ServerEntry] = []
for i, entry in enumerate(data):
host = entry.get("host")
if not host:
log.warning("fallback.servers_file_missing_host", index=i, entry=entry)
continue
servers.append(ServerEntry(host=host, api_key=entry.get("api_key", "")))
return servers
class FallbackOllamaBackend(LLMBackend):
"""Ollama backend with automatic server and model fallback."""
def __init__(self, servers: list[ServerEntry]) -> None:
self._servers = servers
# Cache OllamaBackend instances by host to reuse AsyncClient
self._clients: dict[str, OllamaBackend] = {}
def _get_client(self, server: ServerEntry) -> OllamaBackend:
from navi.config import settings
if server.host not in self._clients:
ollama_http_timeout = max(
settings.ollama_request_timeout,
settings.llm_complete_timeout,
settings.llm_stream_first_chunk_timeout,
)
self._clients[server.host] = OllamaBackend(
model="", host=server.host, api_key=server.api_key,
timeout=ollama_http_timeout,
)
return self._clients[server.host]
@staticmethod
def _model_list(model: "list[str] | str | None") -> list[str]:
if isinstance(model, list):
return model if model else [""]
return [model] if model else [""]
async def complete(
self,
messages: list[Message],
tools: list[ToolSchema] | None = None,
temperature: float = 0.7,
model: "list[str] | str | None" = None,
think: bool | None = None,
max_tokens: int | None = None,
top_k: int | None = None,
top_p: float | None = None,
num_thread: int | None = None,
) -> LLMResponse:
models = self._model_list(model)
last_err: Exception = LLMBackendError("No backends configured")
single_server = len(self._servers) <= 1
for server in self._servers:
if _is_dead_server(server.host):
continue
for m in models:
if _is_dead_model(server.host, m):
continue
try:
return await self._get_client(server).complete(
messages, tools=tools, temperature=temperature,
model=m, think=think, max_tokens=max_tokens,
top_k=top_k, top_p=top_p, num_thread=num_thread,
)
except LLMConnectionError as e:
log.warning("fallback.server_dead", host=server.host, error=str(e))
last_err = e
if single_server:
# Do not blacklist the only server — the next request
# should retry immediately instead of being blocked
# for _TTL seconds.
break
_dead_servers[server.host] = time.monotonic()
break # Skip remaining models — server is gone
except LLMModelNotFoundError as e:
log.warning("fallback.model_dead", host=server.host, model=m, error=str(e))
if not single_server:
_dead_models[(server.host, m)] = time.monotonic()
last_err = e
# Continue to next model on the same server
raise LLMBackendError(f"All backends exhausted: {last_err}") from last_err
async def embed(
self,
texts: list[str],
model: "list[str] | str | None" = None,
) -> list[list[float]]:
models = self._model_list(model)
last_err: Exception = LLMBackendError("No backends configured")
single_server = len(self._servers) <= 1
for server in self._servers:
if _is_dead_server(server.host):
continue
for m in models:
if _is_dead_model(server.host, m):
continue
try:
return await self._get_client(server).embed(texts, model=m)
except LLMConnectionError as e:
log.warning("fallback.server_dead", host=server.host, error=str(e))
last_err = e
if single_server:
break
_dead_servers[server.host] = time.monotonic()
break
except LLMModelNotFoundError as e:
log.warning("fallback.model_dead", host=server.host, model=m, error=str(e))
if not single_server:
_dead_models[(server.host, m)] = time.monotonic()
last_err = e
continue
raise LLMBackendError(f"All backends exhausted: {last_err}") from last_err
async def stream_complete(
self,
messages: list[Message],
tools: list[ToolSchema] | None = None,
temperature: float = 0.7,
model: "list[str] | str | None" = None,
think: bool | None = None,
top_k: int | None = None,
top_p: float | None = None,
num_thread: int | None = None,
) -> AsyncGenerator[LLMChunk, None]:
models = self._model_list(model)
last_err: Exception = LLMBackendError("No backends configured")
single_server = len(self._servers) <= 1
for server in self._servers:
if _is_dead_server(server.host):
continue
for m in models:
if _is_dead_model(server.host, m):
continue
# Try up to 2 times on the same server+model before falling back
for attempt in range(2):
try:
gen = self._get_client(server).stream_complete(
messages, tools=tools, temperature=temperature, model=m, think=think,
top_k=top_k, top_p=top_p, num_thread=num_thread,
)
first = await asyncio.wait_for(
gen.__anext__(), timeout=settings.llm_stream_first_chunk_timeout
)
except StopAsyncIteration:
last_err = LLMModelNotFoundError("Empty stream from server")
break
except asyncio.TimeoutError as e:
log.warning(
"fallback.first_chunk_timeout",
host=server.host, model=m, attempt=attempt + 1,
)
last_err = LLMConnectionError(f"First-chunk timeout after {settings.llm_stream_first_chunk_timeout}s")
if attempt == 0:
await asyncio.sleep(2.0)
continue
if single_server:
break
_dead_servers[server.host] = time.monotonic()
break
except LLMConnectionError as e:
log.warning("fallback.server_dead", host=server.host, error=str(e))
last_err = e
if attempt == 0:
await asyncio.sleep(2.0)
continue
if single_server:
break
_dead_servers[server.host] = time.monotonic()
break
except LLMModelNotFoundError as e:
log.warning("fallback.model_dead", host=server.host, model=m, error=str(e))
if not single_server:
_dead_models[(server.host, m)] = time.monotonic()
last_err = e
break
else:
yield first
async for chunk in gen:
yield chunk
return
# If we fell through the attempt loop without returning,
# continue to the next model (outer loops handle last_err)
raise LLMBackendError(f"All backends exhausted: {last_err}") from last_err