"""Unit tests for MemoryStore (mocked asyncpg)."""
import pytest
from tests.conftest_factory import FakeConnection, FakeRecord, FakePool, make_store_with_pool
class TestUpsertFact:
async def test_calls_execute(self):
conn = FakeConnection()
conn.enqueue("INSERT 0 1")
store = make_store_with_pool(conn)
await store.upsert_fact(category="profile", key="name", value="Eugene")
assert conn.calls[0][0] == "execute"
assert "memory_facts" in conn.calls[0][1]
async def test_without_embedding(self):
conn = FakeConnection()
conn.enqueue("INSERT 0 1") # no embedding column branch
store = make_store_with_pool(conn)
store._embedding_backend = None
await store.upsert_fact(category="profile", key="name", value="Eugene")
call = conn.calls[0]
# Should use the branch without embedding
assert "embedding" not in call[1] or "$8" in call[1] # no vector parameter
class TestSearchFacts:
async def test_vector_search_happy_path(self):
conn = FakeConnection()
conn.enqueue([FakeRecord(id="1", category="profile", key="name", value="Eugene",
updated_at=None, source="conversation", confidence=90,
expires_at=None, source_context="", distance=0.1)])
store = make_store_with_pool(conn)
store._pgvector_checked = True
store._pgvector_available = True
store._embedding_backend = object() # any truthy object
# Mock _generate_embedding to avoid hitting the backend
async def _fake_embed(text: str):
return [0.1] * 768
store._generate_embedding = _fake_embed
results = await store.search_facts("name", limit=5)
assert len(results) == 1
assert results[0]["key"] == "name"
async def test_fallback_to_ilike_no_pgvector(self):
conn = FakeConnection()
conn.enqueue(0) # COUNT(*)
conn.enqueue([FakeRecord(id="1", category="profile", key="name", value="Eugene",
updated_at=None, source="conversation", confidence=90,
expires_at=None, source_context="")])
store = make_store_with_pool(conn)
store._pgvector_checked = True
store._pgvector_available = False
results = await store.search_facts("eugene", limit=5)
assert len(results) == 1
async def test_fallback_auto_dump_below_threshold(self):
conn = FakeConnection()
conn.enqueue(5) # fact_count <= threshold
conn.enqueue([FakeRecord(id="1", category="profile", key="name", value="Eugene",
updated_at=None, source="conversation", confidence=90,
expires_at=None, source_context="")])
store = make_store_with_pool(conn)
store._pgvector_checked = True
store._pgvector_available = False
results = await store.search_facts("anything", limit=5)
assert len(results) == 1
# Should have done get_all_facts instead of ILIKE
assert "ORDER BY category" in conn.calls[1][1]
async def test_fallback_no_terms(self):
conn = FakeConnection()
conn.enqueue([]) # get_all_facts returns empty
store = make_store_with_pool(conn)
store._pgvector_checked = True
store._pgvector_available = False
results = await store.search_facts("a", limit=5)
# single-char query normalizes to empty -> get_all_facts
assert len(results) == 0
class TestDeleteFact:
async def test_by_key(self):
conn = FakeConnection()
conn.enqueue("DELETE 1")
store = make_store_with_pool(conn)
count = await store.delete_fact("name")
assert count == 1
assert "DELETE FROM memory_facts" in conn.calls[0][1]
async def test_by_key_returns_zero(self):
conn = FakeConnection()
conn.enqueue("DELETE 0")
store = make_store_with_pool(conn)
count = await store.delete_fact("missing")
assert count == 0
async def test_by_key_and_category(self):
conn = FakeConnection()
conn.enqueue("DELETE 0 1")
store = make_store_with_pool(conn)
count = await store.delete_fact("name", category="profile")
assert "category" in conn.calls[0][1]
class TestGetAllFacts:
async def test_returns_records(self):
conn = FakeConnection()
conn.enqueue([FakeRecord(id="1", category="profile", key="name", value="Eugene",
updated_at=None, source="conversation", confidence=90,
expires_at=None, source_context="")])
store = make_store_with_pool(conn)
results = await store.get_all_facts()
assert len(results) == 1
assert results[0]["key"] == "name"
async def test_with_limit(self):
conn = FakeConnection()
conn.enqueue([])
store = make_store_with_pool(conn)
await store.get_all_facts(limit=5)
assert "LIMIT $1" in conn.calls[0][1]
async def test_all_users(self):
conn = FakeConnection()
conn.enqueue([FakeRecord(id="1", category="profile", key="name", value="Eugene",
updated_at=None, source="conversation", confidence=90,
expires_at=None, source_context="")])
store = make_store_with_pool(conn)
results = await store.get_all_facts(all_users=True)
assert len(results) == 1
# Should not filter by user_id at all
assert "user_id" not in conn.calls[0][1]
class TestFactCount:
async def test_returns_count(self):
conn = FakeConnection()
conn.enqueue(42)
store = make_store_with_pool(conn)
assert await store.fact_count() == 42
async def test_all_users(self):
conn = FakeConnection()
conn.enqueue(100)
store = make_store_with_pool(conn)
assert await store.fact_count(all_users=True) == 100
assert "WHERE" not in conn.calls[0][1]
class TestSummary:
async def test_get_summary(self):
conn = FakeConnection()
conn.enqueue("User likes Python.")
store = make_store_with_pool(conn)
assert await store.get_summary() == "User likes Python."
async def test_set_summary(self):
conn = FakeConnection()
conn.enqueue("OK")
store = make_store_with_pool(conn)
await store.set_summary("New summary")
assert "memory_summary" in conn.calls[0][1]
class TestSessionState:
async def test_mark_extracted(self):
conn = FakeConnection()
conn.enqueue("OK")
store = make_store_with_pool(conn)
await store.mark_session_extracted("sess-1")
assert "session_memory_state" in conn.calls[0][1]
async def test_get_extracted_at(self):
from datetime import datetime, timezone
now = datetime.now(timezone.utc)
conn = FakeConnection()
conn.enqueue(FakeRecord(extracted_at=now))
store = make_store_with_pool(conn)
result = await store.get_extracted_at("sess-1")
assert result == now.isoformat()
async def test_get_extracted_at_none(self):
conn = FakeConnection()
conn.enqueue(None)
store = make_store_with_pool(conn)
assert await store.get_extracted_at("sess-1") is None
class TestBackfillEmbeddings:
async def test_updates_rows(self):
conn = FakeConnection()
# First batch: 2 rows
conn.enqueue([
FakeRecord(id="1", value="hello"),
FakeRecord(id="2", value="world"),
])
# No more rows
conn.enqueue([])
# executemany response
conn.enqueue(None)
store = make_store_with_pool(conn)
store._pgvector_checked = True
store._pgvector_available = True
store._embedding_backend = object()
async def _fake_embeds(texts: list[str]):
return [[0.1] * 768 for _ in texts]
store._generate_embeddings = _fake_embeds
updated = await store.backfill_embeddings(batch_size=2)
assert updated == 2
assert conn.calls[-2][0] == "executemany"
assert "UPDATE memory_facts SET embedding" in conn.calls[-2][1]