Newer
Older
navi-1 / navi / core / compressor.py
"""
Context compressor — summarizes old messages to stay within the token limit.

Flow:
1. Partition session messages into "to_summarize" (old turns) and "to_keep" (recent turns).
2. Call the LLM to produce a structured summary of the old messages.
3. Replace the old turns with a structured summary message (role=user, is_summary=True).

A "turn" is one user message plus all following assistant/tool messages up to the
next user message. Tool call groups (assistant + tool results) are never split.
Existing summary messages are always folded into the next compression pass.

Compression is profile-aware: AgentProfile can provide compression_keep_recent,
compression_max_tokens, and a compression_prompt_file to specialize summaries.
"""

import json
import re
from datetime import datetime, timezone
from pathlib import Path
from typing import TYPE_CHECKING

from navi.llm.base import LLMBackend, Message
from navi.config import settings
from .events import ContextCompressed

if TYPE_CHECKING:
    from navi.profiles.base import AgentProfile


_SUMMARY_SECTIONS = [
    ("Goal", "One clear sentence describing what the user is trying to accomplish in this session. Include deadlines or constraints if stated."),
    ("Active Files", "Every file or directory the assistant touched, with absolute or project-relative path and status: created / modified / read / deleted. For modified files, note the purpose of the change."),
    ("Decisions & User Preferences", "Explicit choices, architecture decisions, style preferences, or corrections stated by the user. Things the user said NOT to do."),
    ("Completed Work", "Concrete finished steps — include file/function names and verification outcome if available."),
    ("Pending Work / Todo", "Open tasks, in-progress items, or follow-ups that still need action."),
    ("Errors & Blockers", "Failures, exceptions, or unresolved issues. Include exact error snippets when short and diagnostic."),
    ("Key Values", "Exact constants the assistant should remember: ports, config keys, versions, dependency names, important paths, IDs."),
]

_SUMMARY_TEMPLATE_INSTRUCTIONS = (
    "You are summarizing a conversation history to free up context space. "
    "The assistant will continue working using ONLY this summary — it will have no access "
    "to the original messages. Be thorough and precise. Prefer specifics over generalities. "
    "This summary is historical context, not a new user request.\n\n"
    "Use EXACTLY the Markdown structure below. Every section must be present. "
    "If a section has no relevant information, write its header and the literal word NONE. "
    "Keep bullet points tight and information-dense. "
    "Do not include greetings, filler, transitions, or meta-commentary.\n\n"
    + "\n\n".join(f"## {title}\n{desc}" for title, desc in _SUMMARY_SECTIONS)
    + "\n\n"
    "Output rules:\n"
    "- Preserve exact file paths, function names, config keys, and short error snippets verbatim.\n"
    "- Do not paraphrase values that must stay precise.\n"
    "- Do not write implementation code, patches, or long command output.\n"
    "- Use Markdown headers exactly as shown."
)


# Tools whose full output is often needed later and should not be aggressively truncated.
_CRITICAL_TOOL_NAMES = frozenset({
    "filesystem",
    "code_exec",
    "terminal",
    "ssh_exec",
})


# Content markers that make a turn worth preserving verbatim longer.
_CRITICAL_PATTERNS = [
    re.compile(r"\b(error|exception|traceback|failed|failure)\b", re.IGNORECASE),
    re.compile(r"\b(user\s+said|no,\s+|don't\s+|do\s+not\s+|never\s+|instead\s+|wrong\s+|incorrect\s+|fix\s+|correct\s+)\b", re.IGNORECASE),
    re.compile(r"\b(edited|modified|created|deleted|wrote|added)\s+(file|function|class|method)\b", re.IGNORECASE),
]


def should_compress(context_tokens: int, max_context_tokens: int, threshold: float) -> bool:
    return context_tokens >= int(max_context_tokens * threshold)


def _turn_importance(turn: list[Message]) -> int:
    """Score a turn for adaptive keep_recent. Higher = more important to keep."""
    score = 0
    text = "\n".join((m.content or "") for m in turn)
    lowered = text.lower()
    # Strong signals: user corrections and explicit negatives
    if any(w in lowered for w in ("wrong", "incorrect", "fix", "don't use", "do not use", "instead use", "change to")):
        score += 3
    for pattern in _CRITICAL_PATTERNS:
        score += len(pattern.findall(text))
    for m in turn:
        if getattr(m, "is_compression_critical", False):
            score += 3
        if m.role == "tool" and m.name in _CRITICAL_TOOL_NAMES:
            score += 1
        if m.role == "user" and len((m.content or "").strip()) <= 20:
            # Very short user messages are usually social/filler; deprioritize
            score -= 2
    return max(0, score)


def partition_messages(
    messages: list[Message],
    keep_recent: int,
    keep_recent_messages: int | None = None,
    adaptive: bool = True,
) -> tuple[list[Message], list[Message]]:
    """
    Returns (to_summarize, to_keep).

    Keeps the system message and the last `keep_recent` conversational turns verbatim.
    Everything older goes into to_summarize.
    Tool call groups (assistant + tool results) always stay together.
    When adaptive=True, important turns (user corrections, errors, critical tools)
    are kept longer and social/filler turns are compressed sooner.
    """
    non_system = [m for m in messages if m.role != "system"]

    # Group into turns: each turn starts with a user message
    turns: list[list[Message]] = []
    current: list[Message] = []
    for msg in non_system:
        if msg.role == "user" and current:
            turns.append(current)
            current = [msg]
        else:
            current.append(msg)
    if current:
        turns.append(current)

    if len(turns) <= keep_recent:
        if keep_recent_messages is not None:
            intra_turn = partition_current_turn_messages(turns, keep_recent_messages)
            if intra_turn is not None:
                return intra_turn
        return [], non_system  # nothing old enough to compress

    # Adaptive: pull important older turns into the kept region and push
    # unimportant recent/filler turns out for summarization.
    base_keep = keep_recent
    recent_turns = turns[-base_keep:]
    old_turns = turns[:-base_keep]
    if adaptive:
        # Identify important old turns that should not be lost.
        important_old = [t for t in old_turns if _turn_importance(t) > 0]
        # Identify filler turns in the recent window that can be swapped out.
        filler_recent = [t for t in recent_turns if _turn_importance(t) == 0]
        swaps = min(len(important_old), len(filler_recent))
        for i in range(swaps):
            # Replace the oldest filler in recent with the most important old turn.
            recent_turns[recent_turns.index(filler_recent[i])] = important_old[-(i + 1)]
        # Re-sort kept turns by original position so context order stays chronological.
        kept_set = {id(t) for t in recent_turns}
        recent_turns = [t for t in turns if id(t) in kept_set]
        old_turns = [t for t in turns if id(t) not in kept_set]

    to_summarize = [m for turn in old_turns for m in turn]
    to_keep = [m for turn in recent_turns for m in turn]
    return to_summarize, to_keep


def partition_current_turn_messages(
    turns: list[list[Message]],
    keep_recent_messages: int,
) -> tuple[list[Message], list[Message]] | None:
    """
    Mid-turn fallback for long autonomous tool loops.

    A long chain of assistant/tool iterations after one user message is one
    conversational turn, so turn-based compression may have nothing to compress.
    Keep the current user request and the newest messages verbatim, then summarize
    older messages from the same in-flight turn.
    """
    if not turns:
        return None

    current_turn = turns[-1]
    if len(current_turn) <= keep_recent_messages + 1:
        return None

    head = [current_turn[0]] if current_turn and current_turn[0].role == "user" else []
    tail_start = max(len(head), len(current_turn) - keep_recent_messages)
    to_summarize = [m for turn in turns[:-1] for m in turn] + current_turn[len(head):tail_start]
    to_keep = head + current_turn[tail_start:]

    if len(to_summarize) < 2:
        return None

    return to_summarize, to_keep


def _format_for_summary(messages: list[Message]) -> tuple[str, list[str]]:
    """
    Render messages as plain text for the summarization prompt.

    Returns (text, images) where images is a flat list of base64 strings
    collected from all user messages. Vision-capable models will receive
    the images alongside the text; non-vision models silently ignore them.
    """
    lines: list[str] = []
    images: list[str] = []
    i = 0
    while i < len(messages):
        m = messages[i]

        if m.is_summary:
            # Existing summary — include as-is (already compressed)
            lines.append(m.content or "")
            i += 1

        elif m.role == "user":
            if m.images:
                images.extend(m.images)
                img_note = f" [+ {len(m.images)} image(s)]"
            else:
                img_note = ""
            if m.content:
                lines.append(f"User: {m.content}{img_note}")
            elif img_note:
                lines.append(f"User:{img_note}")
            i += 1

        elif m.role == "assistant" and m.tool_calls:
            # Render tool calls + their results as a compact block
            if m.content:
                lines.append(f"Assistant: {m.content}")
            for tc in m.tool_calls:
                args_preview = json.dumps(tc.arguments)[:120]
                lines.append(f"[Tool call: {tc.name}; arguments preview: {args_preview}]")
            i += 1
            while i < len(messages) and messages[i].role == "tool":
                tool_msg = messages[i]
                result = tool_msg.content or ""
                # Critical tool results and explicit critical flag survive verbatim
                # up to a larger budget, so exact errors/file contents are preserved.
                critical = (
                    getattr(tool_msg, "is_compression_critical", False)
                    or tool_msg.name in _CRITICAL_TOOL_NAMES
                )
                if critical and len(result) <= 4000:
                    preview = result
                else:
                    preview = result[:300] + ("…" if len(result) > 300 else "")
                lines.append(f"[Tool result: {tool_msg.name}; preview: {preview}]")
                i += 1

        elif m.role == "assistant" and m.content:
            lines.append(f"Assistant: {m.content}")
            i += 1

        else:
            i += 1  # skip orphan tool messages

    return "\n".join(lines), images


# Safety limit: truncate formatted input to this many characters before sending to LLM.
# Prevents the summarizer from receiving near-context-sized input it can't fit alongside output.
_MAX_SUMMARY_INPUT_CHARS = 24_000

# When existing summaries in to_summarize exceed this many chars combined,
# run a quick meta-summary to consolidate them before the main compression pass.
_META_SUMMARY_THRESHOLD = _MAX_SUMMARY_INPUT_CHARS // 3  # 8_000

_META_SUMMARY_SYSTEM = (
    "You are condensing multiple conversation summaries into a single compact summary. "
    "Preserve all key facts, decisions, file paths, config values, errors, and user preferences. "
    "Eliminate redundancy between overlapping summaries. "
    "Write tight bullet points. Do not include filler or meta-commentary."
)


async def _meta_summarize(
    summaries: list[Message],
    llm: LLMBackend,
    model,
    temperature: float,
) -> Message:
    """Consolidate multiple existing summary messages into one compact meta-summary."""
    combined = "\n\n---\n\n".join(m.content for m in summaries)
    prompt = [
        Message(role="system", content=_META_SUMMARY_SYSTEM),
        Message(role="user", content=combined),
    ]
    response = await llm.complete(
        prompt,
        tools=None,
        temperature=temperature,
        model=model,
        think=False,
        max_tokens=1500,
    )
    text = (response.content or "").strip() or "(consolidated summary unavailable)"
    return Message(
        role="user",
        content=f"[Consolidated Context Summary]\n{text}",
        is_summary=True,
        is_display=False,
        is_context=True,
        created_at=datetime.now(timezone.utc),
    )


def _build_summary_system_prompt(profile: "AgentProfile | None") -> str:
    """Build the system prompt used by the summarization LLM.

    Uses the profile-specific compression prompt file if configured.
    """
    base = _SUMMARY_TEMPLATE_INSTRUCTIONS
    extra = ""
    if profile is not None and getattr(profile, "compression_prompt_file", None):
        profile_dir = Path("navi/profiles") / profile.id
        prompt_path = profile_dir / profile.compression_prompt_file
        if prompt_path.exists():
            try:
                extra = "\n\n---\n\n[Profile-specific compression instructions]\n\n" + prompt_path.read_text(encoding="utf-8").strip()
            except Exception:
                pass
    return base + extra


async def compress_context(
    context: list[Message],
    llm: LLMBackend,
    model: "list[str] | str | None",
    temperature: float,
    keep_recent: int,
    max_tokens: int | None = None,
    keep_recent_messages: int | None = None,
    profile: "AgentProfile | None" = None,
) -> tuple[list[Message], str] | None:
    """
    Summarize old messages in the LLM context and return a shorter context list.
    Only operates on `context` — the full display history (session.messages) is never touched.
    Returns None if there is nothing to compress.

    Images from old user messages are passed to the summarization model.
    Vision-capable models will incorporate image descriptions into the summary;
    non-vision models silently ignore the images field.

    Uses the same model already loaded in memory (profile.model passed via WorkerContext) —
    no model swap, no extra loading overhead.

    Profile settings override global defaults when provided:
      - compression_keep_recent -> keep_recent
      - compression_max_tokens -> max_tokens
      - compression_prompt_file -> appended to summary system prompt

    Exceptions propagate to the caller (CompressionWorker catches them).
    """
    effective_keep_recent = getattr(profile, "compression_keep_recent", None) or keep_recent
    effective_max_tokens = getattr(profile, "compression_max_tokens", None) or max_tokens

    system_msgs = [m for m in context if m.role == "system"]
    to_summarize, to_keep = partition_messages(
        context,
        effective_keep_recent,
        keep_recent_messages=keep_recent_messages,
    )

    # Fallback: if turn-based partition has nothing to compress but we are in
    # mid-turn mode (keep_recent_messages set), try an aggressive intra-turn
    # split keeping only the 2 newest messages of the current turn.
    if len(to_summarize) < 2 and keep_recent_messages is not None and keep_recent_messages > 2:
        to_summarize, to_keep = partition_messages(
            context,
            effective_keep_recent,
            keep_recent_messages=2,
        )

    if len(to_summarize) < 2:
        return None  # nothing substantial to compress

    # Meta-summary: if to_summarize contains multiple existing summary messages
    # that are long enough to crowd the summarizer input, consolidate them first.
    summary_msgs = [m for m in to_summarize if m.is_summary]
    if len(summary_msgs) > 1:
        total_summary_len = sum(len(m.content or "") for m in summary_msgs)
        if total_summary_len > _META_SUMMARY_THRESHOLD:
            try:
                meta = await _meta_summarize(
                    summary_msgs, llm, model, temperature
                )
                to_summarize = [m for m in to_summarize if not m.is_summary]
                to_summarize.insert(0, meta)
            except Exception:
                # If meta-summary fails, continue with raw summaries as-is
                pass

    summary_text_input, images = _format_for_summary(to_summarize)

    # Truncate oversized input so the summarizer LLM has room to generate output
    if len(summary_text_input) > _MAX_SUMMARY_INPUT_CHARS:
        summary_text_input = summary_text_input[:_MAX_SUMMARY_INPUT_CHARS] + "\n…[truncated]"

    system_prompt = _build_summary_system_prompt(profile)
    prompt = [
        Message(role="system", content=system_prompt),
        Message(role="user", content=summary_text_input, images=images or None),
    ]

    # think=False: compression must be fast — extended reasoning wastes context and hangs
    response = await llm.complete(
        prompt, tools=None, temperature=temperature, model=model, think=False, max_tokens=effective_max_tokens
    )
    summary_text = (response.content or "").strip() or "(summary unavailable)"

    summary_msg = Message(
        role="user",
        content=f"[Context Summary - historical context only, not a new user request]\n{summary_text}",
        is_summary=True,
        is_display=False,
        created_at=datetime.now(timezone.utc),
    )

    return system_msgs + [summary_msg] + to_keep, summary_text


class ContextCompressor:
    """High-level context compression with retry strategy and hard-truncate fallback.

    Thin wrapper around `compress_context` that adds:
    1. Retry with keep_recent + 4 on LLM failure.
    2. Hard-truncate fallback (drop oldest messages without summarizing).
    """

    @staticmethod
    def estimate_context_tokens(context: list[Message]) -> int:
        """Conservative local estimate used before the next LLM call returns real token counts.

        Uses ~3 chars per token (more conservative than the naive 4) because code and
        punctuation are often 1 token per character. Images counted at 500 tokens each
        (rough vision-model estimate).
        """
        chars = sum(len(m.content or "") for m in context)
        imgs = sum(500 for m in context if m.images)
        return chars // 3 + imgs

    def __init__(self) -> None:
        self._profile: "AgentProfile | None" = None

    def set_profile(self, profile: "AgentProfile | None") -> None:
        """Tell the compressor which profile is active so it can use profile-specific settings."""
        self._profile = profile

    async def compress_session(
        self,
        context: list[Message],
        llm: LLMBackend,
        model: "list[str] | str | None",
        temperature: float,
        keep_recent: int,
        max_tokens: int | None = None,
        keep_recent_messages: int | None = None,
    ) -> tuple[list[Message], str] | None:
        """Compress context with retry + hard-truncate fallback.

        Returns (new_context, summary_text) or None if nothing changed.
        Does NOT mutate the session — the caller is responsible for updating
        session.context, session.context_token_count, and persisting.
        """
        effective_keep_recent = getattr(self._profile, "compression_keep_recent", None) or keep_recent
        effective_max_tokens = getattr(self._profile, "compression_max_tokens", None) or max_tokens

        # Attempt 1: normal compression
        try:
            result = await compress_context(
                context=context,
                llm=llm,
                model=model,
                temperature=temperature,
                keep_recent=effective_keep_recent,
                max_tokens=effective_max_tokens,
                keep_recent_messages=keep_recent_messages,
                profile=self._profile,
            )
        except Exception:
            # Attempt 2: keep more recent turns verbatim
            try:
                result = await compress_context(
                    context=context,
                    llm=llm,
                    model=model,
                    temperature=temperature,
                    keep_recent=effective_keep_recent + 4,
                    max_tokens=effective_max_tokens,
                    keep_recent_messages=(keep_recent_messages + 4)
                    if keep_recent_messages is not None
                    else None,
                    profile=self._profile,
                )
            except Exception:
                # Attempt 3: hard-truncate fallback
                return self._hard_truncate(context)

        if result is None:
            return None
        return result

    def _hard_truncate(
        self, context: list[Message]
    ) -> tuple[list[Message], str] | None:
        """Last-resort fallback: drop oldest non-system messages without summarizing.

        Keeps whole conversational turns so tool call groups are never split.
        """
        system_msgs = [m for m in context if m.role == "system"]
        non_system = [m for m in context if m.role != "system"]

        _HARD_TRUNCATE_KEEP = 6
        if len(non_system) <= _HARD_TRUNCATE_KEEP:
            return None

        # Group non-system into turns (user message starts a turn)
        turns: list[list[Message]] = []
        current: list[Message] = []
        for msg in non_system:
            if msg.role == "user" and current:
                turns.append(current)
                current = [msg]
            else:
                current.append(msg)
        if current:
            turns.append(current)

        # Keep at least 2 recent turns, or enough messages to exceed _HARD_TRUNCATE_KEEP
        kept_turns: list[list[Message]] = []
        kept_count = 0
        for turn in reversed(turns):
            kept_turns.insert(0, turn)
            kept_count += len(turn)
            if kept_count >= _HARD_TRUNCATE_KEEP and len(kept_turns) >= 2:
                break

        to_keep = [m for turn in kept_turns for m in turn]
        new_context = system_msgs + to_keep
        summary_text = (
            "[Context was too large to summarize. Old messages were truncated to prevent "
            "the model from exceeding its context window. Some earlier details may have been lost.]"
        )
        return new_context, summary_text

    def check_context_size(self, context: list[Message]) -> None:
        """Raise ContextTooLargeError before an LLM call if the context is dangerously large."""
        from navi.config import settings
        from navi.exceptions import ContextTooLargeError

        if not context:
            return

        output_reserve = settings.output_reserve_tokens
        total = self.estimate_context_tokens(context)
        available = settings.ollama_num_ctx - output_reserve

        if total > available:
            existing = self.estimate_context_tokens(context[:-1])
            new = self.estimate_context_tokens(context[-1:])
            remaining = available - existing
            raise ContextTooLargeError(
                f"Context too large: new content is ~{new:,} estimated tokens, "
                f"but only ~{max(0, remaining):,} tokens are available "
                f"(window {settings.ollama_num_ctx:,}, already used ~{existing:,}, "
                f"output_reserve {output_reserve:,}). "
                "Split the file into smaller parts or delegate to a subagent."
            )

    async def compress_and_save_session(
        self,
        session,
        session_store,
        llm: LLMBackend,
        model: str,
        temperature: float,
        session_id: str,
        reason: str,
        keep_recent: int,
        max_tokens: int | None = None,
        keep_recent_messages: int | None = None,
    ) -> ContextCompressed | None:
        """Compresses the session context and persists the changes to the session store."""
        count_before = len(session.context)
        result = await self.compress_session(
            context=session.context,
            llm=llm,
            model=model,
            temperature=temperature,
            keep_recent=keep_recent,
            max_tokens=max_tokens,
            keep_recent_messages=keep_recent_messages,
        )
        if result is None:
            return None

        new_context, summary_text = result

        # Mark messages in session.messages as not context if they are no longer in new_context
        # and are not system messages.
        new_context_ids = {id(m) for m in new_context}
        for msg in session.messages:
            if msg.role != "system" and id(msg) not in new_context_ids:
                msg.is_context = False

        # Add the summary message to session.messages if it's not already there
        summary_msg = next((m for m in new_context if m.is_summary), None)
        if summary_msg and summary_msg not in session.messages:
            summary_msg.is_display = False
            session.messages.append(summary_msg)

        # Add a system message with is_compression=True and content=summary_text
        session.messages.append(
            Message(
                role="system",
                content=summary_text,
                is_compression=True,
                is_context=False,
                created_at=datetime.now(timezone.utc),
            )
        )

        session.context = new_context
        session.context_token_count = self.estimate_context_tokens(new_context)
        await session_store.save(session)

        # Archive old messages if the hot table exceeds the configured window.
        if settings.session_messages_window > 0 and session.db_next_sequence > settings.session_messages_window:
            threshold = session.db_next_sequence - settings.session_messages_window
            archived = await session_store.archive_old_messages(session_id, threshold)
            if archived > 0:
                session.messages = [m for m in session.messages if m.sequence_number >= threshold]
                session.context = [m for m in session.context if m.sequence_number >= threshold]
                session.archive_threshold = threshold

        return ContextCompressed(
            messages_before=count_before,
            messages_after=len(new_context),
            summary=summary_text,
            context_tokens=session.context_token_count,
            max_context_tokens=settings.ollama_num_ctx,
        )