Newer
Older
navi-1 / tests / integration / test_scheduler_loop.py
"""Integration tests for the scheduler background loop."""

import asyncio
from datetime import datetime, timedelta, timezone
from unittest.mock import AsyncMock, MagicMock

import pytest

from navi.core.events import StreamEnd
from navi.core.orchestrator import AgentSessionOrchestrator
from navi.core.scheduler import Recall, _fire_recall, recall_scheduler_loop


@pytest.fixture
def fake_orchestrator():
    container = MagicMock()
    container.profile_registry = None
    container.tool_registry = None
    container.backend_registry = None
    container.cp_registry = None
    container.workers = []
    container.memory_store = None
    container.mcp_manager = None
    return AgentSessionOrchestrator(container)


@pytest.fixture(autouse=True)
def patch_scheduler_deps(monkeypatch):
    """Prevent real _fire_recall from triggering heavy dependency initialization."""
    monkeypatch.setattr("navi.api.deps.get_workers", lambda: [])
    monkeypatch.setattr("navi.api.deps.get_memory_store", lambda: None)
    monkeypatch.setattr("navi.api.deps.get_mcp_manager", AsyncMock(return_value=None))


class TestSchedulerLoop:
    @pytest.mark.anyio
    async def test_loop_fires_past_due_recalls(self, monkeypatch):
        scheduler = AsyncMock()
        scheduler.get_pending_recalls.return_value = [
            Recall(
                id="r1", session_id="s1", call_type="once",
                trigger_at=datetime.now(timezone.utc) - timedelta(minutes=5),
                interval_seconds=None, internal_comment=None,
                additional_context_message="ctx", status="pending",
                created_at=datetime.now(timezone.utc),
                updated_at=datetime.now(timezone.utc),
            )
        ]
        scheduler.get_next_trigger_at.return_value = None
        scheduler.mark_fired.return_value = None

        store = AsyncMock()
        store.get.return_value = AsyncMock()
        store.get.return_value.messages = []
        store.get.return_value.context = []

        # Patch _fire_recall to avoid full Agent construction
        fire_calls = []
        async def _fake_fire(semaphore, recall, scheduler, store, orchestrator):
            fire_calls.append(recall.id)
            await scheduler.mark_fired(recall.id)

        monkeypatch.setattr("navi.core.scheduler._fire_recall", _fake_fire)

        # Run loop for one iteration then cancel
        task = asyncio.create_task(recall_scheduler_loop(scheduler, store, AsyncMock()))
        await asyncio.sleep(0.1)
        task.cancel()
        try:
            await task
        except asyncio.CancelledError:
            pass

        assert "r1" in fire_calls

    @pytest.mark.anyio
    async def test_loop_respects_semaphore(self, monkeypatch, fake_orchestrator):
        scheduler = AsyncMock()
        scheduler.get_pending_recalls.return_value = [
            Recall(
                id=f"r{i}", session_id=f"s{i}", call_type="once",
                trigger_at=datetime.now(timezone.utc) - timedelta(minutes=5),
                interval_seconds=None, internal_comment=None,
                additional_context_message="ctx", status="pending",
                created_at=datetime.now(timezone.utc),
                updated_at=datetime.now(timezone.utc),
            )
            for i in range(5)
        ]
        scheduler.get_next_trigger_at.return_value = None

        store = AsyncMock()

        running = asyncio.Semaphore(0)
        max_concurrent = 0
        current = 0

        async def _slow_fire(semaphore, recall, scheduler, store, orchestrator):
            nonlocal max_concurrent, current
            async with semaphore:
                current += 1
                max_concurrent = max(max_concurrent, current)
                await asyncio.sleep(0.1)
                current -= 1
                await scheduler.mark_fired(recall.id)

        monkeypatch.setattr("navi.core.scheduler._fire_recall", _slow_fire)

        task = asyncio.create_task(recall_scheduler_loop(scheduler, store, fake_orchestrator))
        await asyncio.sleep(0.15)
        task.cancel()
        try:
            await task
        except asyncio.CancelledError:
            pass

        assert max_concurrent <= 3

    @pytest.mark.anyio
    async def test_loop_defers_when_session_busy(self, monkeypatch, fake_orchestrator):
        scheduler = AsyncMock()
        scheduler.get_pending_recalls.return_value = [
            Recall(
                id="r1", session_id="s1", call_type="once",
                trigger_at=datetime.now(timezone.utc) - timedelta(minutes=5),
                interval_seconds=None, internal_comment=None,
                additional_context_message="ctx", status="pending",
                created_at=datetime.now(timezone.utc),
                updated_at=datetime.now(timezone.utc),
            )
        ]
        scheduler.get_next_trigger_at.return_value = None
        scheduler.reschedule.return_value = None

        store = AsyncMock()

        # Simulate an active websocket run in the orchestrator
        fake_orchestrator.create_run("s1")

        async def _fake_fire(semaphore, recall, scheduler, store, orchestrator):
            await _fire_recall(semaphore, recall, scheduler, store, orchestrator)

        monkeypatch.setattr("navi.core.scheduler._fire_recall", _fake_fire)

        task = asyncio.create_task(recall_scheduler_loop(scheduler, store, fake_orchestrator))
        await asyncio.sleep(0.1)
        task.cancel()
        try:
            await task
        except asyncio.CancelledError:
            pass
        finally:
            fake_orchestrator._sessions.pop("s1", None)

        scheduler.reschedule.assert_called_once()

    @pytest.mark.anyio
    async def test_loop_cancels_when_session_missing(self, monkeypatch, fake_orchestrator):
        scheduler = AsyncMock()
        scheduler.get_pending_recalls.return_value = [
            Recall(
                id="r1", session_id="s1", call_type="once",
                trigger_at=datetime.now(timezone.utc) - timedelta(minutes=5),
                interval_seconds=None, internal_comment=None,
                additional_context_message="ctx", status="pending",
                created_at=datetime.now(timezone.utc),
                updated_at=datetime.now(timezone.utc),
            )
        ]
        scheduler.get_next_trigger_at.return_value = None
        scheduler.mark_cancelled.return_value = None

        store = AsyncMock()
        store.get.return_value = None

        async def _fake_fire(semaphore, recall, scheduler, store, orchestrator):
            await _fire_recall(semaphore, recall, scheduler, store, orchestrator)

        monkeypatch.setattr("navi.core.scheduler._fire_recall", _fake_fire)

        task = asyncio.create_task(recall_scheduler_loop(scheduler, store, fake_orchestrator))
        await asyncio.sleep(0.1)
        task.cancel()
        try:
            await task
        except asyncio.CancelledError:
            pass

        scheduler.mark_cancelled.assert_called_once_with("r1")

    @pytest.mark.anyio
    async def test_recurring_rescheduled_on_success(self, monkeypatch, fake_orchestrator):
        scheduler = AsyncMock()
        recall = Recall(
            id="r1", session_id="s1", call_type="recurring",
            trigger_at=datetime.now(timezone.utc) - timedelta(minutes=5),
            interval_seconds=3600, internal_comment=None,
            additional_context_message="ctx", status="pending",
            created_at=datetime.now(timezone.utc),
            updated_at=datetime.now(timezone.utc),
        )
        scheduler.get_pending_recalls.return_value = [recall]
        scheduler.get_next_trigger_at.return_value = None
        scheduler.reschedule.return_value = None

        store = AsyncMock()
        store.get.return_value = AsyncMock()
        store.get.return_value.messages = []
        store.get.return_value.context = []

        # Mock Agent as a plain object with an async-generator run_stream
        class FakeAgent:
            async def run_stream(self, *a, **kw):
                yield StreamEnd(full_content="done")

        monkeypatch.setattr("navi.core.agent.Agent", lambda *a, **kw: FakeAgent())

        task = asyncio.create_task(recall_scheduler_loop(scheduler, store, fake_orchestrator))
        await asyncio.sleep(0.1)
        task.cancel()
        try:
            await task
        except asyncio.CancelledError:
            pass

        scheduler.reschedule.assert_called_once()

    @pytest.mark.anyio
    async def test_recurring_rescheduled_on_failure(self, monkeypatch, fake_orchestrator):
        scheduler = AsyncMock()
        recall = Recall(
            id="r1", session_id="s1", call_type="recurring",
            trigger_at=datetime.now(timezone.utc) - timedelta(minutes=5),
            interval_seconds=3600, internal_comment=None,
            additional_context_message="ctx", status="pending",
            created_at=datetime.now(timezone.utc),
            updated_at=datetime.now(timezone.utc),
        )
        scheduler.get_pending_recalls.return_value = [recall]
        scheduler.get_next_trigger_at.return_value = None
        scheduler.reschedule.return_value = None

        store = AsyncMock()
        store.get.return_value = AsyncMock()
        store.get.return_value.messages = []
        store.get.return_value.context = []

        class FakeAgent:
            async def run_stream(self, *a, **kw):
                raise RuntimeError("boom")
                yield  # makes it an async generator

        monkeypatch.setattr("navi.core.agent.Agent", lambda *a, **kw: FakeAgent())

        task = asyncio.create_task(recall_scheduler_loop(scheduler, store, fake_orchestrator))
        await asyncio.sleep(0.1)
        task.cancel()
        try:
            await task
        except asyncio.CancelledError:
            pass

        scheduler.reschedule.assert_called_once()

    @pytest.mark.anyio
    async def test_one_time_cancelled_on_failure(self, monkeypatch, fake_orchestrator):
        scheduler = AsyncMock()
        recall = Recall(
            id="r1", session_id="s1", call_type="once",
            trigger_at=datetime.now(timezone.utc) - timedelta(minutes=5),
            interval_seconds=None, internal_comment=None,
            additional_context_message="ctx", status="pending",
            created_at=datetime.now(timezone.utc),
            updated_at=datetime.now(timezone.utc),
        )
        scheduler.get_pending_recalls.return_value = [recall]
        scheduler.get_next_trigger_at.return_value = None
        scheduler.mark_cancelled.return_value = None

        store = AsyncMock()
        store.get.return_value = AsyncMock()
        store.get.return_value.messages = []
        store.get.return_value.context = []

        class FakeAgent:
            async def run_stream(self, *a, **kw):
                raise RuntimeError("boom")
                yield  # makes it an async generator

        monkeypatch.setattr("navi.core.agent.Agent", lambda *a, **kw: FakeAgent())

        task = asyncio.create_task(recall_scheduler_loop(scheduler, store, fake_orchestrator))
        await asyncio.sleep(0.1)
        task.cancel()
        try:
            await task
        except asyncio.CancelledError:
            pass

        scheduler.mark_cancelled.assert_called_once_with("r1")

    @pytest.mark.anyio
    async def test_loop_picks_up_after_restart(self, monkeypatch, fake_orchestrator):
        scheduler = AsyncMock()
        scheduler.get_pending_recalls.return_value = [
            Recall(
                id="r1", session_id="s1", call_type="once",
                trigger_at=datetime.now(timezone.utc) - timedelta(hours=2),
                interval_seconds=None, internal_comment=None,
                additional_context_message="ctx", status="pending",
                created_at=datetime.now(timezone.utc),
                updated_at=datetime.now(timezone.utc),
            )
        ]
        scheduler.get_next_trigger_at.return_value = None
        scheduler.mark_fired.return_value = None

        store = AsyncMock()

        fire_calls = []
        async def _fake_fire(semaphore, recall, scheduler, store, orchestrator):
            fire_calls.append(recall.id)
            await scheduler.mark_fired(recall.id)

        monkeypatch.setattr("navi.core.scheduler._fire_recall", _fake_fire)

        task = asyncio.create_task(recall_scheduler_loop(scheduler, store, fake_orchestrator))
        await asyncio.sleep(0.1)
        task.cancel()
        try:
            await task
        except asyncio.CancelledError:
            pass

        assert "r1" in fire_calls