Newer
Older
navi-1 / tests / unit / core / test_compressor.py
@Eugene Sukhodolskiy Eugene Sukhodolskiy on 25 May 18 KB Add meta-summary for multi-level compression
"""Unit tests for context compressor."""

import pytest

from navi.core.compressor import (
    ContextCompressor,
    _format_for_summary,
    compress_context,
    partition_messages,
    should_compress,
)
from navi.llm.base import Message, ToolCallRequest
from tests.conftest_factory import FakeLLMBackend


class TestShouldCompress:
    def test_below_threshold(self):
        assert should_compress(100, 1000, 0.7) is False

    def test_at_threshold(self):
        assert should_compress(700, 1000, 0.7) is True

    def test_above_threshold(self):
        assert should_compress(800, 1000, 0.7) is True


class TestPartitionMessages:
    def test_empty(self):
        old, recent = partition_messages([], keep_recent=2)
        assert old == []
        assert recent == []

    def test_fewer_turns_than_keep(self):
        msgs = [
            Message(role="user", content="hi"),
            Message(role="assistant", content="hello"),
        ]
        old, recent = partition_messages(msgs, keep_recent=5)
        assert old == []
        assert recent == msgs

    def test_splits_into_old_and_recent(self):
        msgs = [
            Message(role="user", content="1"),
            Message(role="assistant", content="a1"),
            Message(role="user", content="2"),
            Message(role="assistant", content="a2"),
            Message(role="user", content="3"),
            Message(role="assistant", content="a3"),
        ]
        old, recent = partition_messages(msgs, keep_recent=2)
        assert len(old) == 2  # turn 1
        assert len(recent) == 4  # turns 2+3

    def test_system_messages_ignored(self):
        msgs = [
            Message(role="system", content="sys"),
            Message(role="user", content="1"),
            Message(role="assistant", content="a1"),
        ]
        old, recent = partition_messages(msgs, keep_recent=5)
        assert recent == [m for m in msgs if m.role != "system"]

    def test_tool_calls_stay_with_assistant(self):
        msgs = [
            Message(role="user", content="1"),
            Message(
                role="assistant",
                content="",
                tool_calls=[ToolCallRequest(id="1", name="fs", arguments={})],
            ),
            Message(role="tool", content="result", name="fs", tool_call_id="1"),
            Message(role="user", content="2"),
            Message(role="assistant", content="ok"),
        ]
        old, recent = partition_messages(msgs, keep_recent=1)
        assert len(old) == 3
        assert len(recent) == 2

    def test_mid_turn_fallback_keeps_user_and_recent_messages(self):
        msgs = [Message(role="user", content="build a model")]
        for i in range(5):
            msgs.append(Message(
                role="assistant",
                tool_calls=[ToolCallRequest(id=str(i), name="fs", arguments={})],
            ))
            msgs.append(Message(role="tool", content=f"result {i}", name="fs", tool_call_id=str(i)))

        old, recent = partition_messages(msgs, keep_recent=8, keep_recent_messages=4)

        assert old
        assert recent[0].role == "user"
        assert recent[0].content == "build a model"
        assert len(recent) == 5  # original user message + 4 newest in-flight messages


class TestFormatForSummary:
    def test_user_message(self):
        msgs = [Message(role="user", content="hello")]
        text, images = _format_for_summary(msgs)
        assert "User: hello" in text
        assert images == []

    def test_assistant_message(self):
        msgs = [Message(role="assistant", content="world")]
        text, images = _format_for_summary(msgs)
        assert "Assistant: world" in text

    def test_tool_call_block(self):
        msgs = [
            Message(
                role="assistant",
                tool_calls=[ToolCallRequest(id="1", name="fs", arguments={"path": "/tmp"})],
            ),
            Message(role="tool", content="file data", name="fs", tool_call_id="1"),
        ]
        text, images = _format_for_summary(msgs)
        assert "[Tool call: fs" in text
        assert "[Tool result: fs" in text

    def test_images_collected(self):
        msgs = [Message(role="user", content="look", images=["base64img"])]
        text, images = _format_for_summary(msgs)
        assert images == ["base64img"]
        assert "[+ 1 image(s)]" in text

    def test_summary_message_folded(self):
        msgs = [Message(role="user", content="old summary", is_summary=True)]
        text, _ = _format_for_summary(msgs)
        assert "old summary" in text


class TestCompressContext:
    async def test_nothing_to_compress(self):
        backend = FakeLLMBackend()
        result = await compress_context(
            context=[Message(role="user", content="hi")],
            llm=backend,
            model="test",
            temperature=0.3,
            keep_recent=8,
        )
        assert result is None

    async def test_compresses_old_turns(self):
        backend = FakeLLMBackend(responses=["Summary of old stuff"])
        context = [
            Message(role="system", content="sys"),
            Message(role="user", content="1"),
            Message(role="assistant", content="a1"),
            Message(role="user", content="2"),
            Message(role="assistant", content="a2"),
            Message(role="user", content="3"),
            Message(role="assistant", content="a3"),
        ]
        new_context, summary = await compress_context(
            context=context,
            llm=backend,
            model="test",
            temperature=0.3,
            keep_recent=2,
        )
        assert summary == "Summary of old stuff"
        # system + summary + 2 recent turns (user+assistant × 2) = 6
        assert len(new_context) == 6
        assert new_context[0].role == "system"
        assert new_context[1].is_summary is True
        assert new_context[1].is_display is False
        assert new_context[1].is_context is True

    async def test_preserves_system_messages(self):
        backend = FakeLLMBackend(responses=["sum"])
        context = [
            Message(role="system", content="s1"),
            Message(role="system", content="s2"),
            Message(role="user", content="1"),
            Message(role="assistant", content="a1"),
            Message(role="user", content="2"),
            Message(role="assistant", content="a2"),
        ]
        new_context, _ = await compress_context(
            context=context,
            llm=backend,
            model="test",
            temperature=0.3,
            keep_recent=1,
        )
        system_msgs = [m for m in new_context if m.role == "system"]
        assert len(system_msgs) == 2

    async def test_compresses_long_current_turn_when_requested(self):
        backend = FakeLLMBackend(responses=["mid-turn summary"])
        context = [Message(role="user", content="build a model")]
        for i in range(5):
            context.append(Message(
                role="assistant",
                tool_calls=[ToolCallRequest(id=str(i), name="fs", arguments={})],
            ))
            context.append(Message(role="tool", content=f"large result {i}", name="fs", tool_call_id=str(i)))

        new_context, summary = await compress_context(
            context=context,
            llm=backend,
            model="test",
            temperature=0.3,
            keep_recent=8,
            keep_recent_messages=4,
        )

        assert summary == "mid-turn summary"
        assert new_context[0].is_summary is True
        assert new_context[1].role == "user"
        assert new_context[1].content == "build a model"
        assert len(new_context) == 6  # summary + user + 4 recent messages

    async def test_meta_summary_consolidates_multiple_summaries(self):
        """When to_summarize contains multiple long existing summaries, a meta-summary
        pass runs first (consolidating them) before the main compression."""
        # First response = meta-summary, second = main compression
        backend = FakeLLMBackend(responses=["Meta summary", "Final summary"])
        # Build a context with two existing summaries (each > 4000 chars to cross threshold)
        big_summary_1 = "A" * 5000
        big_summary_2 = "B" * 5000
        context = [
            Message(role="system", content="sys"),
            Message(role="user", content=big_summary_1, is_summary=True, is_display=False),
            Message(role="user", content=big_summary_2, is_summary=True, is_display=False),
            Message(role="user", content="recent question"),
            Message(role="assistant", content="recent answer"),
            Message(role="user", content="new question"),
            Message(role="assistant", content="new answer"),
        ]
        new_context, summary = await compress_context(
            context=context,
            llm=backend,
            model="test",
            temperature=0.3,
            keep_recent=2,
        )
        assert summary == "Final summary"
        # system + 1 consolidated summary + 2 recent turns = 6
        assert len(new_context) == 6
        assert new_context[1].is_summary is True
        # Exactly one summary in final context (meta + raw folded into one)
        assert sum(1 for m in new_context if m.is_summary) == 1
        # Two LLM calls happened (meta + main)
        assert backend._call_idx == 2

    async def test_meta_summary_skipped_when_summaries_are_short(self):
        """Short existing summaries should not trigger an extra meta-summary pass."""
        backend = FakeLLMBackend(responses=["Final summary"])
        context = [
            Message(role="system", content="sys"),
            Message(role="user", content="short summary 1", is_summary=True, is_display=False),
            Message(role="user", content="short summary 2", is_summary=True, is_display=False),
            Message(role="user", content="recent question"),
            Message(role="assistant", content="recent answer"),
            Message(role="user", content="new question"),
            Message(role="assistant", content="new answer"),
        ]
        new_context, summary = await compress_context(
            context=context,
            llm=backend,
            model="test",
            temperature=0.3,
            keep_recent=2,
        )
        assert summary == "Final summary"
        assert len(new_context) == 6
        # Only one LLM call (main compression) because summaries were too short
        # to trigger meta-summary
        assert backend._call_idx == 1

    async def test_meta_summary_graceful_on_failure(self):
        """If meta-summary fails, compression continues with raw summaries."""
        import navi.core.compressor as compressor_module

        backend = FakeLLMBackend(responses=["Final summary"])
        original_meta = compressor_module._meta_summarize

        async def _failing_meta(*args, **kwargs):
            raise RuntimeError("meta boom")

        compressor_module._meta_summarize = _failing_meta
        try:
            big = "X" * 5000
            context = [
                Message(role="system", content="sys"),
                Message(role="user", content=big, is_summary=True, is_display=False),
                Message(role="user", content=big, is_summary=True, is_display=False),
                Message(role="user", content="q1"),
                Message(role="assistant", content="a1"),
                Message(role="user", content="q2"),
                Message(role="assistant", content="a2"),
            ]
            new_context, summary = await compress_context(
                context=context,
                llm=backend,
                model="test",
                temperature=0.3,
                keep_recent=2,
            )
            assert summary == "Final summary"
            assert len(new_context) == 6
            # Meta failed, so only one call to the backend
            assert backend._call_idx == 1
        finally:
            compressor_module._meta_summarize = original_meta

    async def test_intra_turn_fallback_aggressive(self):
        """When turn-based partition has nothing to compress but keep_recent_messages
        is set, an aggressive fallback (keep_recent_messages=2) should still find
        something to summarize.

        Here the current turn is exactly 5 messages (user + 2 assistant/tool pairs).
        With keep_recent_messages=4, partition_current_turn_messages returns None
        because len(turn) <= keep_recent_messages + 1. This forces the fallback to
        retry with keep_recent_messages=2, which successfully compresses."""
        backend = FakeLLMBackend(responses=["aggressive summary"])
        context = [Message(role="user", content="autonomous loop")]
        for i in range(2):
            context.append(Message(
                role="assistant",
                tool_calls=[ToolCallRequest(id=str(i), name="fs", arguments={})],
            ))
            context.append(Message(role="tool", content=f"result {i}", name="fs", tool_call_id=str(i)))

        # Single turn, keep_recent=1 → turn-based partition returns ([], non_system)
        # Fallback should kick in with keep_recent_messages=2
        new_context, summary = await compress_context(
            context=context,
            llm=backend,
            model="test",
            temperature=0.3,
            keep_recent=1,
            keep_recent_messages=4,
        )

        assert summary == "aggressive summary"
        assert new_context[0].is_summary is True
        assert new_context[1].role == "user"
        # Fallback with keep_recent_messages=2 keeps user + 2 newest messages
        assert len(new_context) == 4  # summary + user + 2 recent messages


class TestContextCompressor:
    def test_estimate_context_tokens_text_only(self):
        compressor = ContextCompressor()
        context = [Message(role="user", content="hello world")]
        assert compressor.estimate_context_tokens(context) == 3  # 11 chars // 3 = 3

    def test_estimate_context_tokens_with_images(self):
        compressor = ContextCompressor()
        context = [Message(role="user", content="hi", images=["base64img"])]
        assert compressor.estimate_context_tokens(context) == 500  # 2 chars // 3 = 0 + 500

    @pytest.mark.asyncio
    async def test_compress_session_success(self):
        backend = FakeLLMBackend(responses=["Summary text"])
        compressor = ContextCompressor()
        context = [
            Message(role="system", content="sys"),
            Message(role="user", content="1"),
            Message(role="assistant", content="a1"),
            Message(role="user", content="2"),
            Message(role="assistant", content="a2"),
            Message(role="user", content="3"),
            Message(role="assistant", content="a3"),
        ]
        result = await compressor.compress_session(
            context=context,
            llm=backend,
            model="test",
            temperature=0.3,
            keep_recent=2,
        )
        assert result is not None
        new_context, summary = result
        assert summary == "Summary text"
        assert len(new_context) == 6  # system + summary + 2 recent turns

    @pytest.mark.asyncio
    async def test_compress_session_retry_on_failure(self):
        """First attempt fails, second succeeds with keep_recent + 4."""
        import navi.core.compressor as compressor_module
        from unittest.mock import AsyncMock

        backend = FakeLLMBackend(responses=["Summary text"])
        original_compress_context = compressor_module.compress_context
        call_count = 0

        async def _failing_once(*args, **kwargs):
            nonlocal call_count
            call_count += 1
            if call_count == 1:
                raise RuntimeError("boom")
            return await original_compress_context(*args, **kwargs)

        compressor_module.compress_context = _failing_once
        try:
            compressor = ContextCompressor()
            context = [Message(role="system", content="sys")]
            for i in range(7):
                context.append(Message(role="user", content=str(i)))
                context.append(Message(role="assistant", content=f"a{i}"))
            result = await compressor.compress_session(
                context=context,
                llm=backend,
                model="test",
                temperature=0.3,
                keep_recent=2,
            )
            assert result is not None
            assert call_count == 2
        finally:
            compressor_module.compress_context = original_compress_context

    @pytest.mark.asyncio
    async def test_compress_session_hard_truncate_on_double_failure(self):
        """Both attempts fail → hard-truncate fallback."""
        import navi.core.compressor as compressor_module
        from unittest.mock import AsyncMock

        original_compress_context = compressor_module.compress_context

        async def _always_fail(*args, **kwargs):
            raise RuntimeError("always fails")

        compressor_module.compress_context = _always_fail
        try:
            compressor = ContextCompressor()
            context = [
                Message(role="system", content="sys"),
                Message(role="user", content="1"),
                Message(role="assistant", content="a1"),
                Message(role="user", content="2"),
                Message(role="assistant", content="a2"),
                Message(role="user", content="3"),
                Message(role="assistant", content="a3"),
                Message(role="user", content="4"),
                Message(role="assistant", content="a4"),
            ]
            result = await compressor.compress_session(
                context=context,
                llm=AsyncMock(),
                model="test",
                temperature=0.3,
                keep_recent=2,
            )
            assert result is not None
            new_context, summary = result
            assert "truncated" in summary.lower()
            # system + 6 kept = 7
            assert len(new_context) == 7
        finally:
            compressor_module.compress_context = original_compress_context

    @pytest.mark.asyncio
    async def test_compress_session_returns_none_when_nothing_to_compress(self):
        backend = FakeLLMBackend()
        compressor = ContextCompressor()
        context = [
            Message(role="user", content="hi"),
            Message(role="assistant", content="hello"),
        ]
        result = await compressor.compress_session(
            context=context,
            llm=backend,
            model="test",
            temperature=0.3,
            keep_recent=5,
        )
        assert result is None