Newer
Older
navi-1 / navi / memory / store.py
"""Persistent memory store — facts about the user, backed by PostgreSQL."""

import asyncio
import re
import uuid
from datetime import datetime, timezone

import asyncpg

_SEPARATORS = re.compile(r'[-_/\\.]')          # treat as word boundaries
_NOISE = re.compile(r'[^\w\s]', re.UNICODE)    # strip remaining punctuation


def _normalize_query(query: str) -> list[str]:
    """Return a clean list of search terms from a raw query string.

    - Hyphens, underscores, slashes, dots → spaces  (web-search → [web, search])
    - All remaining punctuation stripped            (commas, quotes, parens …)
    - Lowercased, split on whitespace
    - Single-character tokens dropped
    """
    q = _SEPARATORS.sub(' ', query)
    q = _NOISE.sub(' ', q)
    return [t for t in q.lower().split() if len(t) > 1]


_AUTO_DUMP_THRESHOLD = 60  # if the DB has ≤ this many facts, skip search and return all

_DDL_STATEMENTS = [
    """CREATE TABLE IF NOT EXISTS memory_facts (
        id          TEXT PRIMARY KEY,
        category    TEXT NOT NULL,
        key         TEXT NOT NULL,
        value       TEXT NOT NULL,
        created_at  TIMESTAMPTZ NOT NULL,
        updated_at  TIMESTAMPTZ NOT NULL,
        source_session_id TEXT,
        UNIQUE(category, key)
    )""",
    """CREATE TABLE IF NOT EXISTS memory_summary (
        id          INTEGER PRIMARY KEY DEFAULT 1,
        content     TEXT NOT NULL,
        generated_at TIMESTAMPTZ NOT NULL
    )""",
    """CREATE TABLE IF NOT EXISTS session_memory_state (
        session_id  TEXT PRIMARY KEY,
        extracted_at TIMESTAMPTZ NOT NULL
    )""",
]


class MemoryStore:
    def __init__(self, dsn: str) -> None:
        self._dsn = dsn
        self._pool: asyncpg.Pool | None = None
        self._lock = asyncio.Lock()

    async def _get_pool(self) -> asyncpg.Pool:
        if self._pool is not None:
            return self._pool
        async with self._lock:
            if self._pool is None:
                pool = await asyncpg.create_pool(self._dsn)
                async with pool.acquire() as conn:
                    for stmt in _DDL_STATEMENTS:
                        await conn.execute(stmt)
                self._pool = pool
        return self._pool

    # ── Facts ────────────────────────────────────────────────────────────────

    async def upsert_fact(
        self,
        category: str,
        key: str,
        value: str,
        source_session_id: str | None = None,
    ) -> None:
        now = datetime.now(timezone.utc)
        pool = await self._get_pool()
        async with pool.acquire() as conn:
            await conn.execute(
                """INSERT INTO memory_facts (id, category, key, value, created_at, updated_at, source_session_id)
                   VALUES ($1, $2, $3, $4, $5, $6, $7)
                   ON CONFLICT(category, key) DO UPDATE SET
                       value             = EXCLUDED.value,
                       updated_at        = EXCLUDED.updated_at,
                       source_session_id = EXCLUDED.source_session_id""",
                str(uuid.uuid4()), category, key, value, now, now, source_session_id,
            )

    async def search_facts(self, query: str, limit: int = 15) -> list[dict]:
        terms = _normalize_query(query)
        if not terms:
            return await self.get_all_facts(limit=limit)

        # Small DB — skip search entirely, just return the most recent facts
        if await self.fact_count() <= _AUTO_DUMP_THRESHOLD:
            return await self.get_all_facts(limit=limit)

        # Build per-term ILIKE conditions and a shared parameter list
        base_params: list = []
        term_conds: list[str] = []
        for term in terms:
            like = f"%{term}%"
            i = len(base_params) + 1
            term_conds.append(
                f"(category ILIKE ${i} OR key ILIKE ${i + 1} OR value ILIKE ${i + 2})"
            )
            base_params.extend([like, like, like])

        limit_idx = len(base_params) + 1
        and_where = " AND ".join(term_conds)
        or_where = " OR ".join(term_conds)

        pool = await self._get_pool()
        async with pool.acquire() as conn:
            # 1. AND — all terms must match (most precise)
            rows = await conn.fetch(
                f"SELECT id, category, key, value, updated_at FROM memory_facts "
                f"WHERE {and_where} ORDER BY updated_at DESC LIMIT ${limit_idx}",
                *base_params, limit,
            )
            if rows:
                return [_row_to_dict(r) for r in rows]

            # 2. OR with relevance score — facts matching more terms rank higher
            score_expr = " + ".join(
                f"CASE WHEN {c} THEN 1 ELSE 0 END" for c in term_conds
            )
            rows = await conn.fetch(
                f"SELECT id, category, key, value, updated_at, ({score_expr}) AS score "
                f"FROM memory_facts WHERE {or_where} "
                f"ORDER BY score DESC, updated_at DESC LIMIT ${limit_idx}",
                *base_params, limit,
            )
        return [_row_to_dict(r) for r in rows]

    async def delete_fact(self, key: str, category: str | None = None) -> int:
        pool = await self._get_pool()
        async with pool.acquire() as conn:
            if category:
                result = await conn.execute(
                    "DELETE FROM memory_facts WHERE LOWER(key)=LOWER($1) AND LOWER(category)=LOWER($2)",
                    key, category,
                )
            else:
                result = await conn.execute(
                    "DELETE FROM memory_facts WHERE LOWER(key)=LOWER($1)", key
                )
        # asyncpg returns "DELETE N"
        return int(result.split()[1])

    async def get_all_facts(self, limit: int | None = None) -> list[dict]:
        q = "SELECT id, category, key, value, updated_at FROM memory_facts ORDER BY category, updated_at DESC"
        params: list = []
        if limit:
            q += f" LIMIT ${len(params) + 1}"
            params.append(limit)
        pool = await self._get_pool()
        async with pool.acquire() as conn:
            rows = await conn.fetch(q, *params)
        return [_row_to_dict(r) for r in rows]

    async def fact_count(self) -> int:
        pool = await self._get_pool()
        async with pool.acquire() as conn:
            return await conn.fetchval("SELECT COUNT(*) FROM memory_facts") or 0

    # ── Summary ───────────────────────────────────────────────────────────────

    async def get_summary(self) -> str | None:
        pool = await self._get_pool()
        async with pool.acquire() as conn:
            return await conn.fetchval("SELECT content FROM memory_summary WHERE id=1")

    async def set_summary(self, content: str) -> None:
        now = datetime.now(timezone.utc)
        pool = await self._get_pool()
        async with pool.acquire() as conn:
            await conn.execute(
                """INSERT INTO memory_summary (id, content, generated_at) VALUES (1, $1, $2)
                   ON CONFLICT(id) DO UPDATE SET
                       content      = EXCLUDED.content,
                       generated_at = EXCLUDED.generated_at""",
                content, now,
            )

    # ── Session extraction tracking ───────────────────────────────────────────

    async def mark_session_extracted(self, session_id: str) -> None:
        now = datetime.now(timezone.utc)
        pool = await self._get_pool()
        async with pool.acquire() as conn:
            await conn.execute(
                """INSERT INTO session_memory_state (session_id, extracted_at) VALUES ($1, $2)
                   ON CONFLICT(session_id) DO UPDATE SET extracted_at=EXCLUDED.extracted_at""",
                session_id, now,
            )

    async def get_extracted_at(self, session_id: str) -> str | None:
        pool = await self._get_pool()
        async with pool.acquire() as conn:
            row = await conn.fetchrow(
                "SELECT extracted_at FROM session_memory_state WHERE session_id=$1", session_id
            )
        if row is None:
            return None
        val = row["extracted_at"]
        return val.isoformat() if hasattr(val, "isoformat") else str(val)


def _row_to_dict(row: asyncpg.Record) -> dict:
    val = row["updated_at"]
    updated_at = val.isoformat() if hasattr(val, "isoformat") else str(val)
    return {
        "id": row["id"],
        "category": row["category"],
        "key": row["key"],
        "value": row["value"],
        "updated_at": updated_at,
    }