diff --git a/navi/memory/_ddl.py b/navi/memory/_ddl.py new file mode 100644 index 0000000..65b3461 --- /dev/null +++ b/navi/memory/_ddl.py @@ -0,0 +1,54 @@ +"""DDL for memory tables — conditional on pgvector/pg_trgm availability.""" + +from navi.config import settings + + +def _build_ddl(pgvector_available: bool) -> list[str]: + """Return DDL statements depending on whether pgvector is installed.""" + embedding_col = f"embedding vector({settings.embedding_dimensions})," if pgvector_available else "" + embedding_idx = ( + "CREATE INDEX IF NOT EXISTS idx_memory_facts_embedding ON memory_facts USING hnsw (embedding vector_cosine_ops)" + if pgvector_available else "" + ) + stmts = [ + """CREATE TABLE IF NOT EXISTS memory_facts ( + id TEXT PRIMARY KEY, + category TEXT NOT NULL, + key TEXT NOT NULL, + value TEXT NOT NULL, + created_at TIMESTAMPTZ NOT NULL, + updated_at TIMESTAMPTZ NOT NULL, + source_session_id TEXT, + %s + source TEXT NOT NULL DEFAULT 'conversation', + confidence SMALLINT NOT NULL DEFAULT 70, + expires_at TIMESTAMPTZ, + last_verified_at TIMESTAMPTZ, + source_context TEXT, + UNIQUE(category, key) + )""" % embedding_col, + "CREATE INDEX IF NOT EXISTS idx_memory_facts_expires ON memory_facts (expires_at) WHERE expires_at IS NOT NULL", + "CREATE INDEX IF NOT EXISTS idx_memory_facts_source_cat ON memory_facts (source, category)", + # GIN trigram indexes — only if pg_trgm extension is already installed. + # CREATE EXTENSION requires superuser/CREATE privilege, so we skip it here. + """DO $$ + BEGIN + IF EXISTS (SELECT 1 FROM pg_extension WHERE extname = 'pg_trgm') THEN + CREATE INDEX IF NOT EXISTS idx_memory_facts_cat_trgm ON memory_facts USING gin (category gin_trgm_ops); + CREATE INDEX IF NOT EXISTS idx_memory_facts_key_trgm ON memory_facts USING gin (key gin_trgm_ops); + CREATE INDEX IF NOT EXISTS idx_memory_facts_value_trgm ON memory_facts USING gin (value gin_trgm_ops); + END IF; + END $$;""", + """CREATE TABLE IF NOT EXISTS memory_summary ( + id INTEGER PRIMARY KEY DEFAULT 1, + content TEXT NOT NULL, + generated_at TIMESTAMPTZ NOT NULL + )""", + """CREATE TABLE IF NOT EXISTS session_memory_state ( + session_id TEXT PRIMARY KEY, + extracted_at TIMESTAMPTZ NOT NULL + )""", + ] + if embedding_idx: + stmts.insert(2, embedding_idx) + return stmts diff --git a/navi/memory/_embeddings.py b/navi/memory/_embeddings.py new file mode 100644 index 0000000..8a0e14f --- /dev/null +++ b/navi/memory/_embeddings.py @@ -0,0 +1,111 @@ +"""Embedding generation and backfill — depends on pgvector + an LLM backend.""" + +import asyncio +import math +from typing import TYPE_CHECKING + +import structlog + +from navi.config import settings + +if TYPE_CHECKING: + from navi.llm.base import LLMBackend + +log = structlog.get_logger() + +_VECTOR_CUTOFF_DISTANCE = 0.3 + + +def _vector_to_str(vec: list[float]) -> str | None: + """Serialize a vector to the PostgreSQL vector literal format. + + Returns None if the vector contains NaN, Infinity, or is empty, + since pgvector rejects those values. + """ + if not vec or not all(math.isfinite(v) for v in vec): + return None + return "[" + ",".join(str(v) for v in vec) + "]" + + +class EmbeddingMixin: + """Provides embedding generation and pgvector helpers. + + Expected on the composite class: + _embedding_backend: LLMBackend | None + _pgvector_checked: bool + _pgvector_available: bool + _get_pool() -> asyncpg.Pool + """ + + async def _has_pgvector(self) -> bool: + if self._pgvector_checked: + return self._pgvector_available + try: + pool = await self._get_pool() + async with pool.acquire() as conn: + row = await conn.fetchval( + "SELECT 1 FROM pg_extension WHERE extname = 'vector'" + ) + self._pgvector_available = bool(row) + except Exception: + self._pgvector_available = False + self._pgvector_checked = True + return self._pgvector_available + + async def _generate_embedding(self, text: str) -> list[float] | None: + if not self._embedding_backend or not await self._has_pgvector(): + return None + try: + vectors = await self._embedding_backend.embed( + texts=[text], + model=settings.embedding_model, + ) + if vectors and vectors[0]: + return vectors[0] + except Exception: + log.warning("memory.embed_failed", text=text[:60], exc_info=True) + return None + + async def _generate_embeddings(self, texts: list[str]) -> list[list[float] | None]: + if not self._embedding_backend or not await self._has_pgvector(): + return [None] * len(texts) + try: + vectors = await self._embedding_backend.embed( + texts=texts, + model=settings.embedding_model, + ) + return [v if v else None for v in vectors] + except Exception: + log.warning("memory.embed_batch_failed", count=len(texts), exc_info=True) + return [None] * len(texts) + + async def backfill_embeddings(self, batch_size: int = 8) -> int: + pool = await self._get_pool() + updated = 0 + async with pool.acquire() as conn: + while True: + rows = await conn.fetch( + "SELECT id, value FROM memory_facts WHERE embedding IS NULL LIMIT $1", + batch_size, + ) + if not rows: + break + ids = [r["id"] for r in rows] + texts = [r["value"] for r in rows] + embeddings = await self._generate_embeddings(texts) + args: list[tuple[str, str]] = [] + for fact_id, emb in zip(ids, embeddings): + if emb: + vec_str = _vector_to_str(emb) + if vec_str: + args.append((vec_str, fact_id)) + if args: + await conn.executemany( + "UPDATE memory_facts SET embedding = $1::vector WHERE id = $2", + args, + ) + updated += len(args) + # Rate-limit against Ollama Cloud (or any remote embed endpoint) + if len(rows) == batch_size: + await asyncio.sleep(2) + return updated diff --git a/navi/memory/_facts.py b/navi/memory/_facts.py new file mode 100644 index 0000000..28aa69d --- /dev/null +++ b/navi/memory/_facts.py @@ -0,0 +1,218 @@ +"""Fact CRUD and search — vector + ILIKE fallback.""" + +import re +import uuid +from datetime import datetime, timezone + +import asyncpg +import structlog + +from navi.config import settings + +from ._embeddings import _vector_to_str + +log = structlog.get_logger() + +_SEPARATORS = re.compile(r"[-_/\\.]") +_NOISE = re.compile(r"[^\w\s]", re.UNICODE) +_AUTO_DUMP_THRESHOLD = 60 +_VECTOR_CUTOFF_DISTANCE = 0.3 + + +def _normalize_query(query: str) -> list[str]: + q = _SEPARATORS.sub(" ", query) + q = _NOISE.sub(" ", q) + return [t for t in q.lower().split() if len(t) > 1] + + +class FactMixin: + """Fact storage operations. + + Expected on the composite class: + _get_pool() -> asyncpg.Pool + _generate_embedding(text) -> list[float] | None + _has_pgvector() -> bool + """ + + async def upsert_fact( + self, + category: str, + key: str, + value: str, + source_session_id: str | None = None, + source: str = "conversation", + confidence: int = 70, + expires_at: datetime | None = None, + source_context: str = "", + ) -> None: + now = datetime.now(timezone.utc) + embedding = await self._generate_embedding(value) + vec_str = _vector_to_str(embedding) if embedding else None + pool = await self._get_pool() + async with pool.acquire() as conn: + if vec_str: + await conn.execute( + """INSERT INTO memory_facts + (id, category, key, value, created_at, updated_at, source_session_id, + embedding, source, confidence, expires_at, source_context) + VALUES ($1, $2, $3, $4, $5, $6, $7, $8::vector, $9, $10, $11, $12) + ON CONFLICT(category, key) DO UPDATE SET + value = EXCLUDED.value, + updated_at = EXCLUDED.updated_at, + source_session_id = EXCLUDED.source_session_id, + embedding = EXCLUDED.embedding, + source = EXCLUDED.source, + confidence = EXCLUDED.confidence, + expires_at = EXCLUDED.expires_at, + source_context = EXCLUDED.source_context""", + str(uuid.uuid4()), category, key, value, now, now, + source_session_id, vec_str, source, confidence, expires_at, source_context, + ) + else: + await conn.execute( + """INSERT INTO memory_facts + (id, category, key, value, created_at, updated_at, source_session_id, + source, confidence, expires_at, source_context) + VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11) + ON CONFLICT(category, key) DO UPDATE SET + value = EXCLUDED.value, + updated_at = EXCLUDED.updated_at, + source_session_id = EXCLUDED.source_session_id, + source = EXCLUDED.source, + confidence = EXCLUDED.confidence, + expires_at = EXCLUDED.expires_at, + source_context = EXCLUDED.source_context""", + str(uuid.uuid4()), category, key, value, now, now, + source_session_id, source, confidence, expires_at, source_context, + ) + + async def search_facts(self, query: str, limit: int = 15) -> list[dict]: + # 1. Try vector search if pgvector + embedding backend are available + if self._embedding_backend and await self._has_pgvector(): + query_embedding = await self._generate_embedding(query) + vec_str = _vector_to_str(query_embedding) if query_embedding else None + if vec_str: + try: + pool = await self._get_pool() + async with pool.acquire() as conn: + rows = await conn.fetch( + """SELECT id, category, key, value, updated_at, + source, confidence, expires_at, source_context, + embedding <=> $1::vector AS distance + FROM memory_facts + WHERE (expires_at IS NULL OR expires_at > now()) + ORDER BY embedding <=> $1::vector + LIMIT $2""", + vec_str, limit, + ) + results = [ + _row_to_dict(r) + for r in rows + if r["distance"] < _VECTOR_CUTOFF_DISTANCE + ] + if results: + log.debug("memory.vector_search", hits=len(results), query=query[:40]) + return results + except Exception: + log.warning("memory.vector_search_failed", query=query[:40], exc_info=True) + else: + log.warning("memory.vector_fallback_to_ilike", reason="embed_failed", query=query[:40]) + else: + reason = "no_embedding_backend" if not self._embedding_backend else "no_pgvector" + log.warning("memory.vector_fallback_to_ilike", reason=reason, query=query[:40]) + + # 2. Text fallback (original ILIKE logic) + terms = _normalize_query(query) + if not terms: + return await self.get_all_facts(limit=limit) + + if await self.fact_count() <= _AUTO_DUMP_THRESHOLD: + return await self.get_all_facts(limit=limit) + + base_params: list = [] + term_conds: list[str] = [] + for term in terms: + like = f"%{term}%" + i = len(base_params) + 1 + term_conds.append( + f"(category ILIKE ${i} OR key ILIKE ${i + 1} OR value ILIKE ${i + 2})" + ) + base_params.extend([like, like, like]) + + limit_idx = len(base_params) + 1 + and_where = " AND ".join(term_conds) + or_where = " OR ".join(term_conds) + + pool = await self._get_pool() + async with pool.acquire() as conn: + rows = await conn.fetch( + f"SELECT id, category, key, value, updated_at, source, confidence, " + f"expires_at, source_context FROM memory_facts " + f"WHERE {and_where} ORDER BY updated_at DESC LIMIT ${limit_idx}", + *base_params, limit, + ) + if rows: + return [_row_to_dict(r) for r in rows] + + score_expr = " + ".join( + f"CASE WHEN {c} THEN 1 ELSE 0 END" for c in term_conds + ) + rows = await conn.fetch( + f"SELECT id, category, key, value, updated_at, source, confidence, " + f"expires_at, source_context, ({score_expr}) AS score " + f"FROM memory_facts WHERE {or_where} " + f"ORDER BY score DESC, updated_at DESC LIMIT ${limit_idx}", + *base_params, limit, + ) + return [_row_to_dict(r) for r in rows] + + async def delete_fact(self, key: str, category: str | None = None) -> int: + pool = await self._get_pool() + async with pool.acquire() as conn: + if category: + result = await conn.execute( + "DELETE FROM memory_facts WHERE LOWER(key)=LOWER($1) AND LOWER(category)=LOWER($2)", + key, category, + ) + else: + result = await conn.execute( + "DELETE FROM memory_facts WHERE LOWER(key)=LOWER($1)", key + ) + return int(result.split()[1]) + + async def get_all_facts(self, limit: int | None = None) -> list[dict]: + q = ( + "SELECT id, category, key, value, updated_at, source, confidence, " + "expires_at, source_context FROM memory_facts ORDER BY category, updated_at DESC" + ) + params: list = [] + if limit: + q += f" LIMIT ${len(params) + 1}" + params.append(limit) + pool = await self._get_pool() + async with pool.acquire() as conn: + rows = await conn.fetch(q, *params) + return [_row_to_dict(r) for r in rows] + + async def fact_count(self) -> int: + pool = await self._get_pool() + async with pool.acquire() as conn: + return await conn.fetchval("SELECT COUNT(*) FROM memory_facts") or 0 + + +def _row_to_dict(row: asyncpg.Record) -> dict: + val = row["updated_at"] + updated_at = val.isoformat() if hasattr(val, "isoformat") else str(val) + expires_at = row.get("expires_at") + expires_str = expires_at.isoformat() if hasattr(expires_at, "isoformat") else None + return { + "id": row["id"], + "category": row["category"], + "key": row["key"], + "value": row["value"], + "updated_at": updated_at, + "source": row.get("source", "conversation"), + "confidence": row.get("confidence", 70), + "expires_at": expires_str, + "source_context": row.get("source_context", ""), + } diff --git a/navi/memory/_session_state.py b/navi/memory/_session_state.py new file mode 100644 index 0000000..380f5d1 --- /dev/null +++ b/navi/memory/_session_state.py @@ -0,0 +1,32 @@ +"""Per-session extraction tracking — avoids re-processing the same session.""" + +from datetime import datetime, timezone + + +class SessionStateMixin: + """Session extraction state. + + Expected on the composite class: + _get_pool() -> asyncpg.Pool + """ + + async def mark_session_extracted(self, session_id: str) -> None: + now = datetime.now(timezone.utc) + pool = await self._get_pool() + async with pool.acquire() as conn: + await conn.execute( + """INSERT INTO session_memory_state (session_id, extracted_at) VALUES ($1, $2) + ON CONFLICT(session_id) DO UPDATE SET extracted_at=EXCLUDED.extracted_at""", + session_id, now, + ) + + async def get_extracted_at(self, session_id: str) -> str | None: + pool = await self._get_pool() + async with pool.acquire() as conn: + row = await conn.fetchrow( + "SELECT extracted_at FROM session_memory_state WHERE session_id=$1", session_id + ) + if row is None: + return None + val = row["extracted_at"] + return val.isoformat() if hasattr(val, "isoformat") else str(val) diff --git a/navi/memory/_summary.py b/navi/memory/_summary.py new file mode 100644 index 0000000..c7b80a7 --- /dev/null +++ b/navi/memory/_summary.py @@ -0,0 +1,28 @@ +"""Conversation summary persistence — single-row table.""" + +from datetime import datetime, timezone + + +class SummaryMixin: + """Summary storage operations. + + Expected on the composite class: + _get_pool() -> asyncpg.Pool + """ + + async def get_summary(self) -> str | None: + pool = await self._get_pool() + async with pool.acquire() as conn: + return await conn.fetchval("SELECT content FROM memory_summary WHERE id=1") + + async def set_summary(self, content: str) -> None: + now = datetime.now(timezone.utc) + pool = await self._get_pool() + async with pool.acquire() as conn: + await conn.execute( + """INSERT INTO memory_summary (id, content, generated_at) VALUES (1, $1, $2) + ON CONFLICT(id) DO UPDATE SET + content = EXCLUDED.content, + generated_at = EXCLUDED.generated_at""", + content, now, + ) diff --git a/navi/memory/store.py b/navi/memory/store.py index 61bc26f..90e6172 100644 --- a/navi/memory/store.py +++ b/navi/memory/store.py @@ -1,10 +1,13 @@ -"""Persistent memory store — facts about the user, backed by PostgreSQL with pgvector support.""" +"""Persistent memory store — facts about the user, backed by PostgreSQL with pgvector support. + +MemoryStore is a composite of mixins: + - EmbeddingMixin (pgvector checks, embedding generation, backfill) + - FactMixin (CRUD + search with vector/ILIKE fallback) + - SummaryMixin (conversation summary) + - SessionStateMixin (per-session extraction tracking) +""" import asyncio -import math -import re -import uuid -from datetime import datetime, timezone, timedelta from typing import TYPE_CHECKING import asyncpg @@ -12,87 +15,19 @@ from navi.config import settings +from ._ddl import _build_ddl +from ._embeddings import EmbeddingMixin +from ._facts import FactMixin +from ._session_state import SessionStateMixin +from ._summary import SummaryMixin + if TYPE_CHECKING: from navi.llm.base import LLMBackend -_SEPARATORS = re.compile(r"[-_/\\.]") -_NOISE = re.compile(r"[^\w\s]", re.UNICODE) - log = structlog.get_logger() -_AUTO_DUMP_THRESHOLD = 60 -_VECTOR_CUTOFF_DISTANCE = 0.3 - -def _normalize_query(query: str) -> list[str]: - q = _SEPARATORS.sub(" ", query) - q = _NOISE.sub(" ", q) - return [t for t in q.lower().split() if len(t) > 1] - - -def _build_ddl(pgvector_available: bool) -> list[str]: - """Return DDL statements depending on whether pgvector is installed.""" - embedding_col = f"embedding vector({settings.embedding_dimensions})," if pgvector_available else "" - embedding_idx = ( - "CREATE INDEX IF NOT EXISTS idx_memory_facts_embedding ON memory_facts USING hnsw (embedding vector_cosine_ops)" - if pgvector_available else "" - ) - stmts = [ - """CREATE TABLE IF NOT EXISTS memory_facts ( - id TEXT PRIMARY KEY, - category TEXT NOT NULL, - key TEXT NOT NULL, - value TEXT NOT NULL, - created_at TIMESTAMPTZ NOT NULL, - updated_at TIMESTAMPTZ NOT NULL, - source_session_id TEXT, - %s - source TEXT NOT NULL DEFAULT 'conversation', - confidence SMALLINT NOT NULL DEFAULT 70, - expires_at TIMESTAMPTZ, - last_verified_at TIMESTAMPTZ, - source_context TEXT, - UNIQUE(category, key) - )""" % embedding_col, - "CREATE INDEX IF NOT EXISTS idx_memory_facts_expires ON memory_facts (expires_at) WHERE expires_at IS NOT NULL", - "CREATE INDEX IF NOT EXISTS idx_memory_facts_source_cat ON memory_facts (source, category)", - # GIN trigram indexes — only if pg_trgm extension is already installed. - # CREATE EXTENSION requires superuser/CREATE privilege, so we skip it here. - """DO $$ - BEGIN - IF EXISTS (SELECT 1 FROM pg_extension WHERE extname = 'pg_trgm') THEN - CREATE INDEX IF NOT EXISTS idx_memory_facts_cat_trgm ON memory_facts USING gin (category gin_trgm_ops); - CREATE INDEX IF NOT EXISTS idx_memory_facts_key_trgm ON memory_facts USING gin (key gin_trgm_ops); - CREATE INDEX IF NOT EXISTS idx_memory_facts_value_trgm ON memory_facts USING gin (value gin_trgm_ops); - END IF; - END $$;""", - """CREATE TABLE IF NOT EXISTS memory_summary ( - id INTEGER PRIMARY KEY DEFAULT 1, - content TEXT NOT NULL, - generated_at TIMESTAMPTZ NOT NULL - )""", - """CREATE TABLE IF NOT EXISTS session_memory_state ( - session_id TEXT PRIMARY KEY, - extracted_at TIMESTAMPTZ NOT NULL - )""", - ] - if embedding_idx: - stmts.insert(2, embedding_idx) - return stmts - - -def _vector_to_str(vec: list[float]) -> str | None: - """Serialize a vector to the PostgreSQL vector literal format. - - Returns None if the vector contains NaN, Infinity, or is empty, - since pgvector rejects those values. - """ - if not vec or not all(math.isfinite(v) for v in vec): - return None - return "[" + ",".join(str(v) for v in vec) + "]" - - -class MemoryStore: +class MemoryStore(EmbeddingMixin, FactMixin, SummaryMixin, SessionStateMixin): def __init__(self, dsn: str, embedding_backend: "LLMBackend | None" = None) -> None: self._dsn = dsn self._pool: asyncpg.Pool | None = None @@ -111,7 +46,6 @@ if self._pool is None: pool = await asyncpg.create_pool(self._dsn) async with pool.acquire() as conn: - # Try to create pgvector extension first pgvector_available = False try: await conn.execute("CREATE EXTENSION IF NOT EXISTS vector") @@ -127,303 +61,3 @@ log.warning("memory.ddl_failed", stmt=stmt[:80], exc_info=True) self._pool = pool return self._pool - - async def _has_pgvector(self) -> bool: - if self._pgvector_checked: - return self._pgvector_available - try: - pool = await self._get_pool() - async with pool.acquire() as conn: - row = await conn.fetchval( - "SELECT 1 FROM pg_extension WHERE extname = 'vector'" - ) - self._pgvector_available = bool(row) - except Exception: - self._pgvector_available = False - self._pgvector_checked = True - return self._pgvector_available - - async def _generate_embedding(self, text: str) -> list[float] | None: - if not self._embedding_backend or not await self._has_pgvector(): - return None - try: - vectors = await self._embedding_backend.embed( - texts=[text], - model=settings.embedding_model, - ) - if vectors and vectors[0]: - return vectors[0] - except Exception: - log.warning("memory.embed_failed", text=text[:60], exc_info=True) - return None - - async def _generate_embeddings(self, texts: list[str]) -> list[list[float] | None]: - if not self._embedding_backend or not await self._has_pgvector(): - return [None] * len(texts) - try: - vectors = await self._embedding_backend.embed( - texts=texts, - model=settings.embedding_model, - ) - return [v if v else None for v in vectors] - except Exception: - log.warning("memory.embed_batch_failed", count=len(texts), exc_info=True) - return [None] * len(texts) - - async def backfill_embeddings(self, batch_size: int = 8) -> int: - pool = await self._get_pool() - updated = 0 - async with pool.acquire() as conn: - while True: - rows = await conn.fetch( - "SELECT id, value FROM memory_facts WHERE embedding IS NULL LIMIT $1", - batch_size, - ) - if not rows: - break - ids = [r["id"] for r in rows] - texts = [r["value"] for r in rows] - embeddings = await self._generate_embeddings(texts) - args: list[tuple[str, str]] = [] - for fact_id, emb in zip(ids, embeddings): - if emb: - vec_str = _vector_to_str(emb) - if vec_str: - args.append((vec_str, fact_id)) - if args: - await conn.executemany( - "UPDATE memory_facts SET embedding = $1::vector WHERE id = $2", - args, - ) - updated += len(args) - # Rate-limit against Ollama Cloud (or any remote embed endpoint) - if len(rows) == batch_size: - await asyncio.sleep(2) - return updated - - # ── Facts ──────────────────────────────────────────────────────────────── - - async def upsert_fact( - self, - category: str, - key: str, - value: str, - source_session_id: str | None = None, - source: str = "conversation", - confidence: int = 70, - expires_at: datetime | None = None, - source_context: str = "", - ) -> None: - now = datetime.now(timezone.utc) - embedding = await self._generate_embedding(value) - vec_str = _vector_to_str(embedding) if embedding else None - pool = await self._get_pool() - async with pool.acquire() as conn: - if vec_str: - await conn.execute( - """INSERT INTO memory_facts - (id, category, key, value, created_at, updated_at, source_session_id, - embedding, source, confidence, expires_at, source_context) - VALUES ($1, $2, $3, $4, $5, $6, $7, $8::vector, $9, $10, $11, $12) - ON CONFLICT(category, key) DO UPDATE SET - value = EXCLUDED.value, - updated_at = EXCLUDED.updated_at, - source_session_id = EXCLUDED.source_session_id, - embedding = EXCLUDED.embedding, - source = EXCLUDED.source, - confidence = EXCLUDED.confidence, - expires_at = EXCLUDED.expires_at, - source_context = EXCLUDED.source_context""", - str(uuid.uuid4()), category, key, value, now, now, - source_session_id, vec_str, source, confidence, expires_at, source_context, - ) - else: - await conn.execute( - """INSERT INTO memory_facts - (id, category, key, value, created_at, updated_at, source_session_id, - source, confidence, expires_at, source_context) - VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11) - ON CONFLICT(category, key) DO UPDATE SET - value = EXCLUDED.value, - updated_at = EXCLUDED.updated_at, - source_session_id = EXCLUDED.source_session_id, - source = EXCLUDED.source, - confidence = EXCLUDED.confidence, - expires_at = EXCLUDED.expires_at, - source_context = EXCLUDED.source_context""", - str(uuid.uuid4()), category, key, value, now, now, - source_session_id, source, confidence, expires_at, source_context, - ) - - async def search_facts(self, query: str, limit: int = 15) -> list[dict]: - # 1. Try vector search if pgvector + embedding backend are available - if self._embedding_backend and await self._has_pgvector(): - query_embedding = await self._generate_embedding(query) - vec_str = _vector_to_str(query_embedding) if query_embedding else None - if vec_str: - try: - pool = await self._get_pool() - async with pool.acquire() as conn: - rows = await conn.fetch( - """SELECT id, category, key, value, updated_at, - source, confidence, expires_at, source_context, - embedding <=> $1::vector AS distance - FROM memory_facts - WHERE (expires_at IS NULL OR expires_at > now()) - ORDER BY embedding <=> $1::vector - LIMIT $2""", - vec_str, limit, - ) - results = [ - _row_to_dict(r) - for r in rows - if r["distance"] < _VECTOR_CUTOFF_DISTANCE - ] - if results: - log.debug("memory.vector_search", hits=len(results), query=query[:40]) - return results - except Exception: - log.warning("memory.vector_search_failed", query=query[:40], exc_info=True) - else: - log.warning("memory.vector_fallback_to_ilike", reason="embed_failed", query=query[:40]) - else: - reason = "no_embedding_backend" if not self._embedding_backend else "no_pgvector" - log.warning("memory.vector_fallback_to_ilike", reason=reason, query=query[:40]) - - # 2. Text fallback (original ILIKE logic) - terms = _normalize_query(query) - if not terms: - return await self.get_all_facts(limit=limit) - - if await self.fact_count() <= _AUTO_DUMP_THRESHOLD: - return await self.get_all_facts(limit=limit) - - base_params: list = [] - term_conds: list[str] = [] - for term in terms: - like = f"%{term}%" - i = len(base_params) + 1 - term_conds.append( - f"(category ILIKE ${i} OR key ILIKE ${i + 1} OR value ILIKE ${i + 2})" - ) - base_params.extend([like, like, like]) - - limit_idx = len(base_params) + 1 - and_where = " AND ".join(term_conds) - or_where = " OR ".join(term_conds) - - pool = await self._get_pool() - async with pool.acquire() as conn: - rows = await conn.fetch( - f"SELECT id, category, key, value, updated_at, source, confidence, " - f"expires_at, source_context FROM memory_facts " - f"WHERE {and_where} ORDER BY updated_at DESC LIMIT ${limit_idx}", - *base_params, limit, - ) - if rows: - return [_row_to_dict(r) for r in rows] - - score_expr = " + ".join( - f"CASE WHEN {c} THEN 1 ELSE 0 END" for c in term_conds - ) - rows = await conn.fetch( - f"SELECT id, category, key, value, updated_at, source, confidence, " - f"expires_at, source_context, ({score_expr}) AS score " - f"FROM memory_facts WHERE {or_where} " - f"ORDER BY score DESC, updated_at DESC LIMIT ${limit_idx}", - *base_params, limit, - ) - return [_row_to_dict(r) for r in rows] - - async def delete_fact(self, key: str, category: str | None = None) -> int: - pool = await self._get_pool() - async with pool.acquire() as conn: - if category: - result = await conn.execute( - "DELETE FROM memory_facts WHERE LOWER(key)=LOWER($1) AND LOWER(category)=LOWER($2)", - key, category, - ) - else: - result = await conn.execute( - "DELETE FROM memory_facts WHERE LOWER(key)=LOWER($1)", key - ) - return int(result.split()[1]) - - async def get_all_facts(self, limit: int | None = None) -> list[dict]: - q = ( - "SELECT id, category, key, value, updated_at, source, confidence, " - "expires_at, source_context FROM memory_facts ORDER BY category, updated_at DESC" - ) - params: list = [] - if limit: - q += f" LIMIT ${len(params) + 1}" - params.append(limit) - pool = await self._get_pool() - async with pool.acquire() as conn: - rows = await conn.fetch(q, *params) - return [_row_to_dict(r) for r in rows] - - async def fact_count(self) -> int: - pool = await self._get_pool() - async with pool.acquire() as conn: - return await conn.fetchval("SELECT COUNT(*) FROM memory_facts") or 0 - - # ── Summary ─────────────────────────────────────────────────────────────── - - async def get_summary(self) -> str | None: - pool = await self._get_pool() - async with pool.acquire() as conn: - return await conn.fetchval("SELECT content FROM memory_summary WHERE id=1") - - async def set_summary(self, content: str) -> None: - now = datetime.now(timezone.utc) - pool = await self._get_pool() - async with pool.acquire() as conn: - await conn.execute( - """INSERT INTO memory_summary (id, content, generated_at) VALUES (1, $1, $2) - ON CONFLICT(id) DO UPDATE SET - content = EXCLUDED.content, - generated_at = EXCLUDED.generated_at""", - content, now, - ) - - # ── Session extraction tracking ─────────────────────────────────────────── - - async def mark_session_extracted(self, session_id: str) -> None: - now = datetime.now(timezone.utc) - pool = await self._get_pool() - async with pool.acquire() as conn: - await conn.execute( - """INSERT INTO session_memory_state (session_id, extracted_at) VALUES ($1, $2) - ON CONFLICT(session_id) DO UPDATE SET extracted_at=EXCLUDED.extracted_at""", - session_id, now, - ) - - async def get_extracted_at(self, session_id: str) -> str | None: - pool = await self._get_pool() - async with pool.acquire() as conn: - row = await conn.fetchrow( - "SELECT extracted_at FROM session_memory_state WHERE session_id=$1", session_id - ) - if row is None: - return None - val = row["extracted_at"] - return val.isoformat() if hasattr(val, "isoformat") else str(val) - - -def _row_to_dict(row: asyncpg.Record) -> dict: - val = row["updated_at"] - updated_at = val.isoformat() if hasattr(val, "isoformat") else str(val) - expires_at = row.get("expires_at") - expires_str = expires_at.isoformat() if hasattr(expires_at, "isoformat") else None - return { - "id": row["id"], - "category": row["category"], - "key": row["key"], - "value": row["value"], - "updated_at": updated_at, - "source": row.get("source", "conversation"), - "confidence": row.get("confidence", 70), - "expires_at": expires_str, - "source_context": row.get("source_context", ""), - }