"""Unit tests for RecallScheduler (mocked asyncpg)."""

from datetime import datetime, timedelta, timezone

import pytest

from navi.core.scheduler import Recall, RecallExistsError, RecallScheduler
from tests.conftest_factory import FakeConnection, FakeRecord, FakePool, make_scheduler_with_pool


def _recall_row(**overrides) -> FakeRecord:
    defaults = {
        "id": "r1",
        "session_id": "s1",
        "call_type": "once",
        "trigger_at": datetime.now(timezone.utc),
        "interval_seconds": None,
        "internal_comment": "test",
        "additional_context_message": "ctx",
        "status": "pending",
        "created_at": datetime.now(timezone.utc),
        "updated_at": datetime.now(timezone.utc),
    }
    defaults.update(overrides)
    return FakeRecord(**defaults)


class TestScheduleRecall:
    async def test_schedule_recall(self):
        conn = FakeConnection()
        conn.enqueue(_recall_row(id="r1", session_id="s1"))
        scheduler = make_scheduler_with_pool(conn)

        trigger = datetime.now(timezone.utc) + timedelta(hours=1)
        recall = await scheduler.schedule_recall(
            session_id="s1",
            call_type="once",
            trigger_at=trigger,
            additional_context_message="hello",
        )

        assert conn.calls[0][0] == "fetchrow"
        assert "INSERT INTO session_recalls" in conn.calls[0][1]
        assert isinstance(recall, Recall)
        assert recall.id == "r1"
        assert recall.session_id == "s1"

    async def test_schedule_duplicate_raises(self):
        conn = FakeConnection()
        # Simulate unique violation
        exc = Exception("duplicate key value violates unique constraint \"idx_recalls_active_pending\"")
        conn.enqueue(exc)
        scheduler = make_scheduler_with_pool(conn)

        trigger = datetime.now(timezone.utc) + timedelta(hours=1)
        with pytest.raises(RecallExistsError):
            await scheduler.schedule_recall(
                session_id="s1",
                call_type="once",
                trigger_at=trigger,
                additional_context_message="hello",
            )


class TestCancelRecall:
    async def test_cancel_recall(self):
        conn = FakeConnection()
        conn.enqueue("UPDATE 1")
        scheduler = make_scheduler_with_pool(conn)

        ok = await scheduler.cancel_recall("s1")
        assert ok is True
        assert conn.calls[0][0] == "execute"
        assert "UPDATE session_recalls" in conn.calls[0][1]
        assert "cancelled" in conn.calls[0][1]

    async def test_cancel_no_pending(self):
        conn = FakeConnection()
        conn.enqueue("UPDATE 0")
        scheduler = make_scheduler_with_pool(conn)

        ok = await scheduler.cancel_recall("s1")
        assert ok is False


class TestSkipNextRecall:
    async def test_skip_next_recall(self):
        conn = FakeConnection()
        conn.enqueue("UPDATE 1")
        scheduler = make_scheduler_with_pool(conn)

        ok = await scheduler.skip_next_recall("s1")
        assert ok is True
        assert conn.calls[0][0] == "execute"
        assert "interval_seconds" in conn.calls[0][1]

    async def test_skip_no_recurring(self):
        conn = FakeConnection()
        conn.enqueue("UPDATE 0")
        scheduler = make_scheduler_with_pool(conn)

        ok = await scheduler.skip_next_recall("s1")
        assert ok is False


class TestListRecalls:
    async def test_list_recalls_by_session(self):
        conn = FakeConnection()
        conn.enqueue([_recall_row(id="r1", session_id="s1")])
        scheduler = make_scheduler_with_pool(conn)

        recalls = await scheduler.list_recalls(session_id="s1")
        assert len(recalls) == 1
        assert recalls[0].id == "r1"
        assert conn.calls[0][0] == "fetch"

    async def test_list_recalls_admin_bypass(self):
        conn = FakeConnection()
        conn.enqueue([_recall_row(id="r1", session_id="s1")])
        scheduler = make_scheduler_with_pool(conn)

        recalls = await scheduler.list_recalls(session_id="s1", is_admin=True)
        assert len(recalls) == 1
        # Admin should not add user_id filter
        assert "user_id" not in conn.calls[0][1] or "$2" not in conn.calls[0][1]


class TestGetPendingRecalls:
    async def test_get_pending_recalls_ordered(self):
        conn = FakeConnection()
        now = datetime.now(timezone.utc)
        conn.enqueue([
            _recall_row(id="r1", trigger_at=now - timedelta(hours=1)),
            _recall_row(id="r2", trigger_at=now - timedelta(minutes=30)),
        ])
        scheduler = make_scheduler_with_pool(conn)

        recalls = await scheduler.get_pending_recalls(before=now)
        assert len(recalls) == 2
        assert conn.calls[0][0] == "fetch"
        assert "ORDER BY trigger_at ASC" in conn.calls[0][1]


class TestGetNextTriggerAt:
    async def test_get_next_trigger_at(self):
        conn = FakeConnection()
        conn.enqueue(_recall_row(trigger_at=datetime.now(timezone.utc) + timedelta(hours=1)))
        scheduler = make_scheduler_with_pool(conn)

        dt = await scheduler.get_next_trigger_at()
        assert dt is not None

    async def test_get_next_trigger_at_empty(self):
        conn = FakeConnection()
        conn.enqueue(None)
        scheduler = make_scheduler_with_pool(conn)

        dt = await scheduler.get_next_trigger_at()
        assert dt is None


class TestGetPendingSessionIds:
    async def test_get_pending_session_ids(self):
        conn = FakeConnection()
        conn.enqueue([
            FakeRecord(session_id="s1"),
            FakeRecord(session_id="s2"),
        ])
        scheduler = make_scheduler_with_pool(conn)

        ids = await scheduler.get_pending_session_ids(["s1", "s2", "s3"])
        assert ids == {"s1", "s2"}
        assert conn.calls[0][0] == "fetch"
        assert "ANY($1)" in conn.calls[0][1]

    async def test_get_pending_session_ids_empty(self):
        conn = FakeConnection()
        scheduler = make_scheduler_with_pool(conn)

        ids = await scheduler.get_pending_session_ids([])
        assert ids == set()
        assert len(conn.calls) == 0
