Newer
Older
navi-1 / navi / store / __init__.py
"""KV store backed by PostgreSQL for session-scoped data.

Used by todo, scratchpad, and any future per-session state that must
survive server restarts.
"""

import asyncio
from typing import Any

import asyncpg

_DDL = """
CREATE TABLE IF NOT EXISTS session_store (
    id         SERIAL PRIMARY KEY,
    user_id    TEXT NOT NULL DEFAULT '',
    session_id TEXT NOT NULL,
    scope      TEXT NOT NULL,
    key        TEXT NOT NULL,
    value      TEXT NOT NULL DEFAULT '',
    updated_at TIMESTAMPTZ NOT NULL DEFAULT now(),
    UNIQUE (user_id, session_id, scope, key)
);
CREATE INDEX IF NOT EXISTS idx_session_store_lookup
    ON session_store (user_id, session_id, scope, key);
UPDATE session_store SET user_id = '' WHERE user_id IS NULL;
"""


def _norm_uid(user_id: str | None) -> str:
    return user_id or ""


class KvStore:
    """Simple key-value persistence with user + session + scope scoping."""

    def __init__(self, pool: asyncpg.Pool) -> None:
        self._pool = pool
        self._initialized = False
        self._lock = asyncio.Lock()

    async def _get_pool(self) -> asyncpg.Pool:
        if not self._initialized:
            async with self._lock:
                if not self._initialized:
                    async with self._pool.acquire() as conn:
                        await conn.execute(_DDL)
                    self._initialized = True
        return self._pool

    async def get(self, user_id: str | None, session_id: str, scope: str, key: str) -> str | None:
        uid = _norm_uid(user_id)
        pool = await self._get_pool()
        async with pool.acquire() as conn:
            row = await conn.fetchrow(
                "SELECT value FROM session_store WHERE user_id = $1 AND session_id = $2 AND scope = $3 AND key = $4",
                uid, session_id, scope, key,
            )
            return row["value"] if row else None

    async def set(self, user_id: str | None, session_id: str, scope: str, key: str, value: str) -> None:
        uid = _norm_uid(user_id)
        pool = await self._get_pool()
        async with pool.acquire() as conn:
            await conn.execute(
                """
                INSERT INTO session_store (user_id, session_id, scope, key, value, updated_at)
                VALUES ($1, $2, $3, $4, $5, now())
                ON CONFLICT (user_id, session_id, scope, key)
                DO UPDATE SET value = EXCLUDED.value, updated_at = now()
                """,
                uid, session_id, scope, key, value,
            )

    async def get_all(self, user_id: str | None, session_id: str, scope: str) -> dict[str, str]:
        uid = _norm_uid(user_id)
        pool = await self._get_pool()
        async with pool.acquire() as conn:
            rows = await conn.fetch(
                "SELECT key, value FROM session_store WHERE user_id = $1 AND session_id = $2 AND scope = $3",
                uid, session_id, scope,
            )
            return {r["key"]: r["value"] for r in rows}

    async def delete(self, user_id: str | None, session_id: str, scope: str, key: str) -> None:
        uid = _norm_uid(user_id)
        pool = await self._get_pool()
        async with pool.acquire() as conn:
            await conn.execute(
                "DELETE FROM session_store WHERE user_id = $1 AND session_id = $2 AND scope = $3 AND key = $4",
                uid, session_id, scope, key,
            )

    async def clear_scope(self, user_id: str | None, session_id: str, scope: str) -> None:
        uid = _norm_uid(user_id)
        pool = await self._get_pool()
        async with pool.acquire() as conn:
            await conn.execute(
                "DELETE FROM session_store WHERE user_id = $1 AND session_id = $2 AND scope = $3",
                uid, session_id, scope,
            )