Newer
Older
navi-1 / navi / llm / fallback.py
"""FallbackOllamaBackend — tries multiple Ollama servers with per-server model fallback.

Servers and models are independent priority lists.

Algorithm (server-first):
  For each server in priority order (skip blacklisted servers):
    For each model in priority order (skip blacklisted model+server pairs):
      Try the request.
      LLMConnectionError  → blacklist the server, break to next server.
      LLMModelNotFoundError → blacklist (server, model), continue to next model.
      Success → use this result.
  If all combinations exhausted → raise LLMBackendError.

Blacklists live in module-level sets (reset on server restart).
"""

import json
import structlog
from dataclasses import dataclass
from pathlib import Path
from typing import AsyncGenerator

from navi.exceptions import LLMBackendError, LLMConnectionError, LLMModelNotFoundError

from .base import LLMBackend, LLMChunk, LLMResponse, Message, ToolSchema
from .ollama import OllamaBackend

log = structlog.get_logger()


@dataclass
class ServerEntry:
    host: str
    api_key: str = ""


# Module-level blacklists — persist for the lifetime of the server process.
_dead_servers: set[str] = set()
_dead_models: set[tuple[str, str]] = set()  # (host, model_name)


def load_servers_from_file(path: str) -> list[ServerEntry]:
    """Load server list from a JSON file: [{host, api_key?}, ...]"""
    data = json.loads(Path(path).read_text(encoding="utf-8"))
    return [ServerEntry(host=e["host"], api_key=e.get("api_key", "")) for e in data]


class FallbackOllamaBackend(LLMBackend):
    """Ollama backend with automatic server and model fallback."""

    def __init__(self, servers: list[ServerEntry]) -> None:
        self._servers = servers
        # Cache OllamaBackend instances by host to reuse AsyncClient
        self._clients: dict[str, OllamaBackend] = {}

    def _get_client(self, server: ServerEntry) -> OllamaBackend:
        if server.host not in self._clients:
            self._clients[server.host] = OllamaBackend(
                model="", host=server.host, api_key=server.api_key
            )
        return self._clients[server.host]

    @staticmethod
    def _model_list(model: "list[str] | str | None") -> list[str]:
        if isinstance(model, list):
            return model if model else [""]
        return [model] if model else [""]

    async def complete(
        self,
        messages: list[Message],
        tools: list[ToolSchema] | None = None,
        temperature: float = 0.7,
        model: "list[str] | str | None" = None,
        think: bool | None = None,
        max_tokens: int | None = None,
    ) -> LLMResponse:
        models = self._model_list(model)
        last_err: Exception = LLMBackendError("No backends configured")

        for server in self._servers:
            if server.host in _dead_servers:
                continue
            for m in models:
                if (server.host, m) in _dead_models:
                    continue
                try:
                    return await self._get_client(server).complete(
                        messages, tools=tools, temperature=temperature,
                        model=m, think=think, max_tokens=max_tokens,
                    )
                except LLMConnectionError as e:
                    log.warning("fallback.server_dead", host=server.host, error=str(e))
                    _dead_servers.add(server.host)
                    last_err = e
                    break  # Skip remaining models — server is gone
                except LLMModelNotFoundError as e:
                    log.warning("fallback.model_dead", host=server.host, model=m, error=str(e))
                    _dead_models.add((server.host, m))
                    last_err = e
                    # Continue to next model on the same server

        raise LLMBackendError(f"All backends exhausted: {last_err}") from last_err

    async def stream(
        self,
        messages: list[Message],
        temperature: float = 0.7,
        model: "list[str] | str | None" = None,
    ) -> AsyncGenerator[LLMChunk, None]:
        models = self._model_list(model)
        last_err: Exception = LLMBackendError("No backends configured")

        for server in self._servers:
            if server.host in _dead_servers:
                continue
            for m in models:
                if (server.host, m) in _dead_models:
                    continue
                try:
                    gen = self._get_client(server).stream(messages, temperature=temperature, model=m)
                    first = await gen.__anext__()
                except StopAsyncIteration:
                    return
                except LLMConnectionError as e:
                    log.warning("fallback.server_dead", host=server.host, error=str(e))
                    _dead_servers.add(server.host)
                    last_err = e
                    break
                except LLMModelNotFoundError as e:
                    log.warning("fallback.model_dead", host=server.host, model=m, error=str(e))
                    _dead_models.add((server.host, m))
                    last_err = e
                    continue
                else:
                    yield first
                    async for chunk in gen:
                        yield chunk
                    return

        raise LLMBackendError(f"All backends exhausted: {last_err}") from last_err

    async def stream_complete(
        self,
        messages: list[Message],
        tools: list[ToolSchema] | None = None,
        temperature: float = 0.7,
        model: "list[str] | str | None" = None,
        think: bool | None = None,
    ) -> AsyncGenerator[LLMChunk, None]:
        models = self._model_list(model)
        last_err: Exception = LLMBackendError("No backends configured")

        for server in self._servers:
            if server.host in _dead_servers:
                continue
            for m in models:
                if (server.host, m) in _dead_models:
                    continue
                try:
                    gen = self._get_client(server).stream_complete(
                        messages, tools=tools, temperature=temperature, model=m, think=think,
                    )
                    first = await gen.__anext__()
                except StopAsyncIteration:
                    return
                except LLMConnectionError as e:
                    log.warning("fallback.server_dead", host=server.host, error=str(e))
                    _dead_servers.add(server.host)
                    last_err = e
                    break
                except LLMModelNotFoundError as e:
                    log.warning("fallback.model_dead", host=server.host, model=m, error=str(e))
                    _dead_models.add((server.host, m))
                    last_err = e
                    continue
                else:
                    yield first
                    async for chunk in gen:
                        yield chunk
                    return

        raise LLMBackendError(f"All backends exhausted: {last_err}") from last_err