"""Persistent memory store — facts about the user, backed by PostgreSQL."""
import asyncio
import uuid
from datetime import datetime, timezone
import asyncpg
_DDL_STATEMENTS = [
"""CREATE TABLE IF NOT EXISTS memory_facts (
id TEXT PRIMARY KEY,
category TEXT NOT NULL,
key TEXT NOT NULL,
value TEXT NOT NULL,
created_at TIMESTAMPTZ NOT NULL,
updated_at TIMESTAMPTZ NOT NULL,
source_session_id TEXT,
UNIQUE(category, key)
)""",
"""CREATE TABLE IF NOT EXISTS memory_summary (
id INTEGER PRIMARY KEY DEFAULT 1,
content TEXT NOT NULL,
generated_at TIMESTAMPTZ NOT NULL
)""",
"""CREATE TABLE IF NOT EXISTS session_memory_state (
session_id TEXT PRIMARY KEY,
extracted_at TIMESTAMPTZ NOT NULL
)""",
]
class MemoryStore:
def __init__(self, dsn: str) -> None:
self._dsn = dsn
self._pool: asyncpg.Pool | None = None
self._lock = asyncio.Lock()
async def _get_pool(self) -> asyncpg.Pool:
if self._pool is not None:
return self._pool
async with self._lock:
if self._pool is None:
pool = await asyncpg.create_pool(self._dsn)
async with pool.acquire() as conn:
for stmt in _DDL_STATEMENTS:
await conn.execute(stmt)
self._pool = pool
return self._pool
# ── Facts ────────────────────────────────────────────────────────────────
async def upsert_fact(
self,
category: str,
key: str,
value: str,
source_session_id: str | None = None,
) -> None:
now = datetime.now(timezone.utc)
pool = await self._get_pool()
async with pool.acquire() as conn:
await conn.execute(
"""INSERT INTO memory_facts (id, category, key, value, created_at, updated_at, source_session_id)
VALUES ($1, $2, $3, $4, $5, $6, $7)
ON CONFLICT(category, key) DO UPDATE SET
value = EXCLUDED.value,
updated_at = EXCLUDED.updated_at,
source_session_id = EXCLUDED.source_session_id""",
str(uuid.uuid4()), category, key, value, now, now, source_session_id,
)
async def search_facts(self, query: str, limit: int = 15) -> list[dict]:
terms = [t for t in query.lower().split() if len(t) > 1]
if not terms:
return await self.get_all_facts(limit=limit)
params: list = []
conditions_parts: list[str] = []
for term in terms:
like = f"%{term}%"
base = len(params) + 1
conditions_parts.append(
f"(category ILIKE ${base} OR key ILIKE ${base + 1} OR value ILIKE ${base + 2})"
)
params.extend([like, like, like])
limit_idx = len(params) + 1
params.append(limit)
pool = await self._get_pool()
async with pool.acquire() as conn:
rows = await conn.fetch(
f"SELECT id, category, key, value, updated_at FROM memory_facts "
f"WHERE {' OR '.join(conditions_parts)} ORDER BY updated_at DESC LIMIT ${limit_idx}",
*params,
)
return [_row_to_dict(r) for r in rows]
async def delete_fact(self, key: str, category: str | None = None) -> int:
pool = await self._get_pool()
async with pool.acquire() as conn:
if category:
result = await conn.execute(
"DELETE FROM memory_facts WHERE LOWER(key)=LOWER($1) AND LOWER(category)=LOWER($2)",
key, category,
)
else:
result = await conn.execute(
"DELETE FROM memory_facts WHERE LOWER(key)=LOWER($1)", key
)
# asyncpg returns "DELETE N"
return int(result.split()[1])
async def get_all_facts(self, limit: int | None = None) -> list[dict]:
q = "SELECT id, category, key, value, updated_at FROM memory_facts ORDER BY category, updated_at DESC"
params: list = []
if limit:
q += f" LIMIT ${len(params) + 1}"
params.append(limit)
pool = await self._get_pool()
async with pool.acquire() as conn:
rows = await conn.fetch(q, *params)
return [_row_to_dict(r) for r in rows]
async def fact_count(self) -> int:
pool = await self._get_pool()
async with pool.acquire() as conn:
return await conn.fetchval("SELECT COUNT(*) FROM memory_facts") or 0
# ── Summary ───────────────────────────────────────────────────────────────
async def get_summary(self) -> str | None:
pool = await self._get_pool()
async with pool.acquire() as conn:
return await conn.fetchval("SELECT content FROM memory_summary WHERE id=1")
async def set_summary(self, content: str) -> None:
now = datetime.now(timezone.utc)
pool = await self._get_pool()
async with pool.acquire() as conn:
await conn.execute(
"""INSERT INTO memory_summary (id, content, generated_at) VALUES (1, $1, $2)
ON CONFLICT(id) DO UPDATE SET
content = EXCLUDED.content,
generated_at = EXCLUDED.generated_at""",
content, now,
)
# ── Session extraction tracking ───────────────────────────────────────────
async def mark_session_extracted(self, session_id: str) -> None:
now = datetime.now(timezone.utc)
pool = await self._get_pool()
async with pool.acquire() as conn:
await conn.execute(
"""INSERT INTO session_memory_state (session_id, extracted_at) VALUES ($1, $2)
ON CONFLICT(session_id) DO UPDATE SET extracted_at=EXCLUDED.extracted_at""",
session_id, now,
)
async def get_extracted_at(self, session_id: str) -> str | None:
pool = await self._get_pool()
async with pool.acquire() as conn:
row = await conn.fetchrow(
"SELECT extracted_at FROM session_memory_state WHERE session_id=$1", session_id
)
if row is None:
return None
val = row["extracted_at"]
return val.isoformat() if hasattr(val, "isoformat") else str(val)
def _row_to_dict(row: asyncpg.Record) -> dict:
val = row["updated_at"]
updated_at = val.isoformat() if hasattr(val, "isoformat") else str(val)
return {
"id": row["id"],
"category": row["category"],
"key": row["key"],
"value": row["value"],
"updated_at": updated_at,
}