Newer
Older
navi-1 / tests / unit / memory / test_store.py
"""Unit tests for MemoryStore (mocked asyncpg)."""

import pytest

from tests.conftest_factory import FakeConnection, FakeRecord, FakePool, make_store_with_pool


class TestUpsertFact:
    async def test_calls_execute(self):
        conn = FakeConnection()
        conn.enqueue("INSERT 0 1")
        store = make_store_with_pool(conn)
        await store.upsert_fact(category="profile", key="name", value="Eugene")
        assert conn.calls[0][0] == "execute"
        assert "memory_facts" in conn.calls[0][1]

    async def test_without_embedding(self):
        conn = FakeConnection()
        conn.enqueue("INSERT 0 1")  # no embedding column branch
        store = make_store_with_pool(conn)
        store._embedding_backend = None
        await store.upsert_fact(category="profile", key="name", value="Eugene")
        call = conn.calls[0]
        # Should use the branch without embedding
        assert "embedding" not in call[1] or "$8" in call[1]  # no vector parameter


class TestSearchFacts:
    async def test_vector_search_happy_path(self):
        conn = FakeConnection()
        conn.enqueue([FakeRecord(id="1", category="profile", key="name", value="Eugene",
                                  updated_at=None, source="conversation", confidence=90,
                                  expires_at=None, source_context="", distance=0.1)])
        store = make_store_with_pool(conn)
        store._pgvector_checked = True
        store._pgvector_available = True
        store._embedding_backend = object()  # any truthy object

        # Mock _generate_embedding to avoid hitting the backend
        async def _fake_embed(text: str):
            return [0.1] * 768

        store._generate_embedding = _fake_embed

        results = await store.search_facts("name", limit=5)
        assert len(results) == 1
        assert results[0]["key"] == "name"

    async def test_fallback_to_ilike_no_pgvector(self):
        conn = FakeConnection()
        conn.enqueue(0)  # COUNT(*)
        conn.enqueue([FakeRecord(id="1", category="profile", key="name", value="Eugene",
                                  updated_at=None, source="conversation", confidence=90,
                                  expires_at=None, source_context="")])
        store = make_store_with_pool(conn)
        store._pgvector_checked = True
        store._pgvector_available = False

        results = await store.search_facts("eugene", limit=5)
        assert len(results) == 1

    async def test_fallback_auto_dump_below_threshold(self):
        conn = FakeConnection()
        conn.enqueue(5)  # fact_count <= threshold
        conn.enqueue([FakeRecord(id="1", category="profile", key="name", value="Eugene",
                                  updated_at=None, source="conversation", confidence=90,
                                  expires_at=None, source_context="")])
        store = make_store_with_pool(conn)
        store._pgvector_checked = True
        store._pgvector_available = False

        results = await store.search_facts("anything", limit=5)
        assert len(results) == 1
        # Should have done get_all_facts instead of ILIKE
        assert "ORDER BY category" in conn.calls[1][1]

    async def test_fallback_no_terms(self):
        conn = FakeConnection()
        conn.enqueue([])  # get_all_facts returns empty
        store = make_store_with_pool(conn)
        store._pgvector_checked = True
        store._pgvector_available = False

        results = await store.search_facts("a", limit=5)
        # single-char query normalizes to empty -> get_all_facts
        assert len(results) == 0


class TestDeleteFact:
    async def test_by_key(self):
        conn = FakeConnection()
        conn.enqueue("DELETE 1")
        store = make_store_with_pool(conn)
        count = await store.delete_fact("name")
        assert count == 1
        assert "DELETE FROM memory_facts" in conn.calls[0][1]

    async def test_by_key_returns_zero(self):
        conn = FakeConnection()
        conn.enqueue("DELETE 0")
        store = make_store_with_pool(conn)
        count = await store.delete_fact("missing")
        assert count == 0

    async def test_by_key_and_category(self):
        conn = FakeConnection()
        conn.enqueue("DELETE 0 1")
        store = make_store_with_pool(conn)
        count = await store.delete_fact("name", category="profile")
        assert "category" in conn.calls[0][1]


class TestGetAllFacts:
    async def test_returns_records(self):
        conn = FakeConnection()
        conn.enqueue([FakeRecord(id="1", category="profile", key="name", value="Eugene",
                                updated_at=None, source="conversation", confidence=90,
                                expires_at=None, source_context="")])
        store = make_store_with_pool(conn)
        results = await store.get_all_facts()
        assert len(results) == 1
        assert results[0]["key"] == "name"

    async def test_with_limit(self):
        conn = FakeConnection()
        conn.enqueue([])
        store = make_store_with_pool(conn)
        await store.get_all_facts(limit=5)
        assert "LIMIT $1" in conn.calls[0][1]

    async def test_all_users(self):
        conn = FakeConnection()
        conn.enqueue([FakeRecord(id="1", category="profile", key="name", value="Eugene",
                                updated_at=None, source="conversation", confidence=90,
                                expires_at=None, source_context="")])
        store = make_store_with_pool(conn)
        results = await store.get_all_facts(all_users=True)
        assert len(results) == 1
        # Should not filter by user_id at all
        assert "user_id" not in conn.calls[0][1]


class TestFactCount:
    async def test_returns_count(self):
        conn = FakeConnection()
        conn.enqueue(42)
        store = make_store_with_pool(conn)
        assert await store.fact_count() == 42

    async def test_all_users(self):
        conn = FakeConnection()
        conn.enqueue(100)
        store = make_store_with_pool(conn)
        assert await store.fact_count(all_users=True) == 100
        assert "WHERE" not in conn.calls[0][1]


class TestSummary:
    async def test_get_summary(self):
        conn = FakeConnection()
        conn.enqueue("User likes Python.")
        store = make_store_with_pool(conn)
        assert await store.get_summary() == "User likes Python."

    async def test_set_summary(self):
        conn = FakeConnection()
        conn.enqueue("OK")
        store = make_store_with_pool(conn)
        await store.set_summary("New summary")
        assert "memory_summary" in conn.calls[0][1]


class TestSessionState:
    async def test_mark_extracted(self):
        conn = FakeConnection()
        conn.enqueue("OK")
        store = make_store_with_pool(conn)
        await store.mark_session_extracted("sess-1")
        assert "session_memory_state" in conn.calls[0][1]

    async def test_get_extracted_at(self):
        from datetime import datetime, timezone

        now = datetime.now(timezone.utc)
        conn = FakeConnection()
        conn.enqueue(FakeRecord(extracted_at=now))
        store = make_store_with_pool(conn)
        result = await store.get_extracted_at("sess-1")
        assert result == now.isoformat()

    async def test_get_extracted_at_none(self):
        conn = FakeConnection()
        conn.enqueue(None)
        store = make_store_with_pool(conn)
        assert await store.get_extracted_at("sess-1") is None


class TestBackfillEmbeddings:
    async def test_updates_rows(self):
        conn = FakeConnection()
        # First batch: 2 rows
        conn.enqueue([
            FakeRecord(id="1", value="hello"),
            FakeRecord(id="2", value="world"),
        ])
        # No more rows
        conn.enqueue([])
        # executemany response
        conn.enqueue(None)
        store = make_store_with_pool(conn)
        store._pgvector_checked = True
        store._pgvector_available = True
        store._embedding_backend = object()

        async def _fake_embeds(texts: list[str]):
            return [[0.1] * 768 for _ in texts]

        store._generate_embeddings = _fake_embeds

        updated = await store.backfill_embeddings(batch_size=2)
        assert updated == 2
        assert conn.calls[-2][0] == "executemany"
        assert "UPDATE memory_facts SET embedding" in conn.calls[-2][1]