Newer
Older
navi-1 / navi / memory / store.py
"""Persistent memory store — facts about the user, backed by PostgreSQL with pgvector support."""

import asyncio
import re
import uuid
from datetime import datetime, timezone, timedelta
from typing import TYPE_CHECKING

import asyncpg
import structlog

from navi.config import settings

if TYPE_CHECKING:
    from navi.llm.base import LLMBackend

_SEPARATORS = re.compile(r"[-_/\\.]")
_NOISE = re.compile(r"[^\w\s]", re.UNICODE)

log = structlog.get_logger()

_AUTO_DUMP_THRESHOLD = 60
_VECTOR_CUTOFF_DISTANCE = 0.3


def _normalize_query(query: str) -> list[str]:
    q = _SEPARATORS.sub(" ", query)
    q = _NOISE.sub(" ", q)
    return [t for t in q.lower().split() if len(t) > 1]


def _build_ddl(pgvector_available: bool) -> list[str]:
    """Return DDL statements depending on whether pgvector is installed."""
    embedding_col = f"embedding vector({settings.embedding_dimensions})," if pgvector_available else ""
    embedding_idx = (
        "CREATE INDEX IF NOT EXISTS idx_memory_facts_embedding ON memory_facts USING hnsw (embedding vector_cosine_ops)"
        if pgvector_available else ""
    )
    stmts = [
        """CREATE TABLE IF NOT EXISTS memory_facts (
            id          TEXT PRIMARY KEY,
            category    TEXT NOT NULL,
            key         TEXT NOT NULL,
            value       TEXT NOT NULL,
            created_at  TIMESTAMPTZ NOT NULL,
            updated_at  TIMESTAMPTZ NOT NULL,
            source_session_id TEXT,
            %s
            source      TEXT NOT NULL DEFAULT 'conversation',
            confidence  SMALLINT NOT NULL DEFAULT 70,
            expires_at  TIMESTAMPTZ,
            last_verified_at TIMESTAMPTZ,
            source_context TEXT,
            UNIQUE(category, key)
        )""" % embedding_col,
        "CREATE INDEX IF NOT EXISTS idx_memory_facts_expires ON memory_facts (expires_at) WHERE expires_at IS NOT NULL",
        "CREATE INDEX IF NOT EXISTS idx_memory_facts_source_cat ON memory_facts (source, category)",
        """CREATE TABLE IF NOT EXISTS memory_summary (
            id          INTEGER PRIMARY KEY DEFAULT 1,
            content     TEXT NOT NULL,
            generated_at TIMESTAMPTZ NOT NULL
        )""",
        """CREATE TABLE IF NOT EXISTS session_memory_state (
            session_id  TEXT PRIMARY KEY,
            extracted_at TIMESTAMPTZ NOT NULL
        )""",
    ]
    if embedding_idx:
        stmts.insert(1, embedding_idx)
    return stmts


def _vector_to_str(vec: list[float]) -> str:
    """Serialize a vector to the PostgreSQL vector literal format."""
    return "[" + ",".join(str(v) for v in vec) + "]"


class MemoryStore:
    def __init__(self, dsn: str, embedding_backend: "LLMBackend | None" = None) -> None:
        self._dsn = dsn
        self._pool: asyncpg.Pool | None = None
        self._lock = asyncio.Lock()
        self._embedding_backend = embedding_backend
        self._pgvector_checked = False
        self._pgvector_available = False

    def set_embedding_backend(self, backend: "LLMBackend | None") -> None:
        self._embedding_backend = backend

    async def _get_pool(self) -> asyncpg.Pool:
        if self._pool is not None:
            return self._pool
        async with self._lock:
            if self._pool is None:
                pool = await asyncpg.create_pool(self._dsn)
                async with pool.acquire() as conn:
                    # Try to create pgvector extension first
                    pgvector_available = False
                    try:
                        await conn.execute("CREATE EXTENSION IF NOT EXISTS vector")
                        row = await conn.fetchval("SELECT 1 FROM pg_extension WHERE extname = 'vector'")
                        pgvector_available = bool(row)
                    except Exception:
                        log.warning("memory.pgvector_not_available", exc_info=True)

                    for stmt in _build_ddl(pgvector_available):
                        try:
                            await conn.execute(stmt)
                        except Exception:
                            log.warning("memory.ddl_failed", stmt=stmt[:80], exc_info=True)
                self._pool = pool
        return self._pool

    async def _has_pgvector(self) -> bool:
        if self._pgvector_checked:
            return self._pgvector_available
        try:
            pool = await self._get_pool()
            async with pool.acquire() as conn:
                row = await conn.fetchval(
                    "SELECT 1 FROM pg_extension WHERE extname = 'vector'"
                )
                self._pgvector_available = bool(row)
        except Exception:
            self._pgvector_available = False
        self._pgvector_checked = True
        return self._pgvector_available

    async def _generate_embedding(self, text: str) -> list[float] | None:
        if not self._embedding_backend or not await self._has_pgvector():
            return None
        try:
            vectors = await self._embedding_backend.embed(
                texts=[text],
                model=settings.embedding_model,
            )
            if vectors and vectors[0]:
                return vectors[0]
        except Exception:
            log.warning("memory.embed_failed", text=text[:60], exc_info=True)
        return None

    async def _generate_embeddings(self, texts: list[str]) -> list[list[float] | None]:
        if not self._embedding_backend or not await self._has_pgvector():
            return [None] * len(texts)
        try:
            vectors = await self._embedding_backend.embed(
                texts=texts,
                model=settings.embedding_model,
            )
            return [v if v else None for v in vectors]
        except Exception:
            log.warning("memory.embed_batch_failed", count=len(texts), exc_info=True)
            return [None] * len(texts)

    async def backfill_embeddings(self, batch_size: int = 8) -> int:
        pool = await self._get_pool()
        updated = 0
        async with pool.acquire() as conn:
            while True:
                rows = await conn.fetch(
                    "SELECT id, value FROM memory_facts WHERE embedding IS NULL LIMIT $1",
                    batch_size,
                )
                if not rows:
                    break
                ids = [r["id"] for r in rows]
                texts = [r["value"] for r in rows]
                embeddings = await self._generate_embeddings(texts)
                for fact_id, emb in zip(ids, embeddings):
                    if emb:
                        vec_str = _vector_to_str(emb)
                        await conn.execute(
                            "UPDATE memory_facts SET embedding = $1::vector WHERE id = $2",
                            vec_str, fact_id,
                        )
                        updated += 1
                # Rate-limit against Ollama Cloud (or any remote embed endpoint)
                if len(rows) == batch_size:
                    await asyncio.sleep(2)
        return updated

    # ── Facts ────────────────────────────────────────────────────────────────

    async def upsert_fact(
        self,
        category: str,
        key: str,
        value: str,
        source_session_id: str | None = None,
        source: str = "conversation",
        confidence: int = 70,
        expires_at: datetime | None = None,
        source_context: str = "",
    ) -> None:
        now = datetime.now(timezone.utc)
        embedding = await self._generate_embedding(value)
        pool = await self._get_pool()
        async with pool.acquire() as conn:
            if embedding:
                vec_str = _vector_to_str(embedding)
                await conn.execute(
                    """INSERT INTO memory_facts
                        (id, category, key, value, created_at, updated_at, source_session_id,
                         embedding, source, confidence, expires_at, source_context)
                       VALUES ($1, $2, $3, $4, $5, $6, $7, $8::vector, $9, $10, $11, $12)
                       ON CONFLICT(category, key) DO UPDATE SET
                           value             = EXCLUDED.value,
                           updated_at        = EXCLUDED.updated_at,
                           source_session_id = EXCLUDED.source_session_id,
                           embedding         = EXCLUDED.embedding,
                           source            = EXCLUDED.source,
                           confidence        = EXCLUDED.confidence,
                           expires_at        = EXCLUDED.expires_at,
                           source_context    = EXCLUDED.source_context""",
                    str(uuid.uuid4()), category, key, value, now, now,
                    source_session_id, vec_str, source, confidence, expires_at, source_context,
                )
            else:
                await conn.execute(
                    """INSERT INTO memory_facts
                        (id, category, key, value, created_at, updated_at, source_session_id,
                         source, confidence, expires_at, source_context)
                       VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11)
                       ON CONFLICT(category, key) DO UPDATE SET
                           value             = EXCLUDED.value,
                           updated_at        = EXCLUDED.updated_at,
                           source_session_id = EXCLUDED.source_session_id,
                           source            = EXCLUDED.source,
                           confidence        = EXCLUDED.confidence,
                           expires_at        = EXCLUDED.expires_at,
                           source_context    = EXCLUDED.source_context""",
                    str(uuid.uuid4()), category, key, value, now, now,
                    source_session_id, source, confidence, expires_at, source_context,
                )

    async def search_facts(self, query: str, limit: int = 15) -> list[dict]:
        # 1. Try vector search if pgvector + embedding backend are available
        if self._embedding_backend and await self._has_pgvector():
            query_embedding = await self._generate_embedding(query)
            if query_embedding:
                try:
                    vec_str = _vector_to_str(query_embedding)
                    pool = await self._get_pool()
                    async with pool.acquire() as conn:
                        rows = await conn.fetch(
                            """SELECT id, category, key, value, updated_at,
                                      source, confidence, expires_at, source_context,
                                      embedding <=> $1::vector AS distance
                               FROM memory_facts
                               WHERE (expires_at IS NULL OR expires_at > now())
                               ORDER BY embedding <=> $1::vector
                               LIMIT $2""",
                            vec_str, limit,
                        )
                        results = [
                            _row_to_dict(r)
                            for r in rows
                            if r["distance"] < _VECTOR_CUTOFF_DISTANCE
                        ]
                        if results:
                            log.debug("memory.vector_search", hits=len(results), query=query[:40])
                            return results
                except Exception:
                    log.warning("memory.vector_search_failed", query=query[:40], exc_info=True)
            else:
                log.warning("memory.vector_fallback_to_ilike", reason="embed_failed", query=query[:40])
        else:
            reason = "no_embedding_backend" if not self._embedding_backend else "no_pgvector"
            log.warning("memory.vector_fallback_to_ilike", reason=reason, query=query[:40])

        # 2. Text fallback (original ILIKE logic)
        terms = _normalize_query(query)
        if not terms:
            return await self.get_all_facts(limit=limit)

        if await self.fact_count() <= _AUTO_DUMP_THRESHOLD:
            return await self.get_all_facts(limit=limit)

        base_params: list = []
        term_conds: list[str] = []
        for term in terms:
            like = f"%{term}%"
            i = len(base_params) + 1
            term_conds.append(
                f"(category ILIKE ${i} OR key ILIKE ${i + 1} OR value ILIKE ${i + 2})"
            )
            base_params.extend([like, like, like])

        limit_idx = len(base_params) + 1
        and_where = " AND ".join(term_conds)
        or_where = " OR ".join(term_conds)

        pool = await self._get_pool()
        async with pool.acquire() as conn:
            rows = await conn.fetch(
                f"SELECT id, category, key, value, updated_at, source, confidence, "
                f"expires_at, source_context FROM memory_facts "
                f"WHERE {and_where} ORDER BY updated_at DESC LIMIT ${limit_idx}",
                *base_params, limit,
            )
            if rows:
                return [_row_to_dict(r) for r in rows]

            score_expr = " + ".join(
                f"CASE WHEN {c} THEN 1 ELSE 0 END" for c in term_conds
            )
            rows = await conn.fetch(
                f"SELECT id, category, key, value, updated_at, source, confidence, "
                f"expires_at, source_context, ({score_expr}) AS score "
                f"FROM memory_facts WHERE {or_where} "
                f"ORDER BY score DESC, updated_at DESC LIMIT ${limit_idx}",
                *base_params, limit,
            )
        return [_row_to_dict(r) for r in rows]

    async def delete_fact(self, key: str, category: str | None = None) -> int:
        pool = await self._get_pool()
        async with pool.acquire() as conn:
            if category:
                result = await conn.execute(
                    "DELETE FROM memory_facts WHERE LOWER(key)=LOWER($1) AND LOWER(category)=LOWER($2)",
                    key, category,
                )
            else:
                result = await conn.execute(
                    "DELETE FROM memory_facts WHERE LOWER(key)=LOWER($1)", key
                )
        return int(result.split()[1])

    async def get_all_facts(self, limit: int | None = None) -> list[dict]:
        q = (
            "SELECT id, category, key, value, updated_at, source, confidence, "
            "expires_at, source_context FROM memory_facts ORDER BY category, updated_at DESC"
        )
        params: list = []
        if limit:
            q += f" LIMIT ${len(params) + 1}"
            params.append(limit)
        pool = await self._get_pool()
        async with pool.acquire() as conn:
            rows = await conn.fetch(q, *params)
        return [_row_to_dict(r) for r in rows]

    async def fact_count(self) -> int:
        pool = await self._get_pool()
        async with pool.acquire() as conn:
            return await conn.fetchval("SELECT COUNT(*) FROM memory_facts") or 0

    # ── Summary ───────────────────────────────────────────────────────────────

    async def get_summary(self) -> str | None:
        pool = await self._get_pool()
        async with pool.acquire() as conn:
            return await conn.fetchval("SELECT content FROM memory_summary WHERE id=1")

    async def set_summary(self, content: str) -> None:
        now = datetime.now(timezone.utc)
        pool = await self._get_pool()
        async with pool.acquire() as conn:
            await conn.execute(
                """INSERT INTO memory_summary (id, content, generated_at) VALUES (1, $1, $2)
                   ON CONFLICT(id) DO UPDATE SET
                       content      = EXCLUDED.content,
                       generated_at = EXCLUDED.generated_at""",
                content, now,
            )

    # ── Session extraction tracking ───────────────────────────────────────────

    async def mark_session_extracted(self, session_id: str) -> None:
        now = datetime.now(timezone.utc)
        pool = await self._get_pool()
        async with pool.acquire() as conn:
            await conn.execute(
                """INSERT INTO session_memory_state (session_id, extracted_at) VALUES ($1, $2)
                   ON CONFLICT(session_id) DO UPDATE SET extracted_at=EXCLUDED.extracted_at""",
                session_id, now,
            )

    async def get_extracted_at(self, session_id: str) -> str | None:
        pool = await self._get_pool()
        async with pool.acquire() as conn:
            row = await conn.fetchrow(
                "SELECT extracted_at FROM session_memory_state WHERE session_id=$1", session_id
            )
        if row is None:
            return None
        val = row["extracted_at"]
        return val.isoformat() if hasattr(val, "isoformat") else str(val)


def _row_to_dict(row: asyncpg.Record) -> dict:
    val = row["updated_at"]
    updated_at = val.isoformat() if hasattr(val, "isoformat") else str(val)
    expires_at = row.get("expires_at")
    expires_str = expires_at.isoformat() if hasattr(expires_at, "isoformat") else None
    return {
        "id": row["id"],
        "category": row["category"],
        "key": row["key"],
        "value": row["value"],
        "updated_at": updated_at,
        "source": row.get("source", "conversation"),
        "confidence": row.get("confidence", 70),
        "expires_at": expires_str,
        "source_context": row.get("source_context", ""),
    }