"""Persistent memory store — facts about the user, backed by PostgreSQL with pgvector support."""
import asyncio
import re
import uuid
from datetime import datetime, timezone, timedelta
from typing import TYPE_CHECKING
import asyncpg
import structlog
from navi.config import settings
if TYPE_CHECKING:
from navi.llm.base import LLMBackend
_SEPARATORS = re.compile(r"[-_/\\.]")
_NOISE = re.compile(r"[^\w\s]", re.UNICODE)
log = structlog.get_logger()
_AUTO_DUMP_THRESHOLD = 60
_VECTOR_CUTOFF_DISTANCE = 0.3
def _normalize_query(query: str) -> list[str]:
q = _SEPARATORS.sub(" ", query)
q = _NOISE.sub(" ", q)
return [t for t in q.lower().split() if len(t) > 1]
def _build_ddl(pgvector_available: bool) -> list[str]:
"""Return DDL statements depending on whether pgvector is installed."""
embedding_col = f"embedding vector({settings.embedding_dimensions})," if pgvector_available else ""
embedding_idx = (
"CREATE INDEX IF NOT EXISTS idx_memory_facts_embedding ON memory_facts USING hnsw (embedding vector_cosine_ops)"
if pgvector_available else ""
)
stmts = [
"""CREATE TABLE IF NOT EXISTS memory_facts (
id TEXT PRIMARY KEY,
category TEXT NOT NULL,
key TEXT NOT NULL,
value TEXT NOT NULL,
created_at TIMESTAMPTZ NOT NULL,
updated_at TIMESTAMPTZ NOT NULL,
source_session_id TEXT,
%s
source TEXT NOT NULL DEFAULT 'conversation',
confidence SMALLINT NOT NULL DEFAULT 70,
expires_at TIMESTAMPTZ,
last_verified_at TIMESTAMPTZ,
source_context TEXT,
UNIQUE(category, key)
)""" % embedding_col,
"CREATE INDEX IF NOT EXISTS idx_memory_facts_expires ON memory_facts (expires_at) WHERE expires_at IS NOT NULL",
"CREATE INDEX IF NOT EXISTS idx_memory_facts_source_cat ON memory_facts (source, category)",
"""CREATE TABLE IF NOT EXISTS memory_summary (
id INTEGER PRIMARY KEY DEFAULT 1,
content TEXT NOT NULL,
generated_at TIMESTAMPTZ NOT NULL
)""",
"""CREATE TABLE IF NOT EXISTS session_memory_state (
session_id TEXT PRIMARY KEY,
extracted_at TIMESTAMPTZ NOT NULL
)""",
]
if embedding_idx:
stmts.insert(1, embedding_idx)
return stmts
def _vector_to_str(vec: list[float]) -> str:
"""Serialize a vector to the PostgreSQL vector literal format."""
return "[" + ",".join(str(v) for v in vec) + "]"
class MemoryStore:
def __init__(self, dsn: str, embedding_backend: "LLMBackend | None" = None) -> None:
self._dsn = dsn
self._pool: asyncpg.Pool | None = None
self._lock = asyncio.Lock()
self._embedding_backend = embedding_backend
self._pgvector_checked = False
self._pgvector_available = False
def set_embedding_backend(self, backend: "LLMBackend | None") -> None:
self._embedding_backend = backend
async def _get_pool(self) -> asyncpg.Pool:
if self._pool is not None:
return self._pool
async with self._lock:
if self._pool is None:
pool = await asyncpg.create_pool(self._dsn)
async with pool.acquire() as conn:
# Try to create pgvector extension first
pgvector_available = False
try:
await conn.execute("CREATE EXTENSION IF NOT EXISTS vector")
row = await conn.fetchval("SELECT 1 FROM pg_extension WHERE extname = 'vector'")
pgvector_available = bool(row)
except Exception:
log.warning("memory.pgvector_not_available", exc_info=True)
for stmt in _build_ddl(pgvector_available):
try:
await conn.execute(stmt)
except Exception:
log.warning("memory.ddl_failed", stmt=stmt[:80], exc_info=True)
self._pool = pool
return self._pool
async def _has_pgvector(self) -> bool:
if self._pgvector_checked:
return self._pgvector_available
try:
pool = await self._get_pool()
async with pool.acquire() as conn:
row = await conn.fetchval(
"SELECT 1 FROM pg_extension WHERE extname = 'vector'"
)
self._pgvector_available = bool(row)
except Exception:
self._pgvector_available = False
self._pgvector_checked = True
return self._pgvector_available
async def _generate_embedding(self, text: str) -> list[float] | None:
if not self._embedding_backend or not await self._has_pgvector():
return None
try:
vectors = await self._embedding_backend.embed(
texts=[text],
model=settings.embedding_model,
)
if vectors and vectors[0]:
return vectors[0]
except Exception:
log.warning("memory.embed_failed", text=text[:60], exc_info=True)
return None
async def _generate_embeddings(self, texts: list[str]) -> list[list[float] | None]:
if not self._embedding_backend or not await self._has_pgvector():
return [None] * len(texts)
try:
vectors = await self._embedding_backend.embed(
texts=texts,
model=settings.embedding_model,
)
return [v if v else None for v in vectors]
except Exception:
log.warning("memory.embed_batch_failed", count=len(texts), exc_info=True)
return [None] * len(texts)
async def backfill_embeddings(self, batch_size: int = 8) -> int:
pool = await self._get_pool()
updated = 0
async with pool.acquire() as conn:
while True:
rows = await conn.fetch(
"SELECT id, value FROM memory_facts WHERE embedding IS NULL LIMIT $1",
batch_size,
)
if not rows:
break
ids = [r["id"] for r in rows]
texts = [r["value"] for r in rows]
embeddings = await self._generate_embeddings(texts)
for fact_id, emb in zip(ids, embeddings):
if emb:
vec_str = _vector_to_str(emb)
await conn.execute(
"UPDATE memory_facts SET embedding = $1::vector WHERE id = $2",
vec_str, fact_id,
)
updated += 1
# Rate-limit against Ollama Cloud (or any remote embed endpoint)
if len(rows) == batch_size:
await asyncio.sleep(2)
return updated
# ── Facts ────────────────────────────────────────────────────────────────
async def upsert_fact(
self,
category: str,
key: str,
value: str,
source_session_id: str | None = None,
source: str = "conversation",
confidence: int = 70,
expires_at: datetime | None = None,
source_context: str = "",
) -> None:
now = datetime.now(timezone.utc)
embedding = await self._generate_embedding(value)
pool = await self._get_pool()
async with pool.acquire() as conn:
if embedding:
vec_str = _vector_to_str(embedding)
await conn.execute(
"""INSERT INTO memory_facts
(id, category, key, value, created_at, updated_at, source_session_id,
embedding, source, confidence, expires_at, source_context)
VALUES ($1, $2, $3, $4, $5, $6, $7, $8::vector, $9, $10, $11, $12)
ON CONFLICT(category, key) DO UPDATE SET
value = EXCLUDED.value,
updated_at = EXCLUDED.updated_at,
source_session_id = EXCLUDED.source_session_id,
embedding = EXCLUDED.embedding,
source = EXCLUDED.source,
confidence = EXCLUDED.confidence,
expires_at = EXCLUDED.expires_at,
source_context = EXCLUDED.source_context""",
str(uuid.uuid4()), category, key, value, now, now,
source_session_id, vec_str, source, confidence, expires_at, source_context,
)
else:
await conn.execute(
"""INSERT INTO memory_facts
(id, category, key, value, created_at, updated_at, source_session_id,
source, confidence, expires_at, source_context)
VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11)
ON CONFLICT(category, key) DO UPDATE SET
value = EXCLUDED.value,
updated_at = EXCLUDED.updated_at,
source_session_id = EXCLUDED.source_session_id,
source = EXCLUDED.source,
confidence = EXCLUDED.confidence,
expires_at = EXCLUDED.expires_at,
source_context = EXCLUDED.source_context""",
str(uuid.uuid4()), category, key, value, now, now,
source_session_id, source, confidence, expires_at, source_context,
)
async def search_facts(self, query: str, limit: int = 15) -> list[dict]:
# 1. Try vector search if pgvector + embedding backend are available
if self._embedding_backend and await self._has_pgvector():
try:
query_embedding = await self._generate_embedding(query)
if query_embedding:
vec_str = _vector_to_str(query_embedding)
pool = await self._get_pool()
async with pool.acquire() as conn:
rows = await conn.fetch(
"""SELECT id, category, key, value, updated_at,
source, confidence, expires_at, source_context,
embedding <=> $1::vector AS distance
FROM memory_facts
WHERE (expires_at IS NULL OR expires_at > now())
ORDER BY embedding <=> $1::vector
LIMIT $2""",
vec_str, limit,
)
results = [
_row_to_dict(r)
for r in rows
if r["distance"] < _VECTOR_CUTOFF_DISTANCE
]
if results:
log.debug("memory.vector_search", hits=len(results), query=query[:40])
return results
except Exception:
log.warning("memory.vector_search_failed", query=query[:40], exc_info=True)
# 2. Text fallback (original ILIKE logic)
terms = _normalize_query(query)
if not terms:
return await self.get_all_facts(limit=limit)
if await self.fact_count() <= _AUTO_DUMP_THRESHOLD:
return await self.get_all_facts(limit=limit)
base_params: list = []
term_conds: list[str] = []
for term in terms:
like = f"%{term}%"
i = len(base_params) + 1
term_conds.append(
f"(category ILIKE ${i} OR key ILIKE ${i + 1} OR value ILIKE ${i + 2})"
)
base_params.extend([like, like, like])
limit_idx = len(base_params) + 1
and_where = " AND ".join(term_conds)
or_where = " OR ".join(term_conds)
pool = await self._get_pool()
async with pool.acquire() as conn:
rows = await conn.fetch(
f"SELECT id, category, key, value, updated_at, source, confidence, "
f"expires_at, source_context FROM memory_facts "
f"WHERE {and_where} ORDER BY updated_at DESC LIMIT ${limit_idx}",
*base_params, limit,
)
if rows:
return [_row_to_dict(r) for r in rows]
score_expr = " + ".join(
f"CASE WHEN {c} THEN 1 ELSE 0 END" for c in term_conds
)
rows = await conn.fetch(
f"SELECT id, category, key, value, updated_at, source, confidence, "
f"expires_at, source_context, ({score_expr}) AS score "
f"FROM memory_facts WHERE {or_where} "
f"ORDER BY score DESC, updated_at DESC LIMIT ${limit_idx}",
*base_params, limit,
)
return [_row_to_dict(r) for r in rows]
async def delete_fact(self, key: str, category: str | None = None) -> int:
pool = await self._get_pool()
async with pool.acquire() as conn:
if category:
result = await conn.execute(
"DELETE FROM memory_facts WHERE LOWER(key)=LOWER($1) AND LOWER(category)=LOWER($2)",
key, category,
)
else:
result = await conn.execute(
"DELETE FROM memory_facts WHERE LOWER(key)=LOWER($1)", key
)
return int(result.split()[1])
async def get_all_facts(self, limit: int | None = None) -> list[dict]:
q = (
"SELECT id, category, key, value, updated_at, source, confidence, "
"expires_at, source_context FROM memory_facts ORDER BY category, updated_at DESC"
)
params: list = []
if limit:
q += f" LIMIT ${len(params) + 1}"
params.append(limit)
pool = await self._get_pool()
async with pool.acquire() as conn:
rows = await conn.fetch(q, *params)
return [_row_to_dict(r) for r in rows]
async def fact_count(self) -> int:
pool = await self._get_pool()
async with pool.acquire() as conn:
return await conn.fetchval("SELECT COUNT(*) FROM memory_facts") or 0
# ── Summary ───────────────────────────────────────────────────────────────
async def get_summary(self) -> str | None:
pool = await self._get_pool()
async with pool.acquire() as conn:
return await conn.fetchval("SELECT content FROM memory_summary WHERE id=1")
async def set_summary(self, content: str) -> None:
now = datetime.now(timezone.utc)
pool = await self._get_pool()
async with pool.acquire() as conn:
await conn.execute(
"""INSERT INTO memory_summary (id, content, generated_at) VALUES (1, $1, $2)
ON CONFLICT(id) DO UPDATE SET
content = EXCLUDED.content,
generated_at = EXCLUDED.generated_at""",
content, now,
)
# ── Session extraction tracking ───────────────────────────────────────────
async def mark_session_extracted(self, session_id: str) -> None:
now = datetime.now(timezone.utc)
pool = await self._get_pool()
async with pool.acquire() as conn:
await conn.execute(
"""INSERT INTO session_memory_state (session_id, extracted_at) VALUES ($1, $2)
ON CONFLICT(session_id) DO UPDATE SET extracted_at=EXCLUDED.extracted_at""",
session_id, now,
)
async def get_extracted_at(self, session_id: str) -> str | None:
pool = await self._get_pool()
async with pool.acquire() as conn:
row = await conn.fetchrow(
"SELECT extracted_at FROM session_memory_state WHERE session_id=$1", session_id
)
if row is None:
return None
val = row["extracted_at"]
return val.isoformat() if hasattr(val, "isoformat") else str(val)
def _row_to_dict(row: asyncpg.Record) -> dict:
val = row["updated_at"]
updated_at = val.isoformat() if hasattr(val, "isoformat") else str(val)
expires_at = row.get("expires_at")
expires_str = expires_at.isoformat() if hasattr(expires_at, "isoformat") else None
return {
"id": row["id"],
"category": row["category"],
"key": row["key"],
"value": row["value"],
"updated_at": updated_at,
"source": row.get("source", "conversation"),
"confidence": row.get("confidence", 70),
"expires_at": expires_str,
"source_context": row.get("source_context", ""),
}