"""Unit tests for memory extractor."""
import pytest
from navi.llm.base import Message, ToolCallRequest
from navi.memory.extractor import extract_and_update, _extract_facts
from tests.conftest_factory import FakeConnection, FakeLLMBackend, FakeRecord, make_store_with_pool
class FakeSession:
def __init__(self, messages, session_id="sess-1", user_id=None):
self.id = session_id
self.messages = messages
self.user_id = user_id
class TestExtractFacts:
async def test_extracts_single_fact(self):
backend = FakeLLMBackend(responses=[
'[{"category": "profile", "key": "name", "value": "Eugene", "source": "conversation"}]'
])
conn = FakeConnection()
conn.enqueue("INSERT 0 1")
store = make_store_with_pool(conn)
session = FakeSession([
Message(role="user", content="My name is Eugene"),
])
count = await _extract_facts(session, backend, "test-model", store)
assert count == 1
assert conn.calls[0][0] == "execute"
assert "memory_facts" in conn.calls[0][1]
async def test_invalid_json_returns_zero(self):
backend = FakeLLMBackend(responses=["not json"])
conn = FakeConnection()
store = make_store_with_pool(conn)
session = FakeSession([Message(role="user", content="hi")])
count = await _extract_facts(session, backend, "test-model", store)
assert count == 0
assert len(conn.calls) == 0
async def test_empty_array_returns_zero(self):
backend = FakeLLMBackend(responses=["[]"])
conn = FakeConnection()
store = make_store_with_pool(conn)
session = FakeSession([Message(role="user", content="hi")])
count = await _extract_facts(session, backend, "test-model", store)
assert count == 0
async def test_skips_invalid_fact_dict(self):
backend = FakeLLMBackend(responses=[
'[{"category": "profile", "key": "name", "value": "Eugene"}, "not a dict"]'
])
conn = FakeConnection()
conn.enqueue("INSERT 0 1")
store = make_store_with_pool(conn)
session = FakeSession([Message(role="user", content="hi")])
count = await _extract_facts(session, backend, "test-model", store)
assert count == 1
async def test_ignores_unknown_category(self):
backend = FakeLLMBackend(responses=[
'[{"category": "profile", "key": "name", "value": "Eugene"}]'
])
conn = FakeConnection()
conn.enqueue("INSERT 0 1")
store = make_store_with_pool(conn)
session = FakeSession([Message(role="user", content="hi")])
count = await _extract_facts(session, backend, "test-model", store)
assert count == 1
async def test_maps_tool_call_source(self):
backend = FakeLLMBackend(responses=[
'[{"category": "technical", "key": "ip", "value": "10.0.0.1", "source": "tool_call"}]'
])
conn = FakeConnection()
conn.enqueue("INSERT 0 1")
store = make_store_with_pool(conn)
session = FakeSession([Message(role="user", content="hi")])
count = await _extract_facts(session, backend, "test-model", store)
assert count == 1
# Confidence for tool_call is 95
call_args = conn.calls[0][2]
assert 95 in call_args # confidence parameter
async def test_truncates_long_tool_results(self):
backend = FakeLLMBackend(responses=["[]"])
conn = FakeConnection()
store = make_store_with_pool(conn)
long_result = "x" * 1000
session = FakeSession([
Message(
role="assistant",
tool_calls=[ToolCallRequest(id="1", name="terminal", arguments={})],
),
Message(role="tool", content=long_result, name="terminal", tool_call_id="1"),
])
await _extract_facts(session, backend, "test-model", store)
# The transcript should be truncated
prompt = backend._responses # not directly accessible, but we can check the call
async def test_no_messages_returns_zero(self):
backend = FakeLLMBackend()
conn = FakeConnection()
store = make_store_with_pool(conn)
session = FakeSession([])
count = await _extract_facts(session, backend, "test-model", store)
assert count == 0
assert len(conn.calls) == 0
class TestExtractAndUpdate:
async def test_calls_mark_extracted(self):
backend = FakeLLMBackend(responses=["[]"])
conn = FakeConnection()
conn.enqueue("OK") # mark_session_extracted
store = make_store_with_pool(conn)
session = FakeSession([Message(role="user", content="hi")], user_id="user-1")
await extract_and_update(session, backend, "test-model", store)
assert any("session_memory_state" in c[1] for c in conn.calls)
async def test_regenerates_summary_when_facts_added(self):
backend = FakeLLMBackend(responses=[
'[{"category": "profile", "key": "name", "value": "Eugene"}]', # extract
"Summary: Eugene is the user.", # regenerate summary
])
conn = FakeConnection()
conn.enqueue("INSERT 0 1") # upsert_fact
conn.enqueue("OK") # mark_session_extracted
conn.enqueue([ # get_all_facts in _regenerate_summary
FakeRecord(id="1", category="profile", key="name", value="Eugene",
updated_at=None, source="conversation", confidence=90,
expires_at=None, source_context=""),
])
conn.enqueue("OK") # set_summary
store = make_store_with_pool(conn)
session = FakeSession([Message(role="user", content="My name is Eugene")], user_id="user-1")
await extract_and_update(session, backend, "test-model", store)
assert any("memory_summary" in c[1] for c in conn.calls)
async def test_no_summary_regeneration_when_no_facts(self):
backend = FakeLLMBackend(responses=["[]"])
conn = FakeConnection()
conn.enqueue("OK") # mark_session_extracted
store = make_store_with_pool(conn)
session = FakeSession([Message(role="user", content="hi")], user_id="user-1")
await extract_and_update(session, backend, "test-model", store)
# Should NOT call get_all_facts or set_summary
assert not any("memory_summary" in c[1] for c in conn.calls)
async def test_skips_legacy_sessions(self):
backend = FakeLLMBackend(responses=["[]"])
conn = FakeConnection()
store = make_store_with_pool(conn)
session = FakeSession([Message(role="user", content="hi")], user_id=None)
await extract_and_update(session, backend, "test-model", store)
# No DB calls for legacy (user_id=None) sessions
assert len(conn.calls) == 0