Newer
Older
navi-1 / navi / memory / _embeddings.py
@Eugene Sukhodolskiy Eugene Sukhodolskiy on 29 Apr 3 KB Split memory/store.py into focused mixins
"""Embedding generation and backfill — depends on pgvector + an LLM backend."""

import asyncio
import math
from typing import TYPE_CHECKING

import structlog

from navi.config import settings

if TYPE_CHECKING:
    from navi.llm.base import LLMBackend

log = structlog.get_logger()

_VECTOR_CUTOFF_DISTANCE = 0.3


def _vector_to_str(vec: list[float]) -> str | None:
    """Serialize a vector to the PostgreSQL vector literal format.

    Returns None if the vector contains NaN, Infinity, or is empty,
    since pgvector rejects those values.
    """
    if not vec or not all(math.isfinite(v) for v in vec):
        return None
    return "[" + ",".join(str(v) for v in vec) + "]"


class EmbeddingMixin:
    """Provides embedding generation and pgvector helpers.

    Expected on the composite class:
        _embedding_backend: LLMBackend | None
        _pgvector_checked: bool
        _pgvector_available: bool
        _get_pool() -> asyncpg.Pool
    """

    async def _has_pgvector(self) -> bool:
        if self._pgvector_checked:
            return self._pgvector_available
        try:
            pool = await self._get_pool()
            async with pool.acquire() as conn:
                row = await conn.fetchval(
                    "SELECT 1 FROM pg_extension WHERE extname = 'vector'"
                )
                self._pgvector_available = bool(row)
        except Exception:
            self._pgvector_available = False
        self._pgvector_checked = True
        return self._pgvector_available

    async def _generate_embedding(self, text: str) -> list[float] | None:
        if not self._embedding_backend or not await self._has_pgvector():
            return None
        try:
            vectors = await self._embedding_backend.embed(
                texts=[text],
                model=settings.embedding_model,
            )
            if vectors and vectors[0]:
                return vectors[0]
        except Exception:
            log.warning("memory.embed_failed", text=text[:60], exc_info=True)
        return None

    async def _generate_embeddings(self, texts: list[str]) -> list[list[float] | None]:
        if not self._embedding_backend or not await self._has_pgvector():
            return [None] * len(texts)
        try:
            vectors = await self._embedding_backend.embed(
                texts=texts,
                model=settings.embedding_model,
            )
            return [v if v else None for v in vectors]
        except Exception:
            log.warning("memory.embed_batch_failed", count=len(texts), exc_info=True)
            return [None] * len(texts)

    async def backfill_embeddings(self, batch_size: int = 8) -> int:
        pool = await self._get_pool()
        updated = 0
        async with pool.acquire() as conn:
            while True:
                rows = await conn.fetch(
                    "SELECT id, value FROM memory_facts WHERE embedding IS NULL LIMIT $1",
                    batch_size,
                )
                if not rows:
                    break
                ids = [r["id"] for r in rows]
                texts = [r["value"] for r in rows]
                embeddings = await self._generate_embeddings(texts)
                args: list[tuple[str, str]] = []
                for fact_id, emb in zip(ids, embeddings):
                    if emb:
                        vec_str = _vector_to_str(emb)
                        if vec_str:
                            args.append((vec_str, fact_id))
                if args:
                    await conn.executemany(
                        "UPDATE memory_facts SET embedding = $1::vector WHERE id = $2",
                        args,
                    )
                    updated += len(args)
                # Rate-limit against Ollama Cloud (or any remote embed endpoint)
                if len(rows) == batch_size:
                    await asyncio.sleep(2)
        return updated