"""Persistent memory store — facts about the user, backed by PostgreSQL."""
import asyncio
import re
import uuid
from datetime import datetime, timezone
import asyncpg
_SEPARATORS = re.compile(r'[-_/\\.]') # treat as word boundaries
_NOISE = re.compile(r'[^\w\s]', re.UNICODE) # strip remaining punctuation
def _normalize_query(query: str) -> list[str]:
"""Return a clean list of search terms from a raw query string.
- Hyphens, underscores, slashes, dots → spaces (web-search → [web, search])
- All remaining punctuation stripped (commas, quotes, parens …)
- Lowercased, split on whitespace
- Single-character tokens dropped
"""
q = _SEPARATORS.sub(' ', query)
q = _NOISE.sub(' ', q)
return [t for t in q.lower().split() if len(t) > 1]
_AUTO_DUMP_THRESHOLD = 60 # if the DB has ≤ this many facts, skip search and return all
_DDL_STATEMENTS = [
"""CREATE TABLE IF NOT EXISTS memory_facts (
id TEXT PRIMARY KEY,
category TEXT NOT NULL,
key TEXT NOT NULL,
value TEXT NOT NULL,
created_at TIMESTAMPTZ NOT NULL,
updated_at TIMESTAMPTZ NOT NULL,
source_session_id TEXT,
UNIQUE(category, key)
)""",
"""CREATE TABLE IF NOT EXISTS memory_summary (
id INTEGER PRIMARY KEY DEFAULT 1,
content TEXT NOT NULL,
generated_at TIMESTAMPTZ NOT NULL
)""",
"""CREATE TABLE IF NOT EXISTS session_memory_state (
session_id TEXT PRIMARY KEY,
extracted_at TIMESTAMPTZ NOT NULL
)""",
]
class MemoryStore:
def __init__(self, dsn: str) -> None:
self._dsn = dsn
self._pool: asyncpg.Pool | None = None
self._lock = asyncio.Lock()
async def _get_pool(self) -> asyncpg.Pool:
if self._pool is not None:
return self._pool
async with self._lock:
if self._pool is None:
pool = await asyncpg.create_pool(self._dsn)
async with pool.acquire() as conn:
for stmt in _DDL_STATEMENTS:
await conn.execute(stmt)
self._pool = pool
return self._pool
# ── Facts ────────────────────────────────────────────────────────────────
async def upsert_fact(
self,
category: str,
key: str,
value: str,
source_session_id: str | None = None,
) -> None:
now = datetime.now(timezone.utc)
pool = await self._get_pool()
async with pool.acquire() as conn:
await conn.execute(
"""INSERT INTO memory_facts (id, category, key, value, created_at, updated_at, source_session_id)
VALUES ($1, $2, $3, $4, $5, $6, $7)
ON CONFLICT(category, key) DO UPDATE SET
value = EXCLUDED.value,
updated_at = EXCLUDED.updated_at,
source_session_id = EXCLUDED.source_session_id""",
str(uuid.uuid4()), category, key, value, now, now, source_session_id,
)
async def search_facts(self, query: str, limit: int = 15) -> list[dict]:
terms = _normalize_query(query)
if not terms:
return await self.get_all_facts(limit=limit)
# Small DB — skip search entirely, just return the most recent facts
if await self.fact_count() <= _AUTO_DUMP_THRESHOLD:
return await self.get_all_facts(limit=limit)
# Build per-term ILIKE conditions and a shared parameter list
base_params: list = []
term_conds: list[str] = []
for term in terms:
like = f"%{term}%"
i = len(base_params) + 1
term_conds.append(
f"(category ILIKE ${i} OR key ILIKE ${i + 1} OR value ILIKE ${i + 2})"
)
base_params.extend([like, like, like])
limit_idx = len(base_params) + 1
and_where = " AND ".join(term_conds)
or_where = " OR ".join(term_conds)
pool = await self._get_pool()
async with pool.acquire() as conn:
# 1. AND — all terms must match (most precise)
rows = await conn.fetch(
f"SELECT id, category, key, value, updated_at FROM memory_facts "
f"WHERE {and_where} ORDER BY updated_at DESC LIMIT ${limit_idx}",
*base_params, limit,
)
if rows:
return [_row_to_dict(r) for r in rows]
# 2. OR with relevance score — facts matching more terms rank higher
score_expr = " + ".join(
f"CASE WHEN {c} THEN 1 ELSE 0 END" for c in term_conds
)
rows = await conn.fetch(
f"SELECT id, category, key, value, updated_at, ({score_expr}) AS score "
f"FROM memory_facts WHERE {or_where} "
f"ORDER BY score DESC, updated_at DESC LIMIT ${limit_idx}",
*base_params, limit,
)
return [_row_to_dict(r) for r in rows]
async def delete_fact(self, key: str, category: str | None = None) -> int:
pool = await self._get_pool()
async with pool.acquire() as conn:
if category:
result = await conn.execute(
"DELETE FROM memory_facts WHERE LOWER(key)=LOWER($1) AND LOWER(category)=LOWER($2)",
key, category,
)
else:
result = await conn.execute(
"DELETE FROM memory_facts WHERE LOWER(key)=LOWER($1)", key
)
# asyncpg returns "DELETE N"
return int(result.split()[1])
async def get_all_facts(self, limit: int | None = None) -> list[dict]:
q = "SELECT id, category, key, value, updated_at FROM memory_facts ORDER BY category, updated_at DESC"
params: list = []
if limit:
q += f" LIMIT ${len(params) + 1}"
params.append(limit)
pool = await self._get_pool()
async with pool.acquire() as conn:
rows = await conn.fetch(q, *params)
return [_row_to_dict(r) for r in rows]
async def fact_count(self) -> int:
pool = await self._get_pool()
async with pool.acquire() as conn:
return await conn.fetchval("SELECT COUNT(*) FROM memory_facts") or 0
# ── Summary ───────────────────────────────────────────────────────────────
async def get_summary(self) -> str | None:
pool = await self._get_pool()
async with pool.acquire() as conn:
return await conn.fetchval("SELECT content FROM memory_summary WHERE id=1")
async def set_summary(self, content: str) -> None:
now = datetime.now(timezone.utc)
pool = await self._get_pool()
async with pool.acquire() as conn:
await conn.execute(
"""INSERT INTO memory_summary (id, content, generated_at) VALUES (1, $1, $2)
ON CONFLICT(id) DO UPDATE SET
content = EXCLUDED.content,
generated_at = EXCLUDED.generated_at""",
content, now,
)
# ── Session extraction tracking ───────────────────────────────────────────
async def mark_session_extracted(self, session_id: str) -> None:
now = datetime.now(timezone.utc)
pool = await self._get_pool()
async with pool.acquire() as conn:
await conn.execute(
"""INSERT INTO session_memory_state (session_id, extracted_at) VALUES ($1, $2)
ON CONFLICT(session_id) DO UPDATE SET extracted_at=EXCLUDED.extracted_at""",
session_id, now,
)
async def get_extracted_at(self, session_id: str) -> str | None:
pool = await self._get_pool()
async with pool.acquire() as conn:
row = await conn.fetchrow(
"SELECT extracted_at FROM session_memory_state WHERE session_id=$1", session_id
)
if row is None:
return None
val = row["extracted_at"]
return val.isoformat() if hasattr(val, "isoformat") else str(val)
def _row_to_dict(row: asyncpg.Record) -> dict:
val = row["updated_at"]
updated_at = val.isoformat() if hasattr(val, "isoformat") else str(val)
return {
"id": row["id"],
"category": row["category"],
"key": row["key"],
"value": row["value"],
"updated_at": updated_at,
}