Newer
Older
navi-1 / tests / conftest_factory.py
@Eugene Sukhodolskiy Eugene Sukhodolskiy on 29 Apr 5 KB Bootstrap test suite — Phase 1 unit tests
"""Test factories and fakes for NAVI internals."""

from typing import AsyncGenerator

from navi.core.registry import ProfileRegistry, ToolRegistry
from navi.llm.base import LLMBackend, LLMChunk, LLMResponse, Message, ToolCallRequest, ToolSchema
from navi.profiles.base import AgentProfile


class FakeLLMBackend(LLMBackend):
    """Deterministic LLM backend for tests.

    Usage:
        backend = FakeLLMBackend(responses=["Hello"])
        resp = await backend.complete([])  # -> LLMResponse(content="Hello")
    """

    def __init__(
        self,
        responses: list[str] | None = None,
        tool_calls: list[list[ToolCallRequest]] | None = None,
        thinking: list[str] | None = None,
        prompt_tokens: int = 0,
        completion_tokens: int = 0,
    ) -> None:
        self._responses = responses or []
        self._tool_calls = tool_calls or []
        self._thinking = thinking or []
        self._prompt_tokens = prompt_tokens
        self._completion_tokens = completion_tokens
        self._call_idx = 0
        self._stream_idx = 0

    def _next(self):
        idx = self._call_idx
        self._call_idx += 1
        return idx

    async def complete(
        self,
        messages: list[Message],
        tools: list[ToolSchema] | None = None,
        temperature: float = 0.7,
        model: str | None = None,
        think: bool | None = None,
        max_tokens: int | None = None,
    ) -> LLMResponse:
        idx = self._next()
        content = self._responses[idx] if idx < len(self._responses) else ""
        tcalls = self._tool_calls[idx] if idx < len(self._tool_calls) else None
        finish = "tool_calls" if tcalls else "stop"
        thinking = self._thinking[idx] if idx < len(self._thinking) else None
        return LLMResponse(
            content=content,
            tool_calls=tcalls,
            finish_reason=finish,
            thinking=thinking,
            prompt_tokens=self._prompt_tokens,
            completion_tokens=self._completion_tokens,
        )

    async def stream(
        self,
        messages: list[Message],
        temperature: float = 0.7,
        model: str | None = None,
    ) -> AsyncGenerator[LLMChunk, None]:
        idx = self._stream_idx
        self._stream_idx += 1
        content = self._responses[idx] if idx < len(self._responses) else ""
        if content:
            yield LLMChunk(delta=content, finish_reason="stop")
        else:
            yield LLMChunk(delta="", finish_reason="stop")

    async def stream_complete(
        self,
        messages: list[Message],
        tools: list[ToolSchema] | None = None,
        temperature: float = 0.7,
        model: str | None = None,
        think: bool | None = None,
    ) -> AsyncGenerator[LLMChunk, None]:
        idx = self._stream_idx
        self._stream_idx += 1
        content = self._responses[idx] if idx < len(self._responses) else ""
        thinking = self._thinking[idx] if idx < len(self._thinking) else None
        tcalls = self._tool_calls[idx] if idx < len(self._tool_calls) else None

        if thinking:
            yield LLMChunk(thinking=thinking)
        if content:
            yield LLMChunk(delta=content)
        yield LLMChunk(
            finish_reason="tool_calls" if tcalls else "stop",
            tool_calls=tcalls,
            prompt_tokens=self._prompt_tokens,
            completion_tokens=self._completion_tokens,
        )

    async def embed(
        self,
        texts: list[str],
        model: str | None = None,
    ) -> list[list[float]]:
        return [[0.1] * 768 for _ in texts]


class FakeTool:
    """Minimal stand-in for navi.tools.base.Tool."""

    def __init__(self, name: str, description: str = "", parameters: dict | None = None) -> None:
        self.name = name
        self.description = description
        self.parameters = parameters or {"type": "object", "properties": {}}

    def schema(self) -> dict:
        return {
            "type": "function",
            "function": {
                "name": self.name,
                "description": self.description,
                "parameters": self.parameters,
            },
        }

    async def execute(self, arguments: dict) -> "FakeToolResult":
        from navi.tools.base import ToolResult

        return ToolResult(success=True, output=f"executed {self.name}")


def make_profile(profile_id: str = "test", **overrides) -> AgentProfile:
    defaults = {
        "id": profile_id,
        "name": "Test Profile",
        "description": "A profile for tests",
        "system_prompt": "You are a test assistant.",
        "enabled_tools": ["test_tool"],
        "llm_backend": "ollama",
        "model": ["gemma4:31b-cloud"],
        "max_iterations": 5,
        "temperature": 0.7,
    }
    defaults.update(overrides)
    return AgentProfile(**defaults)


def make_registry_with_tools() -> ToolRegistry:
    reg = ToolRegistry()
    reg.register(FakeTool("test_tool"), builtin=True)
    reg.register(FakeTool("another_tool"), builtin=True)
    return reg


def make_profile_registry() -> ProfileRegistry:
    reg = ProfileRegistry()
    reg.register(make_profile("secretary"))
    reg.register(make_profile("developer"))
    return reg