Newer
Older
navi-1 / tests / unit / tools / test_recall_tools.py
"""Unit tests for ScheduleRecallTool and ManageRecallTool."""

from datetime import datetime, timedelta, timezone
from unittest.mock import AsyncMock

import pytest

from navi.core.scheduler import Recall, RecallExistsError
from navi.tools.schedule_recall import ScheduleRecallTool
from navi.tools.manage_recall import ManageRecallTool
from navi.tools._internal.base import ToolContext


class TestScheduleRecallTool:
    async def test_schedule_once(self, monkeypatch):
        scheduler = AsyncMock()
        scheduler.schedule_recall.return_value = Recall(
            id="r1", session_id="s1", call_type="once",
            trigger_at=datetime.now(timezone.utc), interval_seconds=None,
            internal_comment=None, additional_context_message="ctx",
            status="pending", created_at=datetime.now(timezone.utc),
            updated_at=datetime.now(timezone.utc),
        )
        tool = ScheduleRecallTool(scheduler)

        result = await tool.execute({
            "call_type": "once",
            "when": "1h",
            "additional_context_message": "ctx",
        }, ctx=ToolContext(session_id="s1"))

        assert result.success is True
        assert "r1" in result.output
        scheduler.schedule_recall.assert_called_once()

    async def test_schedule_immediate(self, monkeypatch):
        scheduler = AsyncMock()
        scheduler.schedule_recall.return_value = Recall(
            id="r1", session_id="s1", call_type="immediate",
            trigger_at=datetime.now(timezone.utc), interval_seconds=None,
            internal_comment=None, additional_context_message="ctx",
            status="pending", created_at=datetime.now(timezone.utc),
            updated_at=datetime.now(timezone.utc),
        )
        tool = ScheduleRecallTool(scheduler)

        result = await tool.execute({
            "call_type": "immediate",
            "additional_context_message": "ctx",
        }, ctx=ToolContext(session_id="s1"))

        assert result.success is True
        assert "immediate" in result.output

    async def test_schedule_recurring(self, monkeypatch):
        scheduler = AsyncMock()
        scheduler.schedule_recall.return_value = Recall(
            id="r1", session_id="s1", call_type="recurring",
            trigger_at=datetime.now(timezone.utc), interval_seconds=3600,
            internal_comment=None, additional_context_message="ctx",
            status="pending", created_at=datetime.now(timezone.utc),
            updated_at=datetime.now(timezone.utc),
        )
        tool = ScheduleRecallTool(scheduler)

        result = await tool.execute({
            "call_type": "recurring",
            "when": "1h",
            "interval_seconds": 3600,
            "additional_context_message": "ctx",
        }, ctx=ToolContext(session_id="s1"))

        assert result.success is True
        assert "3600s" in result.output

    async def test_missing_context_message(self, monkeypatch):
        scheduler = AsyncMock()
        tool = ScheduleRecallTool(scheduler)

        result = await tool.execute({
            "call_type": "once",
            "when": "1h",
        }, ctx=ToolContext(session_id="s1"))

        assert result.success is False
        assert "missing context" in result.error

    async def test_recall_exists_error(self, monkeypatch):
        scheduler = AsyncMock()
        scheduler.schedule_recall.side_effect = RecallExistsError("exists")
        tool = ScheduleRecallTool(scheduler)

        result = await tool.execute({
            "call_type": "once",
            "when": "1h",
            "additional_context_message": "ctx",
        }, ctx=ToolContext(session_id="s1"))

        assert result.success is False
        assert result.error == "recall_exists"

    async def test_invalid_call_type(self, monkeypatch):
        scheduler = AsyncMock()
        tool = ScheduleRecallTool(scheduler)

        result = await tool.execute({
            "call_type": "invalid",
            "when": "1h",
            "additional_context_message": "ctx",
        }, ctx=ToolContext(session_id="s1"))

        assert result.success is False
        assert "bad_call_type" in result.error


class TestManageRecallTool:
    async def test_cancel(self, monkeypatch):
        scheduler = AsyncMock()
        scheduler.cancel_recall.return_value = True
        tool = ManageRecallTool(scheduler)

        result = await tool.execute({"action": "cancel"}, ctx=ToolContext(session_id="s1"))

        assert result.success is True
        assert "cancelled" in result.output
        scheduler.cancel_recall.assert_called_once_with("s1")

    async def test_skip(self, monkeypatch):
        scheduler = AsyncMock()
        scheduler.skip_next_recall.return_value = True
        tool = ManageRecallTool(scheduler)

        result = await tool.execute({"action": "skip"}, ctx=ToolContext(session_id="s1"))

        assert result.success is True
        assert "skipped" in result.output
        scheduler.skip_next_recall.assert_called_once_with("s1")

    async def test_list(self, monkeypatch):
        scheduler = AsyncMock()
        scheduler.list_recalls.return_value = [
            Recall(
                id="r1", session_id="s1", call_type="once",
                trigger_at=datetime.now(timezone.utc), interval_seconds=None,
                internal_comment="test", additional_context_message="ctx",
                status="pending", created_at=datetime.now(timezone.utc),
                updated_at=datetime.now(timezone.utc),
            )
        ]
        tool = ManageRecallTool(scheduler)

        result = await tool.execute({"action": "list"}, ctx=ToolContext(session_id="s1", user_role="admin"))

        assert result.success is True
        assert "Recalls for session" in result.output
        assert "PENDING" in result.output

    async def test_bad_action(self, monkeypatch):
        scheduler = AsyncMock()
        tool = ManageRecallTool(scheduler)

        result = await tool.execute({"action": "unknown"}, ctx=ToolContext(session_id="s1"))

        assert result.success is False
        assert result.error == "bad_action"

    async def test_cancel_no_pending(self, monkeypatch):
        scheduler = AsyncMock()
        scheduler.cancel_recall.return_value = False
        tool = ManageRecallTool(scheduler)

        result = await tool.execute({"action": "cancel"}, ctx=ToolContext(session_id="s1"))

        assert result.success is False
        assert "no_pending_recall" in result.error