"""Unit tests for memory extractor."""

import pytest

from navi.llm.base import Message, ToolCallRequest
from navi.memory.extractor import extract_and_update, _extract_facts
from tests.conftest_factory import FakeConnection, FakeLLMBackend, FakeRecord, make_store_with_pool


class FakeSession:
    def __init__(self, messages, session_id="sess-1", user_id=None):
        self.id = session_id
        self.messages = messages
        self.user_id = user_id


class TestExtractFacts:
    async def test_extracts_single_fact(self):
        backend = FakeLLMBackend(responses=[
            '[{"category": "profile", "key": "name", "value": "Eugene", "source": "conversation"}]'
        ])
        conn = FakeConnection()
        conn.enqueue("INSERT 0 1")
        store = make_store_with_pool(conn)
        session = FakeSession([
            Message(role="user", content="My name is Eugene"),
        ])
        count = await _extract_facts(session, backend, "test-model", store)
        assert count == 1
        assert conn.calls[0][0] == "execute"
        assert "memory_facts" in conn.calls[0][1]

    async def test_invalid_json_returns_zero(self):
        backend = FakeLLMBackend(responses=["not json"])
        conn = FakeConnection()
        store = make_store_with_pool(conn)
        session = FakeSession([Message(role="user", content="hi")])
        count = await _extract_facts(session, backend, "test-model", store)
        assert count == 0
        assert len(conn.calls) == 0

    async def test_empty_array_returns_zero(self):
        backend = FakeLLMBackend(responses=["[]"])
        conn = FakeConnection()
        store = make_store_with_pool(conn)
        session = FakeSession([Message(role="user", content="hi")])
        count = await _extract_facts(session, backend, "test-model", store)
        assert count == 0

    async def test_skips_invalid_fact_dict(self):
        backend = FakeLLMBackend(responses=[
            '[{"category": "profile", "key": "name", "value": "Eugene"}, "not a dict"]'
        ])
        conn = FakeConnection()
        conn.enqueue("INSERT 0 1")
        store = make_store_with_pool(conn)
        session = FakeSession([Message(role="user", content="hi")])
        count = await _extract_facts(session, backend, "test-model", store)
        assert count == 1

    async def test_ignores_unknown_category(self):
        backend = FakeLLMBackend(responses=[
            '[{"category": "profile", "key": "name", "value": "Eugene"}]'
        ])
        conn = FakeConnection()
        conn.enqueue("INSERT 0 1")
        store = make_store_with_pool(conn)
        session = FakeSession([Message(role="user", content="hi")])
        count = await _extract_facts(session, backend, "test-model", store)
        assert count == 1

    async def test_maps_tool_call_source(self):
        backend = FakeLLMBackend(responses=[
            '[{"category": "technical", "key": "ip", "value": "10.0.0.1", "source": "tool_call"}]'
        ])
        conn = FakeConnection()
        conn.enqueue("INSERT 0 1")
        store = make_store_with_pool(conn)
        session = FakeSession([Message(role="user", content="hi")])
        count = await _extract_facts(session, backend, "test-model", store)
        assert count == 1
        # Confidence for tool_call is 95
        call_args = conn.calls[0][2]
        assert 95 in call_args  # confidence parameter

    async def test_truncates_long_tool_results(self):
        backend = FakeLLMBackend(responses=["[]"])
        conn = FakeConnection()
        store = make_store_with_pool(conn)
        long_result = "x" * 1000
        session = FakeSession([
            Message(
                role="assistant",
                tool_calls=[ToolCallRequest(id="1", name="terminal", arguments={})],
            ),
            Message(role="tool", content=long_result, name="terminal", tool_call_id="1"),
        ])
        await _extract_facts(session, backend, "test-model", store)
        # The transcript should be truncated
        prompt = backend._responses  # not directly accessible, but we can check the call

    async def test_no_messages_returns_zero(self):
        backend = FakeLLMBackend()
        conn = FakeConnection()
        store = make_store_with_pool(conn)
        session = FakeSession([])
        count = await _extract_facts(session, backend, "test-model", store)
        assert count == 0
        assert len(conn.calls) == 0


class TestExtractAndUpdate:
    async def test_calls_mark_extracted(self):
        backend = FakeLLMBackend(responses=["[]"])
        conn = FakeConnection()
        conn.enqueue("OK")  # mark_session_extracted
        store = make_store_with_pool(conn)
        session = FakeSession([Message(role="user", content="hi")], user_id="user-1")
        await extract_and_update(session, backend, "test-model", store)
        assert any("session_memory_state" in c[1] for c in conn.calls)

    async def test_regenerates_summary_when_facts_added(self):
        backend = FakeLLMBackend(responses=[
            '[{"category": "profile", "key": "name", "value": "Eugene"}]',  # extract
            "Summary: Eugene is the user.",  # regenerate summary
        ])
        conn = FakeConnection()
        conn.enqueue("INSERT 0 1")  # upsert_fact
        conn.enqueue("OK")  # mark_session_extracted
        conn.enqueue([  # get_all_facts in _regenerate_summary
            FakeRecord(id="1", category="profile", key="name", value="Eugene",
                       updated_at=None, source="conversation", confidence=90,
                       expires_at=None, source_context=""),
        ])
        conn.enqueue("OK")  # set_summary
        store = make_store_with_pool(conn)
        session = FakeSession([Message(role="user", content="My name is Eugene")], user_id="user-1")
        await extract_and_update(session, backend, "test-model", store)
        assert any("memory_summary" in c[1] for c in conn.calls)

    async def test_no_summary_regeneration_when_no_facts(self):
        backend = FakeLLMBackend(responses=["[]"])
        conn = FakeConnection()
        conn.enqueue("OK")  # mark_session_extracted
        store = make_store_with_pool(conn)
        session = FakeSession([Message(role="user", content="hi")], user_id="user-1")
        await extract_and_update(session, backend, "test-model", store)
        # Should NOT call get_all_facts or set_summary
        assert not any("memory_summary" in c[1] for c in conn.calls)

    async def test_skips_legacy_sessions(self):
        backend = FakeLLMBackend(responses=["[]"])
        conn = FakeConnection()
        store = make_store_with_pool(conn)
        session = FakeSession([Message(role="user", content="hi")], user_id=None)
        await extract_and_update(session, backend, "test-model", store)
        # No DB calls for legacy (user_id=None) sessions
        assert len(conn.calls) == 0
