Newer
Older
navi-1 / tests / unit / core / test_compressor.py
@Eugene Sukhodolskiy Eugene Sukhodolskiy on 29 Apr 5 KB Bootstrap test suite — Phase 1 unit tests
"""Unit tests for context compressor."""

import pytest

from navi.core.compressor import (
    _format_for_summary,
    compress_context,
    partition_messages,
    should_compress,
)
from navi.llm.base import Message, ToolCallRequest
from tests.conftest_factory import FakeLLMBackend


class TestShouldCompress:
    def test_below_threshold(self):
        assert should_compress(100, 1000, 0.7) is False

    def test_at_threshold(self):
        assert should_compress(700, 1000, 0.7) is True

    def test_above_threshold(self):
        assert should_compress(800, 1000, 0.7) is True


class TestPartitionMessages:
    def test_empty(self):
        old, recent = partition_messages([], keep_recent=2)
        assert old == []
        assert recent == []

    def test_fewer_turns_than_keep(self):
        msgs = [
            Message(role="user", content="hi"),
            Message(role="assistant", content="hello"),
        ]
        old, recent = partition_messages(msgs, keep_recent=5)
        assert old == []
        assert recent == msgs

    def test_splits_into_old_and_recent(self):
        msgs = [
            Message(role="user", content="1"),
            Message(role="assistant", content="a1"),
            Message(role="user", content="2"),
            Message(role="assistant", content="a2"),
            Message(role="user", content="3"),
            Message(role="assistant", content="a3"),
        ]
        old, recent = partition_messages(msgs, keep_recent=2)
        assert len(old) == 2  # turn 1
        assert len(recent) == 4  # turns 2+3

    def test_system_messages_ignored(self):
        msgs = [
            Message(role="system", content="sys"),
            Message(role="user", content="1"),
            Message(role="assistant", content="a1"),
        ]
        old, recent = partition_messages(msgs, keep_recent=5)
        assert recent == [m for m in msgs if m.role != "system"]

    def test_tool_calls_stay_with_assistant(self):
        msgs = [
            Message(role="user", content="1"),
            Message(
                role="assistant",
                content="",
                tool_calls=[ToolCallRequest(id="1", name="fs", arguments={})],
            ),
            Message(role="tool", content="result", name="fs", tool_call_id="1"),
            Message(role="user", content="2"),
            Message(role="assistant", content="ok"),
        ]
        old, recent = partition_messages(msgs, keep_recent=1)
        assert len(old) == 3
        assert len(recent) == 2


class TestFormatForSummary:
    def test_user_message(self):
        msgs = [Message(role="user", content="hello")]
        text, images = _format_for_summary(msgs)
        assert "User: hello" in text
        assert images == []

    def test_assistant_message(self):
        msgs = [Message(role="assistant", content="world")]
        text, images = _format_for_summary(msgs)
        assert "Assistant: world" in text

    def test_tool_call_block(self):
        msgs = [
            Message(
                role="assistant",
                tool_calls=[ToolCallRequest(id="1", name="fs", arguments={"path": "/tmp"})],
            ),
            Message(role="tool", content="file data", name="fs", tool_call_id="1"),
        ]
        text, images = _format_for_summary(msgs)
        assert "[Tool call: fs" in text
        assert "[Tool result: fs" in text

    def test_images_collected(self):
        msgs = [Message(role="user", content="look", images=["base64img"])]
        text, images = _format_for_summary(msgs)
        assert images == ["base64img"]
        assert "[+ 1 image(s)]" in text

    def test_summary_message_folded(self):
        msgs = [Message(role="user", content="old summary", is_summary=True)]
        text, _ = _format_for_summary(msgs)
        assert "old summary" in text


class TestCompressContext:
    async def test_nothing_to_compress(self):
        backend = FakeLLMBackend()
        result = await compress_context(
            context=[Message(role="user", content="hi")],
            llm=backend,
            model="test",
            temperature=0.3,
            keep_recent=8,
        )
        assert result is None

    async def test_compresses_old_turns(self):
        backend = FakeLLMBackend(responses=["Summary of old stuff"])
        context = [
            Message(role="system", content="sys"),
            Message(role="user", content="1"),
            Message(role="assistant", content="a1"),
            Message(role="user", content="2"),
            Message(role="assistant", content="a2"),
            Message(role="user", content="3"),
            Message(role="assistant", content="a3"),
        ]
        new_context, summary = await compress_context(
            context=context,
            llm=backend,
            model="test",
            temperature=0.3,
            keep_recent=2,
        )
        assert summary == "Summary of old stuff"
        # system + summary + 2 recent turns (user+assistant × 2) = 6
        assert len(new_context) == 6
        assert new_context[0].role == "system"
        assert new_context[1].is_summary is True

    async def test_preserves_system_messages(self):
        backend = FakeLLMBackend(responses=["sum"])
        context = [
            Message(role="system", content="s1"),
            Message(role="system", content="s2"),
            Message(role="user", content="1"),
            Message(role="assistant", content="a1"),
            Message(role="user", content="2"),
            Message(role="assistant", content="a2"),
        ]
        new_context, _ = await compress_context(
            context=context,
            llm=backend,
            model="test",
            temperature=0.3,
            keep_recent=1,
        )
        system_msgs = [m for m in new_context if m.role == "system"]
        assert len(system_msgs) == 2