Newer
Older
navi-1 / tests / unit / core / test_agent.py
"""Unit tests for navi.core.agent.Agent.

Uses InMemorySessionStore, FakeLLMBackend, and FakeTool so tests run
without a real database or LLM server.
"""

import asyncio

import pytest
import pytest_asyncio

from navi.core.agent import Agent
from navi.core.events import (
    StreamEnd,
    StreamStopped,
    SubagentComplete,
    TextDelta,
    ToolEvent,
    ToolStarted,
)
from navi.core.registry import BackendRegistry, ProfileRegistry, ToolRegistry
from navi.core.session import InMemorySessionStore
from navi.exceptions import MaxIterationsReached, SessionNotFound
from navi.llm.base import LLMChunk, Message, ToolCallRequest
from navi.tools._internal.base import ToolResult
from tests.conftest_factory import FakeLLMBackend, FakeTool, make_profile, make_registry_with_tools


@pytest.fixture
def agent():
    sessions = InMemorySessionStore()
    profiles = ProfileRegistry()
    profile = make_profile("test")
    profile.planning_phase1_enabled = False
    profile.planning_phase2_enabled = False
    profile.planning_phase3_enabled = False
    profiles.register(profile)
    tools = make_registry_with_tools()
    backends = BackendRegistry()
    backends.register("ollama", FakeLLMBackend(responses=["hello"]))
    return Agent(
        session_store=sessions,
        profile_registry=profiles,
        tool_registry=tools,
        backend_registry=backends,
    )


@pytest_asyncio.fixture
async def session(agent):
    return await agent._sessions.create(profile_id="test")


# ─── run() tests ───────────────────────────────────────────────────────────


class TestAgentRun:
    @pytest.mark.asyncio
    async def test_run_single_iteration(self, agent, session):
        backend = FakeLLMBackend(responses=["hello"])
        agent._backends.register("ollama", backend)

        result = await agent.run(session.id, "hi")
        assert result == "hello"
        saved = await agent._sessions.get(session.id)
        # user display + user context + assistant
        assert len(saved.messages) == 3
        assert saved.messages[0].role == "user"
        assert saved.messages[1].role == "user"
        assert saved.messages[2].role == "assistant"
        assert saved.messages[2].content == "hello"

        # Flags: display-only user message, context-only user message
        assert saved.messages[0].is_display is True
        assert saved.messages[0].is_context is False
        assert saved.messages[1].is_display is False
        assert saved.messages[1].is_context is True
        # Assistant message is both display and context
        assert saved.messages[2].is_display is True
        assert saved.messages[2].is_context is True

    @pytest.mark.asyncio
    async def test_run_session_not_found(self, agent):
        with pytest.raises(SessionNotFound):
            await agent.run("nonexistent-id", "hi")

    @pytest.mark.asyncio
    async def test_run_tool_calls_then_stop(self, agent, session):
        """Tool-calling turn followed by a final stop turn."""
        backend = FakeLLMBackend(
            responses=["", "done"],
            tool_calls=[
                [ToolCallRequest(id="1", name="test_tool", arguments={})],
                None,
            ],
        )
        agent._backends.register("ollama", backend)

        result = await agent.run(session.id, "do something")
        assert result == "done"
        saved = await agent._sessions.get(session.id)
        # user display + user context + assistant(tool) + tool_result + assistant(final)
        assert len(saved.messages) == 5
        assert saved.messages[3].role == "tool"
        assert saved.messages[4].content == "done"

    @pytest.mark.asyncio
    async def test_run_token_accumulation(self, agent, session):
        """_turn_tokens accumulates completion tokens across tool-calling iterations."""
        backend = FakeLLMBackend(
            responses=["", "done"],
            tool_calls=[
                [ToolCallRequest(id="1", name="test_tool", arguments={})],
                None,
            ],
            prompt_tokens=10,
            completion_tokens=5,
        )
        agent._backends.register("ollama", backend)

        await agent.run(session.id, "do something")
        saved = await agent._sessions.get(session.id)
        final_msg = saved.messages[-1]
        # Two iterations × 5 completion tokens = 10 tokens
        assert final_msg.token_count == 10

    @pytest.mark.asyncio
    async def test_run_max_iterations(self, agent, session):
        """After max_iterations tool turns, MaxIterationsReached is raised."""
        profile = agent._profiles.get("test")
        profile.max_iterations = 2

        backend = FakeLLMBackend(
            responses=["", ""],
            tool_calls=[
                [ToolCallRequest(id="1", name="test_tool", arguments={})],
                [ToolCallRequest(id="2", name="test_tool", arguments={})],
            ],
        )
        agent._backends.register("ollama", backend)

        with pytest.raises(MaxIterationsReached):
            await agent.run(session.id, "loop forever")


# ─── run_stream() tests ──────────────────────────────────────────────────────


class TestAgentRunStream:
    @pytest.mark.asyncio
    async def test_run_stream_single_iteration(self, agent, session):
        backend = FakeLLMBackend(responses=["streamed hello"])
        agent._backends.register("ollama", backend)

        events = []
        async for ev in agent.run_stream(session.id, "hi"):
            events.append(type(ev).__name__)

        assert events[-1] == "StreamEnd"
        saved = await agent._sessions.get(session.id)
        assert saved.messages[-1].content == "streamed hello"

    @pytest.mark.asyncio
    async def test_run_stream_tool_calls(self, agent, session):
        backend = FakeLLMBackend(
            responses=["", "final"],
            tool_calls=[
                [ToolCallRequest(id="1", name="test_tool", arguments={})],
                None,
            ],
        )
        agent._backends.register("ollama", backend)

        events = []
        async for ev in agent.run_stream(session.id, "do something"):
            events.append(type(ev).__name__)

        assert "ToolStarted" in events
        assert "ToolEvent" in events
        assert events[-1] == "StreamEnd"

    @pytest.mark.asyncio
    async def test_run_stream_stop_event(self, agent, session):
        """Cooperative stop mid-stream yields StreamStopped."""
        from navi.tools._internal.base import current_stop_event

        stop = asyncio.Event()
        token = current_stop_event.set(stop)
        try:
            async def _slow_stream(self, **kwargs):
                yield LLMChunk(delta="a")
                await asyncio.sleep(10)
                yield LLMChunk(delta="b")

            backend = FakeLLMBackend()
            # Monkey-patch stream_complete to be slow
            backend.stream_complete = _slow_stream
            agent._backends.register("ollama", backend)

            stop.set()
            events = []
            async for ev in agent.run_stream(session.id, "hi"):
                events.append(type(ev).__name__)

            assert "StreamStopped" in events
        finally:
            current_stop_event.reset(token)

    @pytest.mark.asyncio
    async def test_run_stream_token_count(self, agent, session):
        backend = FakeLLMBackend(
            responses=["final"],
            prompt_tokens=100,
            completion_tokens=50,
        )
        agent._backends.register("ollama", backend)

        events = []
        async for ev in agent.run_stream(session.id, "hi"):
            if isinstance(ev, StreamEnd):
                events.append(ev)

        assert events[0].token_count == 50
        saved = await agent._sessions.get(session.id)
        assert saved.messages[-1].token_count == 50


# ─── run_ephemeral() tests ───────────────────────────────────────────────────


class TestAgentRunEphemeral:
    @pytest.mark.asyncio
    async def test_run_ephemeral_complete(self, agent):
        backend = FakeLLMBackend(responses=["subagent result"])
        agent._backends.register("ollama", backend)

        result, ok = await agent.run_ephemeral("task", profile_id="test")
        assert "subagent result" in result
        assert "[Sub-agent stopped: completed]" in result
        assert ok is True

    @pytest.mark.asyncio
    async def test_run_ephemeral_max_iterations(self, agent):
        backend = FakeLLMBackend(
            responses=[""],
            tool_calls=[
                [ToolCallRequest(id="1", name="test_tool", arguments={})],
            ],
        )
        agent._backends.register("ollama", backend)

        result, ok = await agent.run_ephemeral(
            "task", profile_id="test", max_iterations=1
        )
        assert ok is False
        assert "iteration limit" in result.lower()

    @pytest.mark.skip(reason="run_ephemeral uses 'import time as _time' inside the function; CPython LOAD_GLOBAL caching makes module-level mock replacement unreliable in pytest-asyncio.")
    @pytest.mark.asyncio
    async def test_run_ephemeral_timeout(self, agent):
        pass

    @pytest.mark.asyncio
    async def test_run_ephemeral_planning_tokens_accumulated(self, agent):
        """Planning phase AIHelperTokensUsed contributes to SubagentComplete."""
        from navi.core.events import AIHelperTokensUsed
        from navi.tools._internal.base import current_event_sink

        backend = FakeLLMBackend(responses=["final"])
        agent._backends.register("ollama", backend)

        # Force planning by setting subagent_planning_enabled on profile
        profile = agent._profiles.get("test")
        profile.subagent_planning_enabled = True

        sink = asyncio.Queue()
        token = current_event_sink.set(sink)
        try:
            # Mock planning to emit AIHelperTokensUsed
            original_planning_run = agent._planning.run

            async def _mock_planning(*args, **kwargs):
                yield AIHelperTokensUsed(prompt_tokens=5, completion_tokens=10)
                yield AIHelperTokensUsed(prompt_tokens=3, completion_tokens=7)

            agent._planning.run = _mock_planning

            result, ok = await agent.run_ephemeral("task", profile_id="test")
            assert ok is True

            # Drain sink for SubagentComplete
            subagent_complete = None
            while not sink.empty():
                item = await sink.get()
                if isinstance(item, SubagentComplete):
                    subagent_complete = item

            # Planning completion tokens: 10 + 7 = 17
            # Final LLM call: 0 (no tokens in FakeLLMBackend default)
            assert subagent_complete is not None
            assert subagent_complete.token_count == 17
        finally:
            current_event_sink.reset(token)
            agent._planning.run = original_planning_run

    @pytest.mark.asyncio
    async def test_run_ephemeral_thinking_stall(self, agent):
        """Subagent that produces only thinking for too long is aborted."""
        async def _thinking_only(self, **kwargs):
            for _ in range(200):
                yield LLMChunk(thinking="thinking " * 100)
            yield LLMChunk(delta="done", finish_reason="stop")

        backend = FakeLLMBackend()
        backend.stream_complete = _thinking_only
        agent._backends.register("ollama", backend)

        result, ok = await agent.run_ephemeral("task", profile_id="test")
        assert ok is False
        assert "thinking" in result.lower() or "stall" in result.lower()