"""Fact CRUD and search — vector + ILIKE fallback."""
import re
import uuid
from datetime import datetime, timezone
import asyncpg
import structlog
from navi.config import settings
from ._embeddings import _vector_to_str
log = structlog.get_logger()
_SEPARATORS = re.compile(r"[-_/\\.]")
_NOISE = re.compile(r"[^\w\s]", re.UNICODE)
_AUTO_DUMP_THRESHOLD = 60
_VECTOR_CUTOFF_DISTANCE = 0.3
def _normalize_query(query: str) -> list[str]:
q = _SEPARATORS.sub(" ", query)
q = _NOISE.sub(" ", q)
return [t for t in q.lower().split() if len(t) > 1]
class FactMixin:
"""Fact storage operations.
Expected on the composite class:
_get_pool() -> asyncpg.Pool
_generate_embedding(text) -> list[float] | None
_has_pgvector() -> bool
"""
async def upsert_fact(
self,
category: str,
key: str,
value: str,
user_id: str | None = None,
source_session_id: str | None = None,
source: str = "conversation",
confidence: int = 70,
expires_at: datetime | None = None,
source_context: str = "",
) -> None:
now = datetime.now(timezone.utc)
embedding = await self._generate_embedding(value)
vec_str = _vector_to_str(embedding) if embedding else None
pool = await self._get_pool()
async with pool.acquire() as conn:
if vec_str:
await conn.execute(
"""INSERT INTO memory_facts
(id, user_id, category, key, value, created_at, updated_at, source_session_id,
embedding, source, confidence, expires_at, source_context)
VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9::vector, $10, $11, $12, $13)
ON CONFLICT(user_id, category, key) DO UPDATE SET
value = EXCLUDED.value,
updated_at = EXCLUDED.updated_at,
source_session_id = EXCLUDED.source_session_id,
embedding = EXCLUDED.embedding,
source = EXCLUDED.source,
confidence = EXCLUDED.confidence,
expires_at = EXCLUDED.expires_at,
source_context = EXCLUDED.source_context""",
str(uuid.uuid4()), user_id, category, key, value, now, now,
source_session_id, vec_str, source, confidence, expires_at, source_context,
)
else:
await conn.execute(
"""INSERT INTO memory_facts
(id, user_id, category, key, value, created_at, updated_at, source_session_id,
source, confidence, expires_at, source_context)
VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13)
ON CONFLICT(user_id, category, key) DO UPDATE SET
value = EXCLUDED.value,
updated_at = EXCLUDED.updated_at,
source_session_id = EXCLUDED.source_session_id,
source = EXCLUDED.source,
confidence = EXCLUDED.confidence,
expires_at = EXCLUDED.expires_at,
source_context = EXCLUDED.source_context""",
str(uuid.uuid4()), user_id, category, key, value, now, now,
source_session_id, source, confidence, expires_at, source_context,
)
async def search_facts(self, query: str, user_id: str | None = None, limit: int = 15) -> list[dict]:
# 1. Try vector search if pgvector + embedding backend are available
if self._embedding_backend and await self._has_pgvector():
query_embedding = await self._generate_embedding(query)
vec_str = _vector_to_str(query_embedding) if query_embedding else None
if vec_str:
try:
pool = await self._get_pool()
async with pool.acquire() as conn:
if user_id is None:
rows = await conn.fetch(
"""SELECT id, category, key, value, updated_at,
source, confidence, expires_at, source_context,
embedding <=> $1::vector AS distance
FROM memory_facts
WHERE (expires_at IS NULL OR expires_at > now())
AND user_id IS NULL
ORDER BY embedding <=> $1::vector
LIMIT $2""",
vec_str, limit,
)
else:
rows = await conn.fetch(
"""SELECT id, category, key, value, updated_at,
source, confidence, expires_at, source_context,
embedding <=> $1::vector AS distance
FROM memory_facts
WHERE (expires_at IS NULL OR expires_at > now())
AND user_id = $3
ORDER BY embedding <=> $1::vector
LIMIT $2""",
vec_str, limit, user_id,
)
results = [
_row_to_dict(r)
for r in rows
if r["distance"] < _VECTOR_CUTOFF_DISTANCE
]
if results:
log.debug("memory.vector_search", hits=len(results), query=query[:40])
return results
except Exception:
log.warning("memory.vector_search_failed", query=query[:40], exc_info=True)
else:
log.warning("memory.vector_fallback_to_ilike", reason="embed_failed", query=query[:40])
else:
reason = "no_embedding_backend" if not self._embedding_backend else "no_pgvector"
log.warning("memory.vector_fallback_to_ilike", reason=reason, query=query[:40])
# 2. Text fallback (original ILIKE logic)
terms = _normalize_query(query)
if not terms:
return await self.get_all_facts(user_id=user_id, limit=limit)
if await self.fact_count(user_id=user_id) <= _AUTO_DUMP_THRESHOLD:
return await self.get_all_facts(user_id=user_id, limit=limit)
base_params: list = []
term_conds: list[str] = []
for term in terms:
like = f"%{term}%"
i = len(base_params) + 1
term_conds.append(
f"(category ILIKE ${i} OR key ILIKE ${i + 1} OR value ILIKE ${i + 2})"
)
base_params.extend([like, like, like])
user_filter = "user_id IS NULL" if user_id is None else f"user_id = ${len(base_params) + 1}"
if user_id is not None:
base_params.append(user_id)
limit_idx = len(base_params) + 1
and_where = " AND ".join(term_conds)
or_where = " OR ".join(term_conds)
pool = await self._get_pool()
async with pool.acquire() as conn:
rows = await conn.fetch(
f"SELECT id, category, key, value, updated_at, source, confidence, "
f"expires_at, source_context FROM memory_facts "
f"WHERE {user_filter} AND {and_where} ORDER BY updated_at DESC LIMIT ${limit_idx}",
*base_params, limit,
)
if rows:
return [_row_to_dict(r) for r in rows]
score_expr = " + ".join(
f"CASE WHEN {c} THEN 1 ELSE 0 END" for c in term_conds
)
rows = await conn.fetch(
f"SELECT id, category, key, value, updated_at, source, confidence, "
f"expires_at, source_context, ({score_expr}) AS score "
f"FROM memory_facts WHERE {user_filter} AND ({or_where}) "
f"ORDER BY score DESC, updated_at DESC LIMIT ${limit_idx}",
*base_params, limit,
)
return [_row_to_dict(r) for r in rows]
async def delete_fact(self, key: str, category: str | None = None, user_id: str | None = None) -> int:
pool = await self._get_pool()
async with pool.acquire() as conn:
if category:
if user_id is None:
result = await conn.execute(
"DELETE FROM memory_facts WHERE LOWER(key)=LOWER($1) AND LOWER(category)=LOWER($2) AND user_id IS NULL",
key, category,
)
else:
result = await conn.execute(
"DELETE FROM memory_facts WHERE LOWER(key)=LOWER($1) AND LOWER(category)=LOWER($2) AND user_id = $3",
key, category, user_id,
)
else:
if user_id is None:
result = await conn.execute(
"DELETE FROM memory_facts WHERE LOWER(key)=LOWER($1) AND user_id IS NULL", key
)
else:
result = await conn.execute(
"DELETE FROM memory_facts WHERE LOWER(key)=LOWER($1) AND user_id = $2", key, user_id
)
return int(result.split()[1])
async def get_all_facts(
self,
user_id: str | None = None,
limit: int | None = None,
offset: int = 0,
search: str | None = None,
sort_by: str = "category",
sort_order: str = "desc",
all_users: bool = False,
) -> list[dict]:
q = (
"SELECT id, category, key, value, updated_at, source, confidence, "
"expires_at, source_context FROM memory_facts "
)
params: list = []
param_idx = 0
def add_param(value):
nonlocal param_idx
param_idx += 1
params.append(value)
return f"${param_idx}"
conditions = []
if not all_users:
if user_id is None:
conditions.append("user_id IS NULL")
else:
conditions.append(f"user_id = {add_param(user_id)}")
if search:
like = f"%{search}%"
conditions.append(
f"(key ILIKE {add_param(like)} OR value ILIKE {add_param(like)} OR category ILIKE {add_param(like)})"
)
if conditions:
q += "WHERE " + " AND ".join(conditions) + " "
allowed_cols = {"category", "key", "updated_at", "confidence", "source"}
col = sort_by if sort_by in allowed_cols else "category"
order = "DESC" if sort_order == "desc" else "ASC"
if col == "category":
q += f"ORDER BY {col} {order}, updated_at DESC"
else:
q += f"ORDER BY {col} {order}"
if limit is not None:
q += f" LIMIT {add_param(limit)}"
if offset:
q += f" OFFSET {add_param(offset)}"
pool = await self._get_pool()
async with pool.acquire() as conn:
rows = await conn.fetch(q, *params)
return [_row_to_dict(r) for r in rows]
async def fact_count(
self,
user_id: str | None = None,
all_users: bool = False,
search: str | None = None,
) -> int:
q = "SELECT COUNT(*) FROM memory_facts"
params: list = []
conditions = []
if not all_users:
if user_id is None:
conditions.append("user_id IS NULL")
else:
conditions.append(f"user_id = ${len(params) + 1}")
params.append(user_id)
if search:
like = f"%{search}%"
conditions.append(
f"(key ILIKE ${len(params) + 1} OR value ILIKE ${len(params) + 2} OR category ILIKE ${len(params) + 3})"
)
params.extend([like, like, like])
if conditions:
q += " WHERE " + " AND ".join(conditions)
pool = await self._get_pool()
async with pool.acquire() as conn:
return await conn.fetchval(q, *params) or 0
async def get_categories(self, user_id: str | None = None) -> list[str]:
pool = await self._get_pool()
async with pool.acquire() as conn:
if user_id is None:
rows = await conn.fetch(
"SELECT DISTINCT category FROM memory_facts WHERE user_id IS NULL ORDER BY category"
)
else:
rows = await conn.fetch(
"SELECT DISTINCT category FROM memory_facts WHERE user_id = $1 ORDER BY category",
user_id,
)
return [r["category"] for r in rows]
def _row_to_dict(row: asyncpg.Record) -> dict:
val = row["updated_at"]
updated_at = val.isoformat() if hasattr(val, "isoformat") else str(val)
expires_at = row.get("expires_at")
expires_str = expires_at.isoformat() if hasattr(expires_at, "isoformat") else None
return {
"id": row["id"],
"category": row["category"],
"key": row["key"],
"value": row["value"],
"updated_at": updated_at,
"source": row.get("source", "conversation"),
"confidence": row.get("confidence", 70),
"expires_at": expires_str,
"source_context": row.get("source_context", ""),
}