Newer
Older
navi-1 / navi / memory / store.py
"""Persistent memory store — facts about the user, backed by PostgreSQL."""

import asyncio
import uuid
from datetime import datetime, timezone

import asyncpg

_DDL_STATEMENTS = [
    """CREATE TABLE IF NOT EXISTS memory_facts (
        id          TEXT PRIMARY KEY,
        category    TEXT NOT NULL,
        key         TEXT NOT NULL,
        value       TEXT NOT NULL,
        created_at  TIMESTAMPTZ NOT NULL,
        updated_at  TIMESTAMPTZ NOT NULL,
        source_session_id TEXT,
        UNIQUE(category, key)
    )""",
    """CREATE TABLE IF NOT EXISTS memory_summary (
        id          INTEGER PRIMARY KEY DEFAULT 1,
        content     TEXT NOT NULL,
        generated_at TIMESTAMPTZ NOT NULL
    )""",
    """CREATE TABLE IF NOT EXISTS session_memory_state (
        session_id  TEXT PRIMARY KEY,
        extracted_at TIMESTAMPTZ NOT NULL
    )""",
]


class MemoryStore:
    def __init__(self, dsn: str) -> None:
        self._dsn = dsn
        self._pool: asyncpg.Pool | None = None
        self._lock = asyncio.Lock()

    async def _get_pool(self) -> asyncpg.Pool:
        if self._pool is not None:
            return self._pool
        async with self._lock:
            if self._pool is None:
                pool = await asyncpg.create_pool(self._dsn)
                async with pool.acquire() as conn:
                    for stmt in _DDL_STATEMENTS:
                        await conn.execute(stmt)
                self._pool = pool
        return self._pool

    # ── Facts ────────────────────────────────────────────────────────────────

    async def upsert_fact(
        self,
        category: str,
        key: str,
        value: str,
        source_session_id: str | None = None,
    ) -> None:
        now = datetime.now(timezone.utc)
        pool = await self._get_pool()
        async with pool.acquire() as conn:
            await conn.execute(
                """INSERT INTO memory_facts (id, category, key, value, created_at, updated_at, source_session_id)
                   VALUES ($1, $2, $3, $4, $5, $6, $7)
                   ON CONFLICT(category, key) DO UPDATE SET
                       value             = EXCLUDED.value,
                       updated_at        = EXCLUDED.updated_at,
                       source_session_id = EXCLUDED.source_session_id""",
                str(uuid.uuid4()), category, key, value, now, now, source_session_id,
            )

    async def search_facts(self, query: str, limit: int = 15) -> list[dict]:
        terms = [t for t in query.lower().split() if len(t) > 1]
        if not terms:
            return await self.get_all_facts(limit=limit)

        params: list = []
        conditions_parts: list[str] = []
        for term in terms:
            like = f"%{term}%"
            base = len(params) + 1
            conditions_parts.append(
                f"(category ILIKE ${base} OR key ILIKE ${base + 1} OR value ILIKE ${base + 2})"
            )
            params.extend([like, like, like])

        limit_idx = len(params) + 1
        params.append(limit)

        pool = await self._get_pool()
        async with pool.acquire() as conn:
            rows = await conn.fetch(
                f"SELECT id, category, key, value, updated_at FROM memory_facts "
                f"WHERE {' OR '.join(conditions_parts)} ORDER BY updated_at DESC LIMIT ${limit_idx}",
                *params,
            )
        return [_row_to_dict(r) for r in rows]

    async def delete_fact(self, key: str, category: str | None = None) -> int:
        pool = await self._get_pool()
        async with pool.acquire() as conn:
            if category:
                result = await conn.execute(
                    "DELETE FROM memory_facts WHERE LOWER(key)=LOWER($1) AND LOWER(category)=LOWER($2)",
                    key, category,
                )
            else:
                result = await conn.execute(
                    "DELETE FROM memory_facts WHERE LOWER(key)=LOWER($1)", key
                )
        # asyncpg returns "DELETE N"
        return int(result.split()[1])

    async def get_all_facts(self, limit: int | None = None) -> list[dict]:
        q = "SELECT id, category, key, value, updated_at FROM memory_facts ORDER BY category, updated_at DESC"
        params: list = []
        if limit:
            q += f" LIMIT ${len(params) + 1}"
            params.append(limit)
        pool = await self._get_pool()
        async with pool.acquire() as conn:
            rows = await conn.fetch(q, *params)
        return [_row_to_dict(r) for r in rows]

    async def fact_count(self) -> int:
        pool = await self._get_pool()
        async with pool.acquire() as conn:
            return await conn.fetchval("SELECT COUNT(*) FROM memory_facts") or 0

    # ── Summary ───────────────────────────────────────────────────────────────

    async def get_summary(self) -> str | None:
        pool = await self._get_pool()
        async with pool.acquire() as conn:
            return await conn.fetchval("SELECT content FROM memory_summary WHERE id=1")

    async def set_summary(self, content: str) -> None:
        now = datetime.now(timezone.utc)
        pool = await self._get_pool()
        async with pool.acquire() as conn:
            await conn.execute(
                """INSERT INTO memory_summary (id, content, generated_at) VALUES (1, $1, $2)
                   ON CONFLICT(id) DO UPDATE SET
                       content      = EXCLUDED.content,
                       generated_at = EXCLUDED.generated_at""",
                content, now,
            )

    # ── Session extraction tracking ───────────────────────────────────────────

    async def mark_session_extracted(self, session_id: str) -> None:
        now = datetime.now(timezone.utc)
        pool = await self._get_pool()
        async with pool.acquire() as conn:
            await conn.execute(
                """INSERT INTO session_memory_state (session_id, extracted_at) VALUES ($1, $2)
                   ON CONFLICT(session_id) DO UPDATE SET extracted_at=EXCLUDED.extracted_at""",
                session_id, now,
            )

    async def get_extracted_at(self, session_id: str) -> str | None:
        pool = await self._get_pool()
        async with pool.acquire() as conn:
            row = await conn.fetchrow(
                "SELECT extracted_at FROM session_memory_state WHERE session_id=$1", session_id
            )
        if row is None:
            return None
        val = row["extracted_at"]
        return val.isoformat() if hasattr(val, "isoformat") else str(val)


def _row_to_dict(row: asyncpg.Record) -> dict:
    val = row["updated_at"]
    updated_at = val.isoformat() if hasattr(val, "isoformat") else str(val)
    return {
        "id": row["id"],
        "category": row["category"],
        "key": row["key"],
        "value": row["value"],
        "updated_at": updated_at,
    }