Newer
Older
navi-1 / navi / core / stream_guard.py
"""Stream safety wrapper for LLM generators."""

from __future__ import annotations

import asyncio
from typing import AsyncGenerator

from navi.exceptions import LLMBackendError
from navi.llm.base import LLMChunk


async def _iter_stream_guarded(
    stream_gen: "AsyncGenerator[LLMChunk, None]",
    stop_event: "asyncio.Event | None",
    first_chunk_timeout: float,
    chunk_timeout: float,
) -> "AsyncGenerator[LLMChunk, None]":
    """
    Wraps a streaming LLM generator with two safety mechanisms:

    1. Stop-event responsiveness during prefill.
       Normally, the agent only checks stop_event *between* chunks. During the
       prefill phase (processing input tokens) Ollama emits no chunks at all —
       the first await can block for minutes on large contexts. This wrapper polls
       stop_event every second so the user's Stop button works even then.

    2. Timeouts as a last-resort safety net.
       first_chunk_timeout: how long to wait for the first token (prefill).
       chunk_timeout: max gap between subsequent tokens.
       On timeout the generator is closed, which terminates the HTTP connection
       to Ollama → Ollama halts generation → GPU load drops to idle.
    """
    first = True
    chunk_task: asyncio.Task | None = None
    try:
        while True:
            timeout = first_chunk_timeout if first else chunk_timeout
            # Create one task per chunk; reuse across poll iterations so we
            # don't accidentally start multiple concurrent __anext__ calls.
            chunk_task = asyncio.ensure_future(stream_gen.__anext__())
            elapsed = 0.0

            while True:
                done, _ = await asyncio.wait({chunk_task}, timeout=1.0)
                if done:
                    break
                elapsed += 1.0
                if stop_event and stop_event.is_set():
                    chunk_task.cancel()
                    try:
                        await chunk_task
                    except (asyncio.CancelledError, Exception):
                        pass
                    chunk_task = None
                    return
                if elapsed >= timeout:
                    chunk_task.cancel()
                    try:
                        await chunk_task
                    except (asyncio.CancelledError, Exception):
                        pass
                    chunk_task = None
                    label = "first token (context may be too large for this model)" if first else "next token"
                    raise LLMBackendError(
                        f"LLM timed out after {elapsed:.0f}s waiting for {label}."
                    )

            try:
                chunk = chunk_task.result()
            except StopAsyncIteration:
                chunk_task = None
                return

            chunk_task = None
            first = False
            yield chunk

            if stop_event and stop_event.is_set():
                return

    finally:
        # Cancel any in-flight __anext__ task so we don't leave a zombie
        # coroutine holding an open HTTP connection to Ollama.
        if chunk_task is not None and not chunk_task.done():
            chunk_task.cancel()
            try:
                await chunk_task
            except (asyncio.CancelledError, Exception):
                pass
        # Closing the generator terminates the HTTP connection to Ollama,
        # which signals it to stop generating (GPU returns to idle).
        try:
            await stream_gen.aclose()
        except Exception:
            pass