Newer
Older
navi-1 / navi / memory / extractor.py
@Eugene Sukhodolskiy Eugene Sukhodolskiy on 12 May 8 KB Clarify knowledge persistence prompts
"""
Fact extraction and summary generation for the memory system.

Flow (triggered when a session is considered complete):
1. Format session.messages as plain text
2. Ask LLM to extract stable facts about the user → upsert into memory_facts
3. If new facts were found → regenerate summary from all facts
"""

import json
import structlog

from navi.llm.base import LLMBackend, Message

from .store import MemoryStore

log = structlog.get_logger()

_EXTRACT_SYSTEM = """\
You extract stable facts about the user from a session transcript.

The transcript contains four types of entries:
1. User messages — what the user explicitly said
2. Assistant messages — Navi's own responses
3. Tool calls — which tools Navi executed and their arguments
4. Tool results — the raw output those tools returned (may be truncated)

Facts discovered through tool results are MORE RELIABLE than conversation.
If a tool result contradicts something said in chat — trust the tool result.

Extract facts that are:
- Persistent characteristics: name, age, location, occupation, family situation
- Technical environment: OS, tools, servers, devices, IPs, running services
- Preferences: communication style, coding habits, likes/dislikes, workflow patterns
- Ongoing projects or goals

Do NOT extract:
- Topics that were discussed or questions that were asked
- Temporary states ("was tired", "was busy today")
- Information about third parties that isn't about the user
- Directory-specific project notes, one-off commands, file paths, task progress
- Infrastructure inventory, service topology, network routes, server facts, or local operational notes
- Already-known facts that appear in the transcript

For each fact, indicate its source:
- "conversation" — the user explicitly stated it or it was inferred from chat
- "tool_call" — discovered through a tool execution result

Return ONLY a valid JSON array. No markdown, no prose, no comments.
Return empty [] if nothing new should be extracted.
Schema:
[
  {"category": "profile", "key": "name", "value": "Eugene", "source": "conversation", "source_context": "user introduced themselves"},
  {"category": "preferences", "key": "prefers_dark_ui", "value": "true", "source": "conversation", "source_context": "user asked for dark UI"}
]

Valid categories: profile, preferences, technical, projects, other"""

_SUMMARY_SYSTEM = """\
You are writing a memory summary for an AI assistant about its user.
Summarize the facts below in 2-4 short paragraphs (max 400 words).
Write from the assistant's perspective: what you know about the user.
Be specific and concrete. Cover the most important identifying details first,
then preferences and ongoing context.
Do not add facts not present below. Do not include task progress, local directory notes,
one-off commands, infrastructure inventory, service topology, network routes, or server facts."""


async def extract_and_update(
    session,
    llm: LLMBackend,
    model: str,
    memory_store: MemoryStore,
) -> None:
    """
    Extract facts from a session and update the memory summary.
    Safe to call multiple times — already-extracted sessions produce no duplicates.
    Legacy sessions (user_id=None) are skipped — no multi-user memory for unowned sessions.
    """
    user_id = getattr(session, "user_id", None)
    if user_id is None:
        return

    facts_added = await _extract_facts(session, llm, model, memory_store)
    log.info("memory.extracted", session_id=session.id, facts_added=facts_added)

    await memory_store.mark_session_extracted(session.id)

    if facts_added > 0:
        await _regenerate_summary(llm, model, memory_store, user_id=user_id)


_MAX_TOOL_RESULT_LEN = 500
_MAX_TRANSCRIPT_CHARS = 12_000


async def _extract_facts(session, llm: LLMBackend, model: str, store: MemoryStore) -> int:
    user_id = getattr(session, "user_id", None)
    lines: list[str] = []

    # Map tool_call_id -> tool_name so we can label tool results
    tool_call_map: dict[str, str] = {}
    for msg in session.messages:
        if msg.role == "assistant" and msg.tool_calls:
            for tc in msg.tool_calls:
                tool_call_map[tc.id] = tc.name

    for msg in session.messages:
        if msg.role == "user" and msg.content:
            lines.append(f"User: {msg.content}")
        elif msg.role == "assistant":
            if msg.content:
                lines.append(f"Assistant: {msg.content}")
            if msg.tool_calls:
                for tc in msg.tool_calls:
                    args = str(tc.arguments)[:200]
                    lines.append(f"[Tool call] {tc.name}({args})")
        elif msg.role == "tool" and msg.content:
            tool_name = tool_call_map.get(msg.tool_call_id or "", "unknown")
            content = msg.content
            if len(content) > _MAX_TOOL_RESULT_LEN:
                content = content[:_MAX_TOOL_RESULT_LEN] + " ... [truncated]"
            lines.append(f"[Tool result] {tool_name}: {content}")

    if not lines:
        return 0

    transcript = "\n".join(lines)
    if len(transcript) > _MAX_TRANSCRIPT_CHARS:
        # Keep early context + recent tail; drop the middle
        half = _MAX_TRANSCRIPT_CHARS // 2
        transcript = transcript[:half] + "\n\n... [transcript truncated] ...\n\n" + transcript[-half:]

    prompt = [
        Message(role="system", content=_EXTRACT_SYSTEM),
        Message(role="user", content=transcript),
    ]

    try:
        response = await llm.complete(prompt, tools=None, temperature=0.1, model=model)
        raw = (response.content or "").strip()
    except Exception:
        log.warning("memory.extract_llm_error", session_id=session.id, exc_info=True)
        return 0

    # Find JSON array in response (model may add surrounding text)
    start = raw.find("[")
    end = raw.rfind("]") + 1
    if start == -1 or end == 0:
        return 0

    try:
        facts = json.loads(raw[start:end])
    except json.JSONDecodeError:
        log.warning("memory.extract_parse_error", session_id=session.id, raw=raw[:300])
        return 0

    count = 0
    for fact in facts:
        if not isinstance(fact, dict):
            continue
        category = str(fact.get("category", "other")).strip().lower()
        if category not in {"profile", "preferences", "technical", "projects", "other"}:
            category = "other"
        key = str(fact.get("key", "")).strip()
        value = str(fact.get("value", "")).strip()
        source = str(fact.get("source", "conversation")).strip().lower()
        if source not in {"conversation", "tool_call", "auto_discovery", "user_explicit"}:
            source = "conversation"
        source_context = str(fact.get("source_context", "")).strip()

        # Confidence mapping based on source reliability
        confidence = {"tool_call": 95, "auto_discovery": 95, "user_explicit": 90}.get(source, 70)

        if key and value:
            # TODO: Semantic deduplication — before upsert, search for semantically
            # similar facts and merge/update instead of creating duplicates.
            # Problem: LLM generates different keys for the same fact across sessions.
            # Solution: vector search + similarity threshold before upsert.
            await store.upsert_fact(
                category=category,
                key=key,
                value=value,
                user_id=user_id,
                source_session_id=session.id,
                source=source,
                confidence=confidence,
                source_context=source_context,
            )
            count += 1

    return count


async def _regenerate_summary(llm: LLMBackend, model: str, store: MemoryStore, user_id: str | None = None) -> None:
    facts = await store.get_all_facts(user_id=user_id)
    if not facts:
        return

    # Group by category, sort by recency within each
    by_cat: dict[str, list] = {}
    for f in facts:
        by_cat.setdefault(f["category"], []).append(f)

    lines: list[str] = []
    for cat in sorted(by_cat):
        lines.append(f"[{cat}]")
        for f in by_cat[cat]:
            lines.append(f"  {f['key']}: {f['value']}")

    prompt = [
        Message(role="system", content=_SUMMARY_SYSTEM),
        Message(role="user", content="\n".join(lines)),
    ]

    try:
        response = await llm.complete(prompt, tools=None, temperature=0.3, model=model)
        summary = (response.content or "").strip()
    except Exception:
        log.warning("memory.summary_llm_error", exc_info=True)
        return

    if summary:
        await store.set_summary(summary, user_id=user_id)
        log.info("memory.summary_updated", fact_count=len(facts), user_id=user_id)