Newer
Older
navi-1 / navi / memory / store.py
"""Persistent memory store — facts about the user, backed by PostgreSQL with pgvector support."""

import asyncio
import math
import re
import uuid
from datetime import datetime, timezone, timedelta
from typing import TYPE_CHECKING

import asyncpg
import structlog

from navi.config import settings

if TYPE_CHECKING:
    from navi.llm.base import LLMBackend

_SEPARATORS = re.compile(r"[-_/\\.]")
_NOISE = re.compile(r"[^\w\s]", re.UNICODE)

log = structlog.get_logger()

_AUTO_DUMP_THRESHOLD = 60
_VECTOR_CUTOFF_DISTANCE = 0.3


def _normalize_query(query: str) -> list[str]:
    q = _SEPARATORS.sub(" ", query)
    q = _NOISE.sub(" ", q)
    return [t for t in q.lower().split() if len(t) > 1]


def _build_ddl(pgvector_available: bool) -> list[str]:
    """Return DDL statements depending on whether pgvector is installed."""
    embedding_col = f"embedding vector({settings.embedding_dimensions})," if pgvector_available else ""
    embedding_idx = (
        "CREATE INDEX IF NOT EXISTS idx_memory_facts_embedding ON memory_facts USING hnsw (embedding vector_cosine_ops)"
        if pgvector_available else ""
    )
    stmts = [
        """CREATE TABLE IF NOT EXISTS memory_facts (
            id          TEXT PRIMARY KEY,
            category    TEXT NOT NULL,
            key         TEXT NOT NULL,
            value       TEXT NOT NULL,
            created_at  TIMESTAMPTZ NOT NULL,
            updated_at  TIMESTAMPTZ NOT NULL,
            source_session_id TEXT,
            %s
            source      TEXT NOT NULL DEFAULT 'conversation',
            confidence  SMALLINT NOT NULL DEFAULT 70,
            expires_at  TIMESTAMPTZ,
            last_verified_at TIMESTAMPTZ,
            source_context TEXT,
            UNIQUE(category, key)
        )""" % embedding_col,
        "CREATE INDEX IF NOT EXISTS idx_memory_facts_expires ON memory_facts (expires_at) WHERE expires_at IS NOT NULL",
        "CREATE INDEX IF NOT EXISTS idx_memory_facts_source_cat ON memory_facts (source, category)",
        # GIN trigram indexes — only if pg_trgm extension is already installed.
        # CREATE EXTENSION requires superuser/CREATE privilege, so we skip it here.
        """DO $$
        BEGIN
            IF EXISTS (SELECT 1 FROM pg_extension WHERE extname = 'pg_trgm') THEN
                CREATE INDEX IF NOT EXISTS idx_memory_facts_cat_trgm ON memory_facts USING gin (category gin_trgm_ops);
                CREATE INDEX IF NOT EXISTS idx_memory_facts_key_trgm ON memory_facts USING gin (key gin_trgm_ops);
                CREATE INDEX IF NOT EXISTS idx_memory_facts_value_trgm ON memory_facts USING gin (value gin_trgm_ops);
            END IF;
        END $$;""",
        """CREATE TABLE IF NOT EXISTS memory_summary (
            id          INTEGER PRIMARY KEY DEFAULT 1,
            content     TEXT NOT NULL,
            generated_at TIMESTAMPTZ NOT NULL
        )""",
        """CREATE TABLE IF NOT EXISTS session_memory_state (
            session_id  TEXT PRIMARY KEY,
            extracted_at TIMESTAMPTZ NOT NULL
        )""",
    ]
    if embedding_idx:
        stmts.insert(2, embedding_idx)
    return stmts


def _vector_to_str(vec: list[float]) -> str | None:
    """Serialize a vector to the PostgreSQL vector literal format.

    Returns None if the vector contains NaN, Infinity, or is empty,
    since pgvector rejects those values.
    """
    if not vec or not all(math.isfinite(v) for v in vec):
        return None
    return "[" + ",".join(str(v) for v in vec) + "]"


class MemoryStore:
    def __init__(self, dsn: str, embedding_backend: "LLMBackend | None" = None) -> None:
        self._dsn = dsn
        self._pool: asyncpg.Pool | None = None
        self._lock = asyncio.Lock()
        self._embedding_backend = embedding_backend
        self._pgvector_checked = False
        self._pgvector_available = False

    def set_embedding_backend(self, backend: "LLMBackend | None") -> None:
        self._embedding_backend = backend

    async def _get_pool(self) -> asyncpg.Pool:
        if self._pool is not None:
            return self._pool
        async with self._lock:
            if self._pool is None:
                pool = await asyncpg.create_pool(self._dsn)
                async with pool.acquire() as conn:
                    # Try to create pgvector extension first
                    pgvector_available = False
                    try:
                        await conn.execute("CREATE EXTENSION IF NOT EXISTS vector")
                        row = await conn.fetchval("SELECT 1 FROM pg_extension WHERE extname = 'vector'")
                        pgvector_available = bool(row)
                    except Exception:
                        log.warning("memory.pgvector_not_available", exc_info=True)

                    for stmt in _build_ddl(pgvector_available):
                        try:
                            await conn.execute(stmt)
                        except Exception:
                            log.warning("memory.ddl_failed", stmt=stmt[:80], exc_info=True)
                self._pool = pool
        return self._pool

    async def _has_pgvector(self) -> bool:
        if self._pgvector_checked:
            return self._pgvector_available
        try:
            pool = await self._get_pool()
            async with pool.acquire() as conn:
                row = await conn.fetchval(
                    "SELECT 1 FROM pg_extension WHERE extname = 'vector'"
                )
                self._pgvector_available = bool(row)
        except Exception:
            self._pgvector_available = False
        self._pgvector_checked = True
        return self._pgvector_available

    async def _generate_embedding(self, text: str) -> list[float] | None:
        if not self._embedding_backend or not await self._has_pgvector():
            return None
        try:
            vectors = await self._embedding_backend.embed(
                texts=[text],
                model=settings.embedding_model,
            )
            if vectors and vectors[0]:
                return vectors[0]
        except Exception:
            log.warning("memory.embed_failed", text=text[:60], exc_info=True)
        return None

    async def _generate_embeddings(self, texts: list[str]) -> list[list[float] | None]:
        if not self._embedding_backend or not await self._has_pgvector():
            return [None] * len(texts)
        try:
            vectors = await self._embedding_backend.embed(
                texts=texts,
                model=settings.embedding_model,
            )
            return [v if v else None for v in vectors]
        except Exception:
            log.warning("memory.embed_batch_failed", count=len(texts), exc_info=True)
            return [None] * len(texts)

    async def backfill_embeddings(self, batch_size: int = 8) -> int:
        pool = await self._get_pool()
        updated = 0
        async with pool.acquire() as conn:
            while True:
                rows = await conn.fetch(
                    "SELECT id, value FROM memory_facts WHERE embedding IS NULL LIMIT $1",
                    batch_size,
                )
                if not rows:
                    break
                ids = [r["id"] for r in rows]
                texts = [r["value"] for r in rows]
                embeddings = await self._generate_embeddings(texts)
                args: list[tuple[str, str]] = []
                for fact_id, emb in zip(ids, embeddings):
                    if emb:
                        vec_str = _vector_to_str(emb)
                        if vec_str:
                            args.append((vec_str, fact_id))
                if args:
                    await conn.executemany(
                        "UPDATE memory_facts SET embedding = $1::vector WHERE id = $2",
                        args,
                    )
                    updated += len(args)
                # Rate-limit against Ollama Cloud (or any remote embed endpoint)
                if len(rows) == batch_size:
                    await asyncio.sleep(2)
        return updated

    # ── Facts ────────────────────────────────────────────────────────────────

    async def upsert_fact(
        self,
        category: str,
        key: str,
        value: str,
        source_session_id: str | None = None,
        source: str = "conversation",
        confidence: int = 70,
        expires_at: datetime | None = None,
        source_context: str = "",
    ) -> None:
        now = datetime.now(timezone.utc)
        embedding = await self._generate_embedding(value)
        vec_str = _vector_to_str(embedding) if embedding else None
        pool = await self._get_pool()
        async with pool.acquire() as conn:
            if vec_str:
                await conn.execute(
                    """INSERT INTO memory_facts
                        (id, category, key, value, created_at, updated_at, source_session_id,
                         embedding, source, confidence, expires_at, source_context)
                       VALUES ($1, $2, $3, $4, $5, $6, $7, $8::vector, $9, $10, $11, $12)
                       ON CONFLICT(category, key) DO UPDATE SET
                           value             = EXCLUDED.value,
                           updated_at        = EXCLUDED.updated_at,
                           source_session_id = EXCLUDED.source_session_id,
                           embedding         = EXCLUDED.embedding,
                           source            = EXCLUDED.source,
                           confidence        = EXCLUDED.confidence,
                           expires_at        = EXCLUDED.expires_at,
                           source_context    = EXCLUDED.source_context""",
                    str(uuid.uuid4()), category, key, value, now, now,
                    source_session_id, vec_str, source, confidence, expires_at, source_context,
                )
            else:
                await conn.execute(
                    """INSERT INTO memory_facts
                        (id, category, key, value, created_at, updated_at, source_session_id,
                         source, confidence, expires_at, source_context)
                       VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11)
                       ON CONFLICT(category, key) DO UPDATE SET
                           value             = EXCLUDED.value,
                           updated_at        = EXCLUDED.updated_at,
                           source_session_id = EXCLUDED.source_session_id,
                           source            = EXCLUDED.source,
                           confidence        = EXCLUDED.confidence,
                           expires_at        = EXCLUDED.expires_at,
                           source_context    = EXCLUDED.source_context""",
                    str(uuid.uuid4()), category, key, value, now, now,
                    source_session_id, source, confidence, expires_at, source_context,
                )

    async def search_facts(self, query: str, limit: int = 15) -> list[dict]:
        # 1. Try vector search if pgvector + embedding backend are available
        if self._embedding_backend and await self._has_pgvector():
            query_embedding = await self._generate_embedding(query)
            vec_str = _vector_to_str(query_embedding) if query_embedding else None
            if vec_str:
                try:
                    pool = await self._get_pool()
                    async with pool.acquire() as conn:
                        rows = await conn.fetch(
                            """SELECT id, category, key, value, updated_at,
                                      source, confidence, expires_at, source_context,
                                      embedding <=> $1::vector AS distance
                               FROM memory_facts
                               WHERE (expires_at IS NULL OR expires_at > now())
                               ORDER BY embedding <=> $1::vector
                               LIMIT $2""",
                            vec_str, limit,
                        )
                        results = [
                            _row_to_dict(r)
                            for r in rows
                            if r["distance"] < _VECTOR_CUTOFF_DISTANCE
                        ]
                        if results:
                            log.debug("memory.vector_search", hits=len(results), query=query[:40])
                            return results
                except Exception:
                    log.warning("memory.vector_search_failed", query=query[:40], exc_info=True)
            else:
                log.warning("memory.vector_fallback_to_ilike", reason="embed_failed", query=query[:40])
        else:
            reason = "no_embedding_backend" if not self._embedding_backend else "no_pgvector"
            log.warning("memory.vector_fallback_to_ilike", reason=reason, query=query[:40])

        # 2. Text fallback (original ILIKE logic)
        terms = _normalize_query(query)
        if not terms:
            return await self.get_all_facts(limit=limit)

        if await self.fact_count() <= _AUTO_DUMP_THRESHOLD:
            return await self.get_all_facts(limit=limit)

        base_params: list = []
        term_conds: list[str] = []
        for term in terms:
            like = f"%{term}%"
            i = len(base_params) + 1
            term_conds.append(
                f"(category ILIKE ${i} OR key ILIKE ${i + 1} OR value ILIKE ${i + 2})"
            )
            base_params.extend([like, like, like])

        limit_idx = len(base_params) + 1
        and_where = " AND ".join(term_conds)
        or_where = " OR ".join(term_conds)

        pool = await self._get_pool()
        async with pool.acquire() as conn:
            rows = await conn.fetch(
                f"SELECT id, category, key, value, updated_at, source, confidence, "
                f"expires_at, source_context FROM memory_facts "
                f"WHERE {and_where} ORDER BY updated_at DESC LIMIT ${limit_idx}",
                *base_params, limit,
            )
            if rows:
                return [_row_to_dict(r) for r in rows]

            score_expr = " + ".join(
                f"CASE WHEN {c} THEN 1 ELSE 0 END" for c in term_conds
            )
            rows = await conn.fetch(
                f"SELECT id, category, key, value, updated_at, source, confidence, "
                f"expires_at, source_context, ({score_expr}) AS score "
                f"FROM memory_facts WHERE {or_where} "
                f"ORDER BY score DESC, updated_at DESC LIMIT ${limit_idx}",
                *base_params, limit,
            )
        return [_row_to_dict(r) for r in rows]

    async def delete_fact(self, key: str, category: str | None = None) -> int:
        pool = await self._get_pool()
        async with pool.acquire() as conn:
            if category:
                result = await conn.execute(
                    "DELETE FROM memory_facts WHERE LOWER(key)=LOWER($1) AND LOWER(category)=LOWER($2)",
                    key, category,
                )
            else:
                result = await conn.execute(
                    "DELETE FROM memory_facts WHERE LOWER(key)=LOWER($1)", key
                )
        return int(result.split()[1])

    async def get_all_facts(self, limit: int | None = None) -> list[dict]:
        q = (
            "SELECT id, category, key, value, updated_at, source, confidence, "
            "expires_at, source_context FROM memory_facts ORDER BY category, updated_at DESC"
        )
        params: list = []
        if limit:
            q += f" LIMIT ${len(params) + 1}"
            params.append(limit)
        pool = await self._get_pool()
        async with pool.acquire() as conn:
            rows = await conn.fetch(q, *params)
        return [_row_to_dict(r) for r in rows]

    async def fact_count(self) -> int:
        pool = await self._get_pool()
        async with pool.acquire() as conn:
            return await conn.fetchval("SELECT COUNT(*) FROM memory_facts") or 0

    # ── Summary ───────────────────────────────────────────────────────────────

    async def get_summary(self) -> str | None:
        pool = await self._get_pool()
        async with pool.acquire() as conn:
            return await conn.fetchval("SELECT content FROM memory_summary WHERE id=1")

    async def set_summary(self, content: str) -> None:
        now = datetime.now(timezone.utc)
        pool = await self._get_pool()
        async with pool.acquire() as conn:
            await conn.execute(
                """INSERT INTO memory_summary (id, content, generated_at) VALUES (1, $1, $2)
                   ON CONFLICT(id) DO UPDATE SET
                       content      = EXCLUDED.content,
                       generated_at = EXCLUDED.generated_at""",
                content, now,
            )

    # ── Session extraction tracking ───────────────────────────────────────────

    async def mark_session_extracted(self, session_id: str) -> None:
        now = datetime.now(timezone.utc)
        pool = await self._get_pool()
        async with pool.acquire() as conn:
            await conn.execute(
                """INSERT INTO session_memory_state (session_id, extracted_at) VALUES ($1, $2)
                   ON CONFLICT(session_id) DO UPDATE SET extracted_at=EXCLUDED.extracted_at""",
                session_id, now,
            )

    async def get_extracted_at(self, session_id: str) -> str | None:
        pool = await self._get_pool()
        async with pool.acquire() as conn:
            row = await conn.fetchrow(
                "SELECT extracted_at FROM session_memory_state WHERE session_id=$1", session_id
            )
        if row is None:
            return None
        val = row["extracted_at"]
        return val.isoformat() if hasattr(val, "isoformat") else str(val)


def _row_to_dict(row: asyncpg.Record) -> dict:
    val = row["updated_at"]
    updated_at = val.isoformat() if hasattr(val, "isoformat") else str(val)
    expires_at = row.get("expires_at")
    expires_str = expires_at.isoformat() if hasattr(expires_at, "isoformat") else None
    return {
        "id": row["id"],
        "category": row["category"],
        "key": row["key"],
        "value": row["value"],
        "updated_at": updated_at,
        "source": row.get("source", "conversation"),
        "confidence": row.get("confidence", 70),
        "expires_at": expires_str,
        "source_context": row.get("source_context", ""),
    }