diff --git a/navi/config.py b/navi/config.py index 98ae75a..9ca6ffa 100644 --- a/navi/config.py +++ b/navi/config.py @@ -101,8 +101,8 @@ # complete() is non-streaming (planning, compression) — blocked until full response llm_complete_timeout: int = 120 # stream_complete(): how long to wait for the FIRST token (prefill phase) - # Large contexts can take 60-90s to prefill; 180s is a safe upper bound - llm_stream_first_chunk_timeout: int = 180 + # Large contexts can take 60-90s to prefill; 90s matches user expectation + llm_stream_first_chunk_timeout: int = 90 # stream_complete(): max gap between any two subsequent tokens llm_stream_chunk_timeout: int = 60 diff --git a/navi/llm/fallback.py b/navi/llm/fallback.py index f3197c5..49e0b29 100644 --- a/navi/llm/fallback.py +++ b/navi/llm/fallback.py @@ -14,6 +14,7 @@ Blacklists live in module-level sets (reset on server restart). """ +import asyncio import json import time import structlog @@ -21,6 +22,7 @@ from pathlib import Path from typing import AsyncGenerator +from navi.config import settings from navi.exceptions import LLMBackendError, LLMConnectionError, LLMModelNotFoundError from .base import LLMBackend, LLMChunk, LLMResponse, Message, ToolSchema @@ -222,32 +224,54 @@ for m in models: if _is_dead_model(server.host, m): continue - try: - gen = self._get_client(server).stream_complete( - messages, tools=tools, temperature=temperature, model=m, think=think, - top_k=top_k, top_p=top_p, num_thread=num_thread, - ) - first = await gen.__anext__() - except StopAsyncIteration: - last_err = LLMModelNotFoundError("Empty stream from server") - continue - except LLMConnectionError as e: - log.warning("fallback.server_dead", host=server.host, error=str(e)) - last_err = e - if single_server: + # Try up to 2 times on the same server+model before falling back + for attempt in range(2): + try: + gen = self._get_client(server).stream_complete( + messages, tools=tools, temperature=temperature, model=m, think=think, + top_k=top_k, top_p=top_p, num_thread=num_thread, + ) + first = await asyncio.wait_for( + gen.__anext__(), timeout=settings.llm_stream_first_chunk_timeout + ) + except StopAsyncIteration: + last_err = LLMModelNotFoundError("Empty stream from server") break - _dead_servers[server.host] = time.monotonic() - break - except LLMModelNotFoundError as e: - log.warning("fallback.model_dead", host=server.host, model=m, error=str(e)) - if not single_server: - _dead_models[(server.host, m)] = time.monotonic() - last_err = e - continue - else: - yield first - async for chunk in gen: - yield chunk - return + except asyncio.TimeoutError as e: + log.warning( + "fallback.first_chunk_timeout", + host=server.host, model=m, attempt=attempt + 1, + ) + last_err = LLMConnectionError(f"First-chunk timeout after {settings.llm_stream_first_chunk_timeout}s") + if attempt == 0: + await asyncio.sleep(2.0) + continue + if single_server: + break + _dead_servers[server.host] = time.monotonic() + break + except LLMConnectionError as e: + log.warning("fallback.server_dead", host=server.host, error=str(e)) + last_err = e + if attempt == 0: + await asyncio.sleep(2.0) + continue + if single_server: + break + _dead_servers[server.host] = time.monotonic() + break + except LLMModelNotFoundError as e: + log.warning("fallback.model_dead", host=server.host, model=m, error=str(e)) + if not single_server: + _dead_models[(server.host, m)] = time.monotonic() + last_err = e + break + else: + yield first + async for chunk in gen: + yield chunk + return + # If we fell through the attempt loop without returning, + # continue to the next model (outer loops handle last_err) raise LLMBackendError(f"All backends exhausted: {last_err}") from last_err