"""Fact CRUD and search — vector + ILIKE fallback."""
import re
import uuid
from datetime import datetime, timezone
import asyncpg
import structlog
from navi.config import settings
from ._embeddings import _vector_to_str
log = structlog.get_logger()
_SEPARATORS = re.compile(r"[-_/\\.]")
_NOISE = re.compile(r"[^\w\s]", re.UNICODE)
_AUTO_DUMP_THRESHOLD = 60
_VECTOR_CUTOFF_DISTANCE = 0.3
def _normalize_query(query: str) -> list[str]:
q = _SEPARATORS.sub(" ", query)
q = _NOISE.sub(" ", q)
return [t for t in q.lower().split() if len(t) > 1]
class FactMixin:
"""Fact storage operations.
Expected on the composite class:
_get_pool() -> asyncpg.Pool
_generate_embedding(text) -> list[float] | None
_has_pgvector() -> bool
"""
async def upsert_fact(
self,
category: str,
key: str,
value: str,
user_id: str | None = None,
source_session_id: str | None = None,
source: str = "conversation",
confidence: int = 70,
expires_at: datetime | None = None,
source_context: str = "",
) -> None:
now = datetime.now(timezone.utc)
embedding = await self._generate_embedding(value)
vec_str = _vector_to_str(embedding) if embedding else None
pool = await self._get_pool()
async with pool.acquire() as conn:
if vec_str:
await conn.execute(
"""INSERT INTO memory_facts
(id, user_id, category, key, value, created_at, updated_at, source_session_id,
embedding, source, confidence, expires_at, source_context)
VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9::vector, $10, $11, $12, $13)
ON CONFLICT(user_id, category, key) DO UPDATE SET
value = EXCLUDED.value,
updated_at = EXCLUDED.updated_at,
source_session_id = EXCLUDED.source_session_id,
embedding = EXCLUDED.embedding,
source = EXCLUDED.source,
confidence = EXCLUDED.confidence,
expires_at = EXCLUDED.expires_at,
source_context = EXCLUDED.source_context""",
str(uuid.uuid4()), user_id, category, key, value, now, now,
source_session_id, vec_str, source, confidence, expires_at, source_context,
)
else:
await conn.execute(
"""INSERT INTO memory_facts
(id, user_id, category, key, value, created_at, updated_at, source_session_id,
source, confidence, expires_at, source_context)
VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13)
ON CONFLICT(user_id, category, key) DO UPDATE SET
value = EXCLUDED.value,
updated_at = EXCLUDED.updated_at,
source_session_id = EXCLUDED.source_session_id,
source = EXCLUDED.source,
confidence = EXCLUDED.confidence,
expires_at = EXCLUDED.expires_at,
source_context = EXCLUDED.source_context""",
str(uuid.uuid4()), user_id, category, key, value, now, now,
source_session_id, source, confidence, expires_at, source_context,
)
async def search_facts(self, query: str, user_id: str | None = None, limit: int = 15) -> list[dict]:
user_clause = "AND user_id IS NULL" if user_id is None else "AND user_id = $3"
user_param = () if user_id is None else (user_id,)
# 1. Try vector search if pgvector + embedding backend are available
if self._embedding_backend and await self._has_pgvector():
query_embedding = await self._generate_embedding(query)
vec_str = _vector_to_str(query_embedding) if query_embedding else None
if vec_str:
try:
pool = await self._get_pool()
async with pool.acquire() as conn:
rows = await conn.fetch(
"""SELECT id, category, key, value, updated_at,
source, confidence, expires_at, source_context,
embedding <=> $1::vector AS distance
FROM memory_facts
WHERE (expires_at IS NULL OR expires_at > now())
AND (user_id IS NOT DISTINCT FROM $3)
ORDER BY embedding <=> $1::vector
LIMIT $2""",
vec_str, limit, *user_param,
)
results = [
_row_to_dict(r)
for r in rows
if r["distance"] < _VECTOR_CUTOFF_DISTANCE
]
if results:
log.debug("memory.vector_search", hits=len(results), query=query[:40])
return results
except Exception:
log.warning("memory.vector_search_failed", query=query[:40], exc_info=True)
else:
log.warning("memory.vector_fallback_to_ilike", reason="embed_failed", query=query[:40])
else:
reason = "no_embedding_backend" if not self._embedding_backend else "no_pgvector"
log.warning("memory.vector_fallback_to_ilike", reason=reason, query=query[:40])
# 2. Text fallback (original ILIKE logic)
terms = _normalize_query(query)
if not terms:
return await self.get_all_facts(user_id=user_id, limit=limit)
if await self.fact_count(user_id=user_id) <= _AUTO_DUMP_THRESHOLD:
return await self.get_all_facts(user_id=user_id, limit=limit)
base_params: list = []
term_conds: list[str] = []
for term in terms:
like = f"%{term}%"
i = len(base_params) + 1
term_conds.append(
f"(category ILIKE ${i} OR key ILIKE ${i + 1} OR value ILIKE ${i + 2})"
)
base_params.extend([like, like, like])
user_filter = "user_id IS NULL" if user_id is None else f"user_id = ${len(base_params) + 1}"
if user_id is not None:
base_params.append(user_id)
limit_idx = len(base_params) + 1
and_where = " AND ".join(term_conds)
or_where = " OR ".join(term_conds)
pool = await self._get_pool()
async with pool.acquire() as conn:
rows = await conn.fetch(
f"SELECT id, category, key, value, updated_at, source, confidence, "
f"expires_at, source_context FROM memory_facts "
f"WHERE {user_filter} AND {and_where} ORDER BY updated_at DESC LIMIT ${limit_idx}",
*base_params, limit,
)
if rows:
return [_row_to_dict(r) for r in rows]
score_expr = " + ".join(
f"CASE WHEN {c} THEN 1 ELSE 0 END" for c in term_conds
)
rows = await conn.fetch(
f"SELECT id, category, key, value, updated_at, source, confidence, "
f"expires_at, source_context, ({score_expr}) AS score "
f"FROM memory_facts WHERE {user_filter} AND ({or_where}) "
f"ORDER BY score DESC, updated_at DESC LIMIT ${limit_idx}",
*base_params, limit,
)
return [_row_to_dict(r) for r in rows]
async def delete_fact(self, key: str, category: str | None = None, user_id: str | None = None) -> int:
pool = await self._get_pool()
async with pool.acquire() as conn:
if category:
if user_id is None:
result = await conn.execute(
"DELETE FROM memory_facts WHERE LOWER(key)=LOWER($1) AND LOWER(category)=LOWER($2) AND user_id IS NULL",
key, category,
)
else:
result = await conn.execute(
"DELETE FROM memory_facts WHERE LOWER(key)=LOWER($1) AND LOWER(category)=LOWER($2) AND user_id = $3",
key, category, user_id,
)
else:
if user_id is None:
result = await conn.execute(
"DELETE FROM memory_facts WHERE LOWER(key)=LOWER($1) AND user_id IS NULL", key
)
else:
result = await conn.execute(
"DELETE FROM memory_facts WHERE LOWER(key)=LOWER($1) AND user_id = $2", key, user_id
)
return int(result.split()[1])
async def get_all_facts(self, user_id: str | None = None, limit: int | None = None) -> list[dict]:
q = (
"SELECT id, category, key, value, updated_at, source, confidence, "
"expires_at, source_context FROM memory_facts "
)
params: list = []
if user_id is None:
q += "WHERE user_id IS NULL "
else:
q += "WHERE user_id = $1 "
params.append(user_id)
q += "ORDER BY category, updated_at DESC"
if limit:
q += f" LIMIT ${len(params) + 1}"
params.append(limit)
pool = await self._get_pool()
async with pool.acquire() as conn:
rows = await conn.fetch(q, *params)
return [_row_to_dict(r) for r in rows]
async def fact_count(self, user_id: str | None = None) -> int:
pool = await self._get_pool()
async with pool.acquire() as conn:
if user_id is None:
return await conn.fetchval("SELECT COUNT(*) FROM memory_facts WHERE user_id IS NULL") or 0
return await conn.fetchval("SELECT COUNT(*) FROM memory_facts WHERE user_id = $1", user_id) or 0
def _row_to_dict(row: asyncpg.Record) -> dict:
val = row["updated_at"]
updated_at = val.isoformat() if hasattr(val, "isoformat") else str(val)
expires_at = row.get("expires_at")
expires_str = expires_at.isoformat() if hasattr(expires_at, "isoformat") else None
return {
"id": row["id"],
"category": row["category"],
"key": row["key"],
"value": row["value"],
"updated_at": updated_at,
"source": row.get("source", "conversation"),
"confidence": row.get("confidence", 70),
"expires_at": expires_str,
"source_context": row.get("source_context", ""),
}