"""Unit tests for context compressor."""

import pytest

from navi.core.compressor import (
    _format_for_summary,
    compress_context,
    partition_messages,
    should_compress,
)
from navi.llm.base import Message, ToolCallRequest
from tests.conftest_factory import FakeLLMBackend


class TestShouldCompress:
    def test_below_threshold(self):
        assert should_compress(100, 1000, 0.7) is False

    def test_at_threshold(self):
        assert should_compress(700, 1000, 0.7) is True

    def test_above_threshold(self):
        assert should_compress(800, 1000, 0.7) is True


class TestPartitionMessages:
    def test_empty(self):
        old, recent = partition_messages([], keep_recent=2)
        assert old == []
        assert recent == []

    def test_fewer_turns_than_keep(self):
        msgs = [
            Message(role="user", content="hi"),
            Message(role="assistant", content="hello"),
        ]
        old, recent = partition_messages(msgs, keep_recent=5)
        assert old == []
        assert recent == msgs

    def test_splits_into_old_and_recent(self):
        msgs = [
            Message(role="user", content="1"),
            Message(role="assistant", content="a1"),
            Message(role="user", content="2"),
            Message(role="assistant", content="a2"),
            Message(role="user", content="3"),
            Message(role="assistant", content="a3"),
        ]
        old, recent = partition_messages(msgs, keep_recent=2)
        assert len(old) == 2  # turn 1
        assert len(recent) == 4  # turns 2+3

    def test_system_messages_ignored(self):
        msgs = [
            Message(role="system", content="sys"),
            Message(role="user", content="1"),
            Message(role="assistant", content="a1"),
        ]
        old, recent = partition_messages(msgs, keep_recent=5)
        assert recent == [m for m in msgs if m.role != "system"]

    def test_tool_calls_stay_with_assistant(self):
        msgs = [
            Message(role="user", content="1"),
            Message(
                role="assistant",
                content="",
                tool_calls=[ToolCallRequest(id="1", name="fs", arguments={})],
            ),
            Message(role="tool", content="result", name="fs", tool_call_id="1"),
            Message(role="user", content="2"),
            Message(role="assistant", content="ok"),
        ]
        old, recent = partition_messages(msgs, keep_recent=1)
        assert len(old) == 3
        assert len(recent) == 2

    def test_mid_turn_fallback_keeps_user_and_recent_messages(self):
        msgs = [Message(role="user", content="build a model")]
        for i in range(5):
            msgs.append(Message(
                role="assistant",
                tool_calls=[ToolCallRequest(id=str(i), name="fs", arguments={})],
            ))
            msgs.append(Message(role="tool", content=f"result {i}", name="fs", tool_call_id=str(i)))

        old, recent = partition_messages(msgs, keep_recent=8, keep_recent_messages=4)

        assert old
        assert recent[0].role == "user"
        assert recent[0].content == "build a model"
        assert len(recent) == 5  # original user message + 4 newest in-flight messages


class TestFormatForSummary:
    def test_user_message(self):
        msgs = [Message(role="user", content="hello")]
        text, images = _format_for_summary(msgs)
        assert "User: hello" in text
        assert images == []

    def test_assistant_message(self):
        msgs = [Message(role="assistant", content="world")]
        text, images = _format_for_summary(msgs)
        assert "Assistant: world" in text

    def test_tool_call_block(self):
        msgs = [
            Message(
                role="assistant",
                tool_calls=[ToolCallRequest(id="1", name="fs", arguments={"path": "/tmp"})],
            ),
            Message(role="tool", content="file data", name="fs", tool_call_id="1"),
        ]
        text, images = _format_for_summary(msgs)
        assert "[Tool call: fs" in text
        assert "[Tool result: fs" in text

    def test_images_collected(self):
        msgs = [Message(role="user", content="look", images=["base64img"])]
        text, images = _format_for_summary(msgs)
        assert images == ["base64img"]
        assert "[+ 1 image(s)]" in text

    def test_summary_message_folded(self):
        msgs = [Message(role="user", content="old summary", is_summary=True)]
        text, _ = _format_for_summary(msgs)
        assert "old summary" in text


class TestCompressContext:
    async def test_nothing_to_compress(self):
        backend = FakeLLMBackend()
        result = await compress_context(
            context=[Message(role="user", content="hi")],
            llm=backend,
            model="test",
            temperature=0.3,
            keep_recent=8,
        )
        assert result is None

    async def test_compresses_old_turns(self):
        backend = FakeLLMBackend(responses=["Summary of old stuff"])
        context = [
            Message(role="system", content="sys"),
            Message(role="user", content="1"),
            Message(role="assistant", content="a1"),
            Message(role="user", content="2"),
            Message(role="assistant", content="a2"),
            Message(role="user", content="3"),
            Message(role="assistant", content="a3"),
        ]
        new_context, summary = await compress_context(
            context=context,
            llm=backend,
            model="test",
            temperature=0.3,
            keep_recent=2,
        )
        assert summary == "Summary of old stuff"
        # system + summary + 2 recent turns (user+assistant × 2) = 6
        assert len(new_context) == 6
        assert new_context[0].role == "system"
        assert new_context[1].is_summary is True

    async def test_preserves_system_messages(self):
        backend = FakeLLMBackend(responses=["sum"])
        context = [
            Message(role="system", content="s1"),
            Message(role="system", content="s2"),
            Message(role="user", content="1"),
            Message(role="assistant", content="a1"),
            Message(role="user", content="2"),
            Message(role="assistant", content="a2"),
        ]
        new_context, _ = await compress_context(
            context=context,
            llm=backend,
            model="test",
            temperature=0.3,
            keep_recent=1,
        )
        system_msgs = [m for m in new_context if m.role == "system"]
        assert len(system_msgs) == 2

    async def test_compresses_long_current_turn_when_requested(self):
        backend = FakeLLMBackend(responses=["mid-turn summary"])
        context = [Message(role="user", content="build a model")]
        for i in range(5):
            context.append(Message(
                role="assistant",
                tool_calls=[ToolCallRequest(id=str(i), name="fs", arguments={})],
            ))
            context.append(Message(role="tool", content=f"large result {i}", name="fs", tool_call_id=str(i)))

        new_context, summary = await compress_context(
            context=context,
            llm=backend,
            model="test",
            temperature=0.3,
            keep_recent=8,
            keep_recent_messages=4,
        )

        assert summary == "mid-turn summary"
        assert new_context[0].is_summary is True
        assert new_context[1].role == "user"
        assert new_context[1].content == "build a model"
        assert len(new_context) == 6  # summary + user + 4 recent messages
