"""Unit tests for Agent._iter_stream_guarded."""

import asyncio

import pytest

from navi.core.agent import _iter_stream_guarded
from navi.exceptions import LLMBackendError
from navi.llm.base import LLMChunk


async def _yield_chunks(chunks, delay: float = 0):
    for c in chunks:
        if delay:
            await asyncio.sleep(delay)
        yield c


class TestIterStreamGuarded:
    async def test_yields_all_chunks(self):
        chunks = [LLMChunk(delta="a"), LLMChunk(delta="b"), LLMChunk(delta="")]
        result = []
        async for c in _iter_stream_guarded(
            _yield_chunks(chunks), stop_event=None, first_chunk_timeout=5, chunk_timeout=5
        ):
            result.append(c)
        assert len(result) == 3
        assert [c.delta for c in result] == ["a", "b", ""]

    async def test_respects_stop_event(self):
        chunks = [LLMChunk(delta="a")]

        async def _slow():
            yield LLMChunk(delta="a")
            await asyncio.sleep(10)
            yield LLMChunk(delta="b")

        stop = asyncio.Event()
        stop.set()
        result = []
        async for c in _iter_stream_guarded(
            _slow(), stop_event=stop, first_chunk_timeout=5, chunk_timeout=5
        ):
            result.append(c)
        # Should stop before 'b' because stop_event is set
        assert len(result) == 1
        assert result[0].delta == "a"

    async def test_first_chunk_timeout(self):
        async def _very_slow():
            await asyncio.sleep(10)
            yield LLMChunk(delta="a")

        with pytest.raises(LLMBackendError) as exc_info:
            async for _ in _iter_stream_guarded(
                _very_slow(), stop_event=None, first_chunk_timeout=0.1, chunk_timeout=5
            ):
                pass
        assert "timed out" in str(exc_info.value).lower()

    async def test_chunk_timeout(self):
        async def _slow_gap():
            yield LLMChunk(delta="a")
            await asyncio.sleep(10)
            yield LLMChunk(delta="b")

        with pytest.raises(LLMBackendError) as exc_info:
            async for _ in _iter_stream_guarded(
                _slow_gap(), stop_event=None, first_chunk_timeout=5, chunk_timeout=0.1
            ):
                pass
        assert "timed out" in str(exc_info.value).lower()

    async def test_empty_stream(self):
        async def _empty():
            return
            yield  # make it a generator

        result = []
        async for c in _iter_stream_guarded(
            _empty(), stop_event=None, first_chunk_timeout=1, chunk_timeout=1
        ):
            result.append(c)
        assert result == []
