Newer
Older
navi-1 / navi / memory / sqlite_store.py
"""SQLite-backed memory store — used when DATABASE_URL is not set."""

import sqlite3
import uuid
from datetime import datetime, timezone

import aiosqlite

_DDL = """
CREATE TABLE IF NOT EXISTS memory_facts (
    id          TEXT PRIMARY KEY,
    category    TEXT NOT NULL,
    key         TEXT NOT NULL,
    value       TEXT NOT NULL,
    created_at  TEXT NOT NULL,
    updated_at  TEXT NOT NULL,
    source_session_id TEXT,
    UNIQUE(category, key)
);

CREATE TABLE IF NOT EXISTS memory_summary (
    id          INTEGER PRIMARY KEY DEFAULT 1,
    content     TEXT NOT NULL,
    generated_at TEXT NOT NULL
);

CREATE TABLE IF NOT EXISTS session_memory_state (
    session_id  TEXT PRIMARY KEY,
    extracted_at TEXT NOT NULL
);
"""


class SqliteMemoryStore:
    def __init__(self, db_path: str) -> None:
        self._db_path = db_path
        with sqlite3.connect(db_path) as conn:
            conn.executescript(_DDL)
            conn.commit()

    # ── Facts ────────────────────────────────────────────────────────────────

    async def upsert_fact(
        self,
        category: str,
        key: str,
        value: str,
        source_session_id: str | None = None,
    ) -> None:
        now = datetime.now(timezone.utc).isoformat()
        async with aiosqlite.connect(self._db_path) as db:
            await db.execute(
                """INSERT INTO memory_facts (id, category, key, value, created_at, updated_at, source_session_id)
                   VALUES (?, ?, ?, ?, ?, ?, ?)
                   ON CONFLICT(category, key) DO UPDATE SET
                       value              = excluded.value,
                       updated_at         = excluded.updated_at,
                       source_session_id  = excluded.source_session_id""",
                (str(uuid.uuid4()), category, key, value, now, now, source_session_id),
            )
            await db.commit()

    async def search_facts(self, query: str, limit: int = 15) -> list[dict]:
        terms = [t for t in query.lower().split() if len(t) > 1]
        if not terms:
            return await self.get_all_facts(limit=limit)

        conditions = " OR ".join(
            ["(LOWER(category) LIKE ? OR LOWER(key) LIKE ? OR LOWER(value) LIKE ?)"] * len(terms)
        )
        params: list = [f"%{t}%" for t in terms for _ in range(3)]

        async with aiosqlite.connect(self._db_path) as db:
            async with db.execute(
                f"SELECT id, category, key, value, updated_at FROM memory_facts "
                f"WHERE {conditions} ORDER BY updated_at DESC LIMIT ?",
                params + [limit],
            ) as cur:
                rows = await cur.fetchall()
        return [_row_to_dict(r) for r in rows]

    async def delete_fact(self, key: str, category: str | None = None) -> int:
        async with aiosqlite.connect(self._db_path) as db:
            if category:
                cur = await db.execute(
                    "DELETE FROM memory_facts WHERE LOWER(key)=LOWER(?) AND LOWER(category)=LOWER(?)",
                    (key, category),
                )
            else:
                cur = await db.execute(
                    "DELETE FROM memory_facts WHERE LOWER(key)=LOWER(?)", (key,)
                )
            await db.commit()
            return cur.rowcount

    async def get_all_facts(self, limit: int | None = None) -> list[dict]:
        q = "SELECT id, category, key, value, updated_at FROM memory_facts ORDER BY category, updated_at DESC"
        if limit:
            q += f" LIMIT {limit}"
        async with aiosqlite.connect(self._db_path) as db:
            async with db.execute(q) as cur:
                rows = await cur.fetchall()
        return [_row_to_dict(r) for r in rows]

    async def fact_count(self) -> int:
        async with aiosqlite.connect(self._db_path) as db:
            async with db.execute("SELECT COUNT(*) FROM memory_facts") as cur:
                row = await cur.fetchone()
        return row[0] if row else 0

    # ── Summary ───────────────────────────────────────────────────────────────

    async def get_summary(self) -> str | None:
        async with aiosqlite.connect(self._db_path) as db:
            async with db.execute("SELECT content FROM memory_summary WHERE id=1") as cur:
                row = await cur.fetchone()
        return row[0] if row else None

    async def set_summary(self, content: str) -> None:
        now = datetime.now(timezone.utc).isoformat()
        async with aiosqlite.connect(self._db_path) as db:
            await db.execute(
                """INSERT INTO memory_summary (id, content, generated_at) VALUES (1, ?, ?)
                   ON CONFLICT(id) DO UPDATE SET content=excluded.content, generated_at=excluded.generated_at""",
                (content, now),
            )
            await db.commit()

    # ── Session extraction tracking ───────────────────────────────────────────

    async def mark_session_extracted(self, session_id: str) -> None:
        now = datetime.now(timezone.utc).isoformat()
        async with aiosqlite.connect(self._db_path) as db:
            await db.execute(
                """INSERT INTO session_memory_state (session_id, extracted_at) VALUES (?, ?)
                   ON CONFLICT(session_id) DO UPDATE SET extracted_at=excluded.extracted_at""",
                (session_id, now),
            )
            await db.commit()

    async def get_extracted_at(self, session_id: str) -> str | None:
        async with aiosqlite.connect(self._db_path) as db:
            async with db.execute(
                "SELECT extracted_at FROM session_memory_state WHERE session_id=?", (session_id,)
            ) as cur:
                row = await cur.fetchone()
        return row[0] if row else None


def _row_to_dict(row: tuple) -> dict:
    return {"id": row[0], "category": row[1], "key": row[2], "value": row[3], "updated_at": row[4]}