"""Persistent memory store — facts about the user, backed by PostgreSQL with pgvector support."""
import asyncio
import math
import re
import uuid
from datetime import datetime, timezone, timedelta
from typing import TYPE_CHECKING
import asyncpg
import structlog
from navi.config import settings
if TYPE_CHECKING:
from navi.llm.base import LLMBackend
_SEPARATORS = re.compile(r"[-_/\\.]")
_NOISE = re.compile(r"[^\w\s]", re.UNICODE)
log = structlog.get_logger()
_AUTO_DUMP_THRESHOLD = 60
_VECTOR_CUTOFF_DISTANCE = 0.3
def _normalize_query(query: str) -> list[str]:
q = _SEPARATORS.sub(" ", query)
q = _NOISE.sub(" ", q)
return [t for t in q.lower().split() if len(t) > 1]
def _build_ddl(pgvector_available: bool) -> list[str]:
"""Return DDL statements depending on whether pgvector is installed."""
embedding_col = f"embedding vector({settings.embedding_dimensions})," if pgvector_available else ""
embedding_idx = (
"CREATE INDEX IF NOT EXISTS idx_memory_facts_embedding ON memory_facts USING hnsw (embedding vector_cosine_ops)"
if pgvector_available else ""
)
stmts = [
"CREATE EXTENSION IF NOT EXISTS pg_trgm",
"""CREATE TABLE IF NOT EXISTS memory_facts (
id TEXT PRIMARY KEY,
category TEXT NOT NULL,
key TEXT NOT NULL,
value TEXT NOT NULL,
created_at TIMESTAMPTZ NOT NULL,
updated_at TIMESTAMPTZ NOT NULL,
source_session_id TEXT,
%s
source TEXT NOT NULL DEFAULT 'conversation',
confidence SMALLINT NOT NULL DEFAULT 70,
expires_at TIMESTAMPTZ,
last_verified_at TIMESTAMPTZ,
source_context TEXT,
UNIQUE(category, key)
)""" % embedding_col,
"CREATE INDEX IF NOT EXISTS idx_memory_facts_expires ON memory_facts (expires_at) WHERE expires_at IS NOT NULL",
"CREATE INDEX IF NOT EXISTS idx_memory_facts_source_cat ON memory_facts (source, category)",
"CREATE INDEX IF NOT EXISTS idx_memory_facts_cat_trgm ON memory_facts USING gin (category gin_trgm_ops)",
"CREATE INDEX IF NOT EXISTS idx_memory_facts_key_trgm ON memory_facts USING gin (key gin_trgm_ops)",
"CREATE INDEX IF NOT EXISTS idx_memory_facts_value_trgm ON memory_facts USING gin (value gin_trgm_ops)",
"""CREATE TABLE IF NOT EXISTS memory_summary (
id INTEGER PRIMARY KEY DEFAULT 1,
content TEXT NOT NULL,
generated_at TIMESTAMPTZ NOT NULL
)""",
"""CREATE TABLE IF NOT EXISTS session_memory_state (
session_id TEXT PRIMARY KEY,
extracted_at TIMESTAMPTZ NOT NULL
)""",
]
if embedding_idx:
stmts.insert(2, embedding_idx)
return stmts
def _vector_to_str(vec: list[float]) -> str | None:
"""Serialize a vector to the PostgreSQL vector literal format.
Returns None if the vector contains NaN, Infinity, or is empty,
since pgvector rejects those values.
"""
if not vec or not all(math.isfinite(v) for v in vec):
return None
return "[" + ",".join(str(v) for v in vec) + "]"
class MemoryStore:
def __init__(self, dsn: str, embedding_backend: "LLMBackend | None" = None) -> None:
self._dsn = dsn
self._pool: asyncpg.Pool | None = None
self._lock = asyncio.Lock()
self._embedding_backend = embedding_backend
self._pgvector_checked = False
self._pgvector_available = False
def set_embedding_backend(self, backend: "LLMBackend | None") -> None:
self._embedding_backend = backend
async def _get_pool(self) -> asyncpg.Pool:
if self._pool is not None:
return self._pool
async with self._lock:
if self._pool is None:
pool = await asyncpg.create_pool(self._dsn)
async with pool.acquire() as conn:
# Try to create pgvector extension first
pgvector_available = False
try:
await conn.execute("CREATE EXTENSION IF NOT EXISTS vector")
row = await conn.fetchval("SELECT 1 FROM pg_extension WHERE extname = 'vector'")
pgvector_available = bool(row)
except Exception:
log.warning("memory.pgvector_not_available", exc_info=True)
for stmt in _build_ddl(pgvector_available):
try:
await conn.execute(stmt)
except Exception:
log.warning("memory.ddl_failed", stmt=stmt[:80], exc_info=True)
self._pool = pool
return self._pool
async def _has_pgvector(self) -> bool:
if self._pgvector_checked:
return self._pgvector_available
try:
pool = await self._get_pool()
async with pool.acquire() as conn:
row = await conn.fetchval(
"SELECT 1 FROM pg_extension WHERE extname = 'vector'"
)
self._pgvector_available = bool(row)
except Exception:
self._pgvector_available = False
self._pgvector_checked = True
return self._pgvector_available
async def _generate_embedding(self, text: str) -> list[float] | None:
if not self._embedding_backend or not await self._has_pgvector():
return None
try:
vectors = await self._embedding_backend.embed(
texts=[text],
model=settings.embedding_model,
)
if vectors and vectors[0]:
return vectors[0]
except Exception:
log.warning("memory.embed_failed", text=text[:60], exc_info=True)
return None
async def _generate_embeddings(self, texts: list[str]) -> list[list[float] | None]:
if not self._embedding_backend or not await self._has_pgvector():
return [None] * len(texts)
try:
vectors = await self._embedding_backend.embed(
texts=texts,
model=settings.embedding_model,
)
return [v if v else None for v in vectors]
except Exception:
log.warning("memory.embed_batch_failed", count=len(texts), exc_info=True)
return [None] * len(texts)
async def backfill_embeddings(self, batch_size: int = 8) -> int:
pool = await self._get_pool()
updated = 0
async with pool.acquire() as conn:
while True:
rows = await conn.fetch(
"SELECT id, value FROM memory_facts WHERE embedding IS NULL LIMIT $1",
batch_size,
)
if not rows:
break
ids = [r["id"] for r in rows]
texts = [r["value"] for r in rows]
embeddings = await self._generate_embeddings(texts)
args: list[tuple[str, str]] = []
for fact_id, emb in zip(ids, embeddings):
if emb:
vec_str = _vector_to_str(emb)
if vec_str:
args.append((vec_str, fact_id))
if args:
await conn.executemany(
"UPDATE memory_facts SET embedding = $1::vector WHERE id = $2",
args,
)
updated += len(args)
# Rate-limit against Ollama Cloud (or any remote embed endpoint)
if len(rows) == batch_size:
await asyncio.sleep(2)
return updated
# ── Facts ────────────────────────────────────────────────────────────────
async def upsert_fact(
self,
category: str,
key: str,
value: str,
source_session_id: str | None = None,
source: str = "conversation",
confidence: int = 70,
expires_at: datetime | None = None,
source_context: str = "",
) -> None:
now = datetime.now(timezone.utc)
embedding = await self._generate_embedding(value)
vec_str = _vector_to_str(embedding) if embedding else None
pool = await self._get_pool()
async with pool.acquire() as conn:
if vec_str:
await conn.execute(
"""INSERT INTO memory_facts
(id, category, key, value, created_at, updated_at, source_session_id,
embedding, source, confidence, expires_at, source_context)
VALUES ($1, $2, $3, $4, $5, $6, $7, $8::vector, $9, $10, $11, $12)
ON CONFLICT(category, key) DO UPDATE SET
value = EXCLUDED.value,
updated_at = EXCLUDED.updated_at,
source_session_id = EXCLUDED.source_session_id,
embedding = EXCLUDED.embedding,
source = EXCLUDED.source,
confidence = EXCLUDED.confidence,
expires_at = EXCLUDED.expires_at,
source_context = EXCLUDED.source_context""",
str(uuid.uuid4()), category, key, value, now, now,
source_session_id, vec_str, source, confidence, expires_at, source_context,
)
else:
await conn.execute(
"""INSERT INTO memory_facts
(id, category, key, value, created_at, updated_at, source_session_id,
source, confidence, expires_at, source_context)
VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11)
ON CONFLICT(category, key) DO UPDATE SET
value = EXCLUDED.value,
updated_at = EXCLUDED.updated_at,
source_session_id = EXCLUDED.source_session_id,
source = EXCLUDED.source,
confidence = EXCLUDED.confidence,
expires_at = EXCLUDED.expires_at,
source_context = EXCLUDED.source_context""",
str(uuid.uuid4()), category, key, value, now, now,
source_session_id, source, confidence, expires_at, source_context,
)
async def search_facts(self, query: str, limit: int = 15) -> list[dict]:
# 1. Try vector search if pgvector + embedding backend are available
if self._embedding_backend and await self._has_pgvector():
query_embedding = await self._generate_embedding(query)
vec_str = _vector_to_str(query_embedding) if query_embedding else None
if vec_str:
try:
pool = await self._get_pool()
async with pool.acquire() as conn:
rows = await conn.fetch(
"""SELECT id, category, key, value, updated_at,
source, confidence, expires_at, source_context,
embedding <=> $1::vector AS distance
FROM memory_facts
WHERE (expires_at IS NULL OR expires_at > now())
ORDER BY embedding <=> $1::vector
LIMIT $2""",
vec_str, limit,
)
results = [
_row_to_dict(r)
for r in rows
if r["distance"] < _VECTOR_CUTOFF_DISTANCE
]
if results:
log.debug("memory.vector_search", hits=len(results), query=query[:40])
return results
except Exception:
log.warning("memory.vector_search_failed", query=query[:40], exc_info=True)
else:
log.warning("memory.vector_fallback_to_ilike", reason="embed_failed", query=query[:40])
else:
reason = "no_embedding_backend" if not self._embedding_backend else "no_pgvector"
log.warning("memory.vector_fallback_to_ilike", reason=reason, query=query[:40])
# 2. Text fallback (original ILIKE logic)
terms = _normalize_query(query)
if not terms:
return await self.get_all_facts(limit=limit)
if await self.fact_count() <= _AUTO_DUMP_THRESHOLD:
return await self.get_all_facts(limit=limit)
base_params: list = []
term_conds: list[str] = []
for term in terms:
like = f"%{term}%"
i = len(base_params) + 1
term_conds.append(
f"(category ILIKE ${i} OR key ILIKE ${i + 1} OR value ILIKE ${i + 2})"
)
base_params.extend([like, like, like])
limit_idx = len(base_params) + 1
and_where = " AND ".join(term_conds)
or_where = " OR ".join(term_conds)
pool = await self._get_pool()
async with pool.acquire() as conn:
rows = await conn.fetch(
f"SELECT id, category, key, value, updated_at, source, confidence, "
f"expires_at, source_context FROM memory_facts "
f"WHERE {and_where} ORDER BY updated_at DESC LIMIT ${limit_idx}",
*base_params, limit,
)
if rows:
return [_row_to_dict(r) for r in rows]
score_expr = " + ".join(
f"CASE WHEN {c} THEN 1 ELSE 0 END" for c in term_conds
)
rows = await conn.fetch(
f"SELECT id, category, key, value, updated_at, source, confidence, "
f"expires_at, source_context, ({score_expr}) AS score "
f"FROM memory_facts WHERE {or_where} "
f"ORDER BY score DESC, updated_at DESC LIMIT ${limit_idx}",
*base_params, limit,
)
return [_row_to_dict(r) for r in rows]
async def delete_fact(self, key: str, category: str | None = None) -> int:
pool = await self._get_pool()
async with pool.acquire() as conn:
if category:
result = await conn.execute(
"DELETE FROM memory_facts WHERE LOWER(key)=LOWER($1) AND LOWER(category)=LOWER($2)",
key, category,
)
else:
result = await conn.execute(
"DELETE FROM memory_facts WHERE LOWER(key)=LOWER($1)", key
)
return int(result.split()[1])
async def get_all_facts(self, limit: int | None = None) -> list[dict]:
q = (
"SELECT id, category, key, value, updated_at, source, confidence, "
"expires_at, source_context FROM memory_facts ORDER BY category, updated_at DESC"
)
params: list = []
if limit:
q += f" LIMIT ${len(params) + 1}"
params.append(limit)
pool = await self._get_pool()
async with pool.acquire() as conn:
rows = await conn.fetch(q, *params)
return [_row_to_dict(r) for r in rows]
async def fact_count(self) -> int:
pool = await self._get_pool()
async with pool.acquire() as conn:
return await conn.fetchval("SELECT COUNT(*) FROM memory_facts") or 0
# ── Summary ───────────────────────────────────────────────────────────────
async def get_summary(self) -> str | None:
pool = await self._get_pool()
async with pool.acquire() as conn:
return await conn.fetchval("SELECT content FROM memory_summary WHERE id=1")
async def set_summary(self, content: str) -> None:
now = datetime.now(timezone.utc)
pool = await self._get_pool()
async with pool.acquire() as conn:
await conn.execute(
"""INSERT INTO memory_summary (id, content, generated_at) VALUES (1, $1, $2)
ON CONFLICT(id) DO UPDATE SET
content = EXCLUDED.content,
generated_at = EXCLUDED.generated_at""",
content, now,
)
# ── Session extraction tracking ───────────────────────────────────────────
async def mark_session_extracted(self, session_id: str) -> None:
now = datetime.now(timezone.utc)
pool = await self._get_pool()
async with pool.acquire() as conn:
await conn.execute(
"""INSERT INTO session_memory_state (session_id, extracted_at) VALUES ($1, $2)
ON CONFLICT(session_id) DO UPDATE SET extracted_at=EXCLUDED.extracted_at""",
session_id, now,
)
async def get_extracted_at(self, session_id: str) -> str | None:
pool = await self._get_pool()
async with pool.acquire() as conn:
row = await conn.fetchrow(
"SELECT extracted_at FROM session_memory_state WHERE session_id=$1", session_id
)
if row is None:
return None
val = row["extracted_at"]
return val.isoformat() if hasattr(val, "isoformat") else str(val)
def _row_to_dict(row: asyncpg.Record) -> dict:
val = row["updated_at"]
updated_at = val.isoformat() if hasattr(val, "isoformat") else str(val)
expires_at = row.get("expires_at")
expires_str = expires_at.isoformat() if hasattr(expires_at, "isoformat") else None
return {
"id": row["id"],
"category": row["category"],
"key": row["key"],
"value": row["value"],
"updated_at": updated_at,
"source": row.get("source", "conversation"),
"confidence": row.get("confidence", 70),
"expires_at": expires_str,
"source_context": row.get("source_context", ""),
}