"""Embedding generation and backfill — depends on pgvector + an LLM backend."""
import asyncio
import math
from typing import TYPE_CHECKING
import structlog
from navi.config import settings
if TYPE_CHECKING:
from navi.llm.base import LLMBackend
log = structlog.get_logger()
_VECTOR_CUTOFF_DISTANCE = 0.3
def _vector_to_str(vec: list[float]) -> str | None:
"""Serialize a vector to the PostgreSQL vector literal format.
Returns None if the vector contains NaN, Infinity, or is empty,
since pgvector rejects those values.
"""
if not vec or not all(math.isfinite(v) for v in vec):
return None
return "[" + ",".join(str(v) for v in vec) + "]"
class EmbeddingMixin:
"""Provides embedding generation and pgvector helpers.
Expected on the composite class:
_embedding_backend: LLMBackend | None
_pgvector_checked: bool
_pgvector_available: bool
_get_pool() -> asyncpg.Pool
"""
async def _has_pgvector(self) -> bool:
if self._pgvector_checked:
return self._pgvector_available
try:
pool = await self._get_pool()
async with pool.acquire() as conn:
row = await conn.fetchval(
"SELECT 1 FROM pg_extension WHERE extname = 'vector'"
)
self._pgvector_available = bool(row)
except Exception:
self._pgvector_available = False
self._pgvector_checked = True
return self._pgvector_available
async def _generate_embedding(self, text: str) -> list[float] | None:
if not self._embedding_backend or not await self._has_pgvector():
return None
try:
vectors = await self._embedding_backend.embed(
texts=[text],
model=settings.embedding_model,
)
if vectors and vectors[0]:
return vectors[0]
except Exception:
log.warning("memory.embed_failed", text=text[:60], exc_info=True)
return None
async def _generate_embeddings(self, texts: list[str]) -> list[list[float] | None]:
if not self._embedding_backend or not await self._has_pgvector():
return [None] * len(texts)
try:
vectors = await self._embedding_backend.embed(
texts=texts,
model=settings.embedding_model,
)
return [v if v else None for v in vectors]
except Exception:
log.warning("memory.embed_batch_failed", count=len(texts), exc_info=True)
return [None] * len(texts)
async def backfill_embeddings(self, batch_size: int = 8) -> int:
pool = await self._get_pool()
updated = 0
async with pool.acquire() as conn:
while True:
rows = await conn.fetch(
"SELECT id, value FROM memory_facts WHERE embedding IS NULL LIMIT $1",
batch_size,
)
if not rows:
break
ids = [r["id"] for r in rows]
texts = [r["value"] for r in rows]
embeddings = await self._generate_embeddings(texts)
args: list[tuple[str, str]] = []
for fact_id, emb in zip(ids, embeddings):
if emb:
vec_str = _vector_to_str(emb)
if vec_str:
args.append((vec_str, fact_id))
if args:
await conn.executemany(
"UPDATE memory_facts SET embedding = $1::vector WHERE id = $2",
args,
)
updated += len(args)
# Rate-limit against Ollama Cloud (or any remote embed endpoint)
if len(rows) == batch_size:
await asyncio.sleep(2)
return updated