"""Unit tests for RecallScheduler (mocked asyncpg)."""
from datetime import datetime, timedelta, timezone
import pytest
from navi.core.scheduler import Recall, RecallExistsError, RecallScheduler
from tests.conftest_factory import FakeConnection, FakeRecord, FakePool, make_scheduler_with_pool
def _recall_row(**overrides) -> FakeRecord:
defaults = {
"id": "r1",
"session_id": "s1",
"call_type": "once",
"trigger_at": datetime.now(timezone.utc),
"interval_seconds": None,
"internal_comment": "test",
"additional_context_message": "ctx",
"status": "pending",
"created_at": datetime.now(timezone.utc),
"updated_at": datetime.now(timezone.utc),
}
defaults.update(overrides)
return FakeRecord(**defaults)
class TestScheduleRecall:
async def test_schedule_recall(self):
conn = FakeConnection()
conn.enqueue(_recall_row(id="r1", session_id="s1"))
scheduler = make_scheduler_with_pool(conn)
trigger = datetime.now(timezone.utc) + timedelta(hours=1)
recall = await scheduler.schedule_recall(
session_id="s1",
call_type="once",
trigger_at=trigger,
additional_context_message="hello",
)
assert conn.calls[0][0] == "fetchrow"
assert "INSERT INTO session_recalls" in conn.calls[0][1]
assert isinstance(recall, Recall)
assert recall.id == "r1"
assert recall.session_id == "s1"
async def test_schedule_duplicate_raises(self):
conn = FakeConnection()
# Simulate unique violation
exc = Exception("duplicate key value violates unique constraint \"idx_recalls_active_pending\"")
conn.enqueue(exc)
scheduler = make_scheduler_with_pool(conn)
trigger = datetime.now(timezone.utc) + timedelta(hours=1)
with pytest.raises(RecallExistsError):
await scheduler.schedule_recall(
session_id="s1",
call_type="once",
trigger_at=trigger,
additional_context_message="hello",
)
class TestCancelRecall:
async def test_cancel_recall(self):
conn = FakeConnection()
conn.enqueue("UPDATE 1")
scheduler = make_scheduler_with_pool(conn)
ok = await scheduler.cancel_recall("s1")
assert ok is True
assert conn.calls[0][0] == "execute"
assert "UPDATE session_recalls" in conn.calls[0][1]
assert "cancelled" in conn.calls[0][1]
async def test_cancel_no_pending(self):
conn = FakeConnection()
conn.enqueue("UPDATE 0")
scheduler = make_scheduler_with_pool(conn)
ok = await scheduler.cancel_recall("s1")
assert ok is False
class TestSkipNextRecall:
async def test_skip_next_recall(self):
conn = FakeConnection()
conn.enqueue("UPDATE 1")
scheduler = make_scheduler_with_pool(conn)
ok = await scheduler.skip_next_recall("s1")
assert ok is True
assert conn.calls[0][0] == "execute"
assert "interval_seconds" in conn.calls[0][1]
async def test_skip_no_recurring(self):
conn = FakeConnection()
conn.enqueue("UPDATE 0")
scheduler = make_scheduler_with_pool(conn)
ok = await scheduler.skip_next_recall("s1")
assert ok is False
class TestListRecalls:
async def test_list_recalls_by_session(self):
conn = FakeConnection()
conn.enqueue([_recall_row(id="r1", session_id="s1")])
scheduler = make_scheduler_with_pool(conn)
recalls = await scheduler.list_recalls(session_id="s1")
assert len(recalls) == 1
assert recalls[0].id == "r1"
assert conn.calls[0][0] == "fetch"
async def test_list_recalls_admin_bypass(self):
conn = FakeConnection()
conn.enqueue([_recall_row(id="r1", session_id="s1")])
scheduler = make_scheduler_with_pool(conn)
recalls = await scheduler.list_recalls(session_id="s1", is_admin=True)
assert len(recalls) == 1
# Admin should not add user_id filter
assert "user_id" not in conn.calls[0][1] or "$2" not in conn.calls[0][1]
class TestGetPendingRecalls:
async def test_get_pending_recalls_ordered(self):
conn = FakeConnection()
now = datetime.now(timezone.utc)
conn.enqueue([
_recall_row(id="r1", trigger_at=now - timedelta(hours=1)),
_recall_row(id="r2", trigger_at=now - timedelta(minutes=30)),
])
scheduler = make_scheduler_with_pool(conn)
recalls = await scheduler.get_pending_recalls(before=now)
assert len(recalls) == 2
assert conn.calls[0][0] == "fetch"
assert "ORDER BY trigger_at ASC" in conn.calls[0][1]
class TestGetNextTriggerAt:
async def test_get_next_trigger_at(self):
conn = FakeConnection()
conn.enqueue(_recall_row(trigger_at=datetime.now(timezone.utc) + timedelta(hours=1)))
scheduler = make_scheduler_with_pool(conn)
dt = await scheduler.get_next_trigger_at()
assert dt is not None
async def test_get_next_trigger_at_empty(self):
conn = FakeConnection()
conn.enqueue(None)
scheduler = make_scheduler_with_pool(conn)
dt = await scheduler.get_next_trigger_at()
assert dt is None
class TestGetPendingSessionIds:
async def test_get_pending_session_ids(self):
conn = FakeConnection()
conn.enqueue([
FakeRecord(session_id="s1"),
FakeRecord(session_id="s2"),
])
scheduler = make_scheduler_with_pool(conn)
ids = await scheduler.get_pending_session_ids(["s1", "s2", "s3"])
assert ids == {"s1", "s2"}
assert conn.calls[0][0] == "fetch"
assert "ANY($1)" in conn.calls[0][1]
async def test_get_pending_session_ids_empty(self):
conn = FakeConnection()
scheduler = make_scheduler_with_pool(conn)
ids = await scheduler.get_pending_session_ids([])
assert ids == set()
assert len(conn.calls) == 0