Newer
Older
navi-1 / tests / unit / tools / test_todo.py
"""Unit tests for navi.tools.todo."""

import json

import pytest

from navi.llm.base import Message
from navi.tools.todo import (
    TodoTool,
    get_failed_steps,
    get_progress_message,
    get_task_snapshot,
    set_tasks,
    _kv_store,
)
from navi.tools._internal.base import ToolContext
from navi.store import KvStore
from tests.conftest_factory import FakeConnection, FakePool


class FakeKvStore(KvStore):
    """In-memory KV store for tests."""

    def __init__(self):
        self._data: dict[tuple, str] = {}

    async def _get_pool(self):
        return FakePool()

    async def get(self, user_id, session_id, scope, key):
        return self._data.get((user_id or "", session_id, scope, key))

    async def set(self, user_id, session_id, scope, key, value):
        self._data[(user_id or "", session_id, scope, key)] = value

    async def get_all(self, user_id, session_id, scope):
        return {
            k[3]: v
            for k, v in self._data.items()
            if k[:3] == (user_id or "", session_id, scope)
        }

    async def delete(self, user_id, session_id, scope, key):
        self._data.pop((user_id or "", session_id, scope, key), None)

    async def clear_scope(self, user_id, session_id, scope):
        keys = [k for k in self._data if k[:3] == (user_id or "", session_id, scope)]
        for k in keys:
            del self._data[k]


@pytest.fixture(autouse=True)
def _fake_kv():
    store = FakeKvStore()
    from navi.tools import todo as _mod
    _mod._kv_store = store
    yield store
    _mod._kv_store = None


# ── TodoTool execute tests ────────────────────────────────────────────────────


@pytest.mark.asyncio
async def test_set_tasks(_fake_kv):
    tool = TodoTool()
    result = await tool.execute({"op": "set", "tasks": ["task A", "task B"]}, ctx=ToolContext(session_id="sess1", user_id="user1"))
    assert result.success is True
    assert "task A" in result.output
    assert "task B" in result.output


@pytest.mark.asyncio
async def test_view_empty(_fake_kv):
    tool = TodoTool()
    result = await tool.execute({"op": "view"}, ctx=ToolContext(session_id="sess1", user_id="user1"))
    assert result.success is True
    assert "No plan set" in result.output


@pytest.mark.asyncio
async def test_update_status(_fake_kv):
    tool = TodoTool()
    await tool.execute({"op": "set", "tasks": ["task A"]}, ctx=ToolContext(session_id="sess1", user_id="user1"))
    result = await tool.execute({"op": "update", "index": 1, "status": "done", "validation": "tested"}, ctx=ToolContext(session_id="sess1", user_id="user1"))
    assert result.success is True
    assert "done" in result.output


@pytest.mark.asyncio
async def test_done_requires_validation(_fake_kv):
    tool = TodoTool()
    await tool.execute({"op": "set", "tasks": ["task A"]}, ctx=ToolContext(session_id="sess1", user_id="user1"))
    result = await tool.execute({"op": "update", "index": 1, "status": "done"}, ctx=ToolContext(session_id="sess1", user_id="user1"))
    assert result.success is False
    assert "validation" in result.error.lower()


@pytest.mark.asyncio
async def test_failed_without_validation_warns(_fake_kv):
    tool = TodoTool()
    await tool.execute({"op": "set", "tasks": ["task A"]}, ctx=ToolContext(session_id="sess1", user_id="user1"))
    result = await tool.execute({"op": "update", "index": 1, "status": "failed"}, ctx=ToolContext(session_id="sess1", user_id="user1"))
    assert result.success is True
    assert "Tip" in result.output


@pytest.mark.asyncio
async def test_clear(_fake_kv):
    tool = TodoTool()
    await tool.execute({"op": "set", "tasks": ["task A"]}, ctx=ToolContext(session_id="sess1", user_id="user1"))
    result = await tool.execute({"op": "clear"}, ctx=ToolContext(session_id="sess1", user_id="user1"))
    assert result.success is True
    assert "cleared" in result.output.lower()


# ── Public API tests ─────────────────────────────────────────────────────────


@pytest.mark.asyncio
async def test_get_task_snapshot(_fake_kv):
    await set_tasks("sess1", ["t1", "t2"])
    snapshot = await get_task_snapshot("sess1")
    assert snapshot == frozenset({("t1", "pending"), ("t2", "pending")})


@pytest.mark.asyncio
async def test_get_progress_message(_fake_kv):
    await set_tasks("sess1", ["t1", "t2"])
    msg = await get_progress_message("sess1", first_iteration=True)
    assert isinstance(msg, Message)
    assert msg.role == "system"
    assert "TODO progress" in msg.content