Newer
Older
navi-1 / navi / llm / fallback.py
@Eugene Sukhodolskiy Eugene Sukhodolskiy on 12 May 8 KB Remove dead LLMBackend.stream() method
"""FallbackOllamaBackend — tries multiple Ollama servers with per-server model fallback.

Servers and models are independent priority lists.

Algorithm (server-first):
  For each server in priority order (skip blacklisted servers):
    For each model in priority order (skip blacklisted model+server pairs):
      Try the request.
      LLMConnectionError  → blacklist the server, break to next server.
      LLMModelNotFoundError → blacklist (server, model), continue to next model.
      Success → use this result.
  If all combinations exhausted → raise LLMBackendError.

Blacklists live in module-level sets (reset on server restart).
"""

import json
import time
import structlog
from dataclasses import dataclass
from pathlib import Path
from typing import AsyncGenerator

from navi.exceptions import LLMBackendError, LLMConnectionError, LLMModelNotFoundError

from .base import LLMBackend, LLMChunk, LLMResponse, Message, ToolSchema
from .ollama import OllamaBackend

log = structlog.get_logger()


@dataclass
class ServerEntry:
    host: str
    api_key: str = ""


# Module-level blacklists — persist for the lifetime of the server process.
_dead_servers: dict[str, float] = {}   # host -> time.monotonic()
_dead_models: dict[tuple[str, str], float] = {}  # (host, model_name) -> time.monotonic()
_TTL = 300  # 5 minutes


def _is_dead_server(host: str) -> bool:
    t = _dead_servers.get(host)
    if t is None:
        return False
    if time.monotonic() - t > _TTL:
        del _dead_servers[host]
        return False
    return True


def _is_dead_model(host: str, model: str) -> bool:
    key = (host, model)
    t = _dead_models.get(key)
    if t is None:
        return False
    if time.monotonic() - t > _TTL:
        del _dead_models[key]
        return False
    return True


def clear_blacklists() -> None:
    """Manually clear all dead-server and dead-model blacklists."""
    _dead_servers.clear()
    _dead_models.clear()
    log.info("fallback.blacklists_cleared")


def load_servers_from_file(path: str) -> list[ServerEntry]:
    """Load server list from a JSON file: [{host, api_key?}, ...]

    Returns an empty list on missing file, bad JSON, or entries without a host.
    """
    try:
        data = json.loads(Path(path).read_text(encoding="utf-8"))
    except FileNotFoundError:
        log.warning("fallback.servers_file_missing", path=path)
        return []
    except json.JSONDecodeError:
        log.warning("fallback.servers_file_bad_json", path=path)
        return []

    servers: list[ServerEntry] = []
    for i, entry in enumerate(data):
        host = entry.get("host")
        if not host:
            log.warning("fallback.servers_file_missing_host", index=i, entry=entry)
            continue
        servers.append(ServerEntry(host=host, api_key=entry.get("api_key", "")))
    return servers


class FallbackOllamaBackend(LLMBackend):
    """Ollama backend with automatic server and model fallback."""

    def __init__(self, servers: list[ServerEntry]) -> None:
        self._servers = servers
        # Cache OllamaBackend instances by host to reuse AsyncClient
        self._clients: dict[str, OllamaBackend] = {}

    def _get_client(self, server: ServerEntry) -> OllamaBackend:
        from navi.config import settings
        if server.host not in self._clients:
            ollama_http_timeout = max(
                settings.ollama_request_timeout,
                settings.llm_complete_timeout,
                settings.llm_stream_first_chunk_timeout,
            )
            self._clients[server.host] = OllamaBackend(
                model="", host=server.host, api_key=server.api_key,
                timeout=ollama_http_timeout,
            )
        return self._clients[server.host]

    @staticmethod
    def _model_list(model: "list[str] | str | None") -> list[str]:
        if isinstance(model, list):
            return model if model else [""]
        return [model] if model else [""]

    async def complete(
        self,
        messages: list[Message],
        tools: list[ToolSchema] | None = None,
        temperature: float = 0.7,
        model: "list[str] | str | None" = None,
        think: bool | None = None,
        max_tokens: int | None = None,
        top_k: int | None = None,
        top_p: float | None = None,
        num_thread: int | None = None,
    ) -> LLMResponse:
        models = self._model_list(model)
        last_err: Exception = LLMBackendError("No backends configured")

        for server in self._servers:
            if _is_dead_server(server.host):
                continue
            for m in models:
                if _is_dead_model(server.host, m):
                    continue
                try:
                    return await self._get_client(server).complete(
                        messages, tools=tools, temperature=temperature,
                        model=m, think=think, max_tokens=max_tokens,
                        top_k=top_k, top_p=top_p, num_thread=num_thread,
                    )
                except LLMConnectionError as e:
                    log.warning("fallback.server_dead", host=server.host, error=str(e))
                    _dead_servers[server.host] = time.monotonic()
                    last_err = e
                    break  # Skip remaining models — server is gone
                except LLMModelNotFoundError as e:
                    log.warning("fallback.model_dead", host=server.host, model=m, error=str(e))
                    _dead_models[(server.host, m)] = time.monotonic()
                    last_err = e
                    # Continue to next model on the same server

        raise LLMBackendError(f"All backends exhausted: {last_err}") from last_err

    async def embed(
        self,
        texts: list[str],
        model: "list[str] | str | None" = None,
    ) -> list[list[float]]:
        models = self._model_list(model)
        last_err: Exception = LLMBackendError("No backends configured")

        for server in self._servers:
            if _is_dead_server(server.host):
                continue
            for m in models:
                if _is_dead_model(server.host, m):
                    continue
                try:
                    return await self._get_client(server).embed(texts, model=m)
                except LLMConnectionError as e:
                    log.warning("fallback.server_dead", host=server.host, error=str(e))
                    _dead_servers[server.host] = time.monotonic()
                    last_err = e
                    break
                except LLMModelNotFoundError as e:
                    log.warning("fallback.model_dead", host=server.host, model=m, error=str(e))
                    _dead_models[(server.host, m)] = time.monotonic()
                    last_err = e
                    continue

        raise LLMBackendError(f"All backends exhausted: {last_err}") from last_err

    async def stream_complete(
        self,
        messages: list[Message],
        tools: list[ToolSchema] | None = None,
        temperature: float = 0.7,
        model: "list[str] | str | None" = None,
        think: bool | None = None,
        top_k: int | None = None,
        top_p: float | None = None,
        num_thread: int | None = None,
    ) -> AsyncGenerator[LLMChunk, None]:
        models = self._model_list(model)
        last_err: Exception = LLMBackendError("No backends configured")

        for server in self._servers:
            if _is_dead_server(server.host):
                continue
            for m in models:
                if _is_dead_model(server.host, m):
                    continue
                try:
                    gen = self._get_client(server).stream_complete(
                        messages, tools=tools, temperature=temperature, model=m, think=think,
                        top_k=top_k, top_p=top_p, num_thread=num_thread,
                    )
                    first = await gen.__anext__()
                except StopAsyncIteration:
                    last_err = LLMModelNotFoundError("Empty stream from server")
                    continue
                except LLMConnectionError as e:
                    log.warning("fallback.server_dead", host=server.host, error=str(e))
                    _dead_servers[server.host] = time.monotonic()
                    last_err = e
                    break
                except LLMModelNotFoundError as e:
                    log.warning("fallback.model_dead", host=server.host, model=m, error=str(e))
                    _dead_models[(server.host, m)] = time.monotonic()
                    last_err = e
                    continue
                else:
                    yield first
                    async for chunk in gen:
                        yield chunk
                    return

        raise LLMBackendError(f"All backends exhausted: {last_err}") from last_err