Newer
Older
navi-1 / tests / conftest_factory.py
"""Test factories and fakes for NAVI internals."""

from typing import AsyncGenerator

from navi.core.registry import ProfileRegistry, ToolRegistry
from navi.llm.base import LLMBackend, LLMChunk, LLMResponse, Message, ToolCallRequest, ToolSchema
from navi.profiles.base import AgentProfile


class FakeLLMBackend(LLMBackend):
    """Deterministic LLM backend for tests.

    Usage:
        backend = FakeLLMBackend(responses=["Hello"])
        resp = await backend.complete([])  # -> LLMResponse(content="Hello")
    """

    def __init__(
        self,
        responses: list[str] | None = None,
        tool_calls: list[list[ToolCallRequest]] | None = None,
        thinking: list[str] | None = None,
        prompt_tokens: int = 0,
        completion_tokens: int = 0,
    ) -> None:
        self._responses = responses or []
        self._tool_calls = tool_calls or []
        self._thinking = thinking or []
        self._prompt_tokens = prompt_tokens
        self._completion_tokens = completion_tokens
        self._call_idx = 0
        self._stream_idx = 0

    def _next(self):
        idx = self._call_idx
        self._call_idx += 1
        return idx

    async def complete(
        self,
        messages: list[Message],
        tools: list[ToolSchema] | None = None,
        temperature: float = 0.7,
        model: str | None = None,
        think: bool | None = None,
        max_tokens: int | None = None,
        **kwargs,
    ) -> LLMResponse:
        idx = self._next()
        content = self._responses[idx] if idx < len(self._responses) else ""
        tcalls = self._tool_calls[idx] if idx < len(self._tool_calls) else None
        finish = "tool_calls" if tcalls else "stop"
        thinking = self._thinking[idx] if idx < len(self._thinking) else None
        return LLMResponse(
            content=content,
            tool_calls=tcalls,
            finish_reason=finish,
            thinking=thinking,
            prompt_tokens=self._prompt_tokens,
            completion_tokens=self._completion_tokens,
        )

    async def stream(
        self,
        messages: list[Message],
        temperature: float = 0.7,
        model: str | None = None,
        **kwargs,
    ) -> AsyncGenerator[LLMChunk, None]:
        idx = self._stream_idx
        self._stream_idx += 1
        content = self._responses[idx] if idx < len(self._responses) else ""
        if content:
            yield LLMChunk(delta=content, finish_reason="stop")
        else:
            yield LLMChunk(delta="", finish_reason="stop")

    async def stream_complete(
        self,
        messages: list[Message],
        tools: list[ToolSchema] | None = None,
        temperature: float = 0.7,
        model: str | None = None,
        think: bool | None = None,
        **kwargs,
    ) -> AsyncGenerator[LLMChunk, None]:
        idx = self._stream_idx
        self._stream_idx += 1
        content = self._responses[idx] if idx < len(self._responses) else ""
        thinking = self._thinking[idx] if idx < len(self._thinking) else None
        tcalls = self._tool_calls[idx] if idx < len(self._tool_calls) else None

        if thinking:
            yield LLMChunk(thinking=thinking)
        if content:
            yield LLMChunk(delta=content)
        yield LLMChunk(
            finish_reason="tool_calls" if tcalls else "stop",
            tool_calls=tcalls,
            prompt_tokens=self._prompt_tokens,
            completion_tokens=self._completion_tokens,
        )

    async def embed(
        self,
        texts: list[str],
        model: str | None = None,
    ) -> list[list[float]]:
        return [[0.1] * 768 for _ in texts]


class FakeTool:
    """Minimal stand-in for navi.tools.base.Tool."""

    def __init__(self, name: str, description: str = "", parameters: dict | None = None) -> None:
        self.name = name
        self.description = description
        self.parameters = parameters or {"type": "object", "properties": {}}

    def schema(self) -> dict:
        return {
            "type": "function",
            "function": {
                "name": self.name,
                "description": self.description,
                "parameters": self.parameters,
            },
        }

    async def execute(self, arguments: dict) -> "FakeToolResult":
        from navi.tools._internal.base import ToolResult

        return ToolResult(success=True, output=f"executed {self.name}")


def make_profile(profile_id: str = "test", **overrides) -> AgentProfile:
    defaults = {
        "id": profile_id,
        "name": "Test Profile",
        "description": "A profile for tests",
        "system_prompt": "You are a test assistant.",
        "enabled_tools": ["test_tool"],
        "llm_backend": "ollama",
        "model": ["gemma4:31b-cloud"],
        "max_iterations": 5,
        "temperature": 0.7,
    }
    defaults.update(overrides)
    return AgentProfile(**defaults)


def make_registry_with_tools() -> ToolRegistry:
    reg = ToolRegistry()
    reg.register(FakeTool("test_tool"), builtin=True)
    reg.register(FakeTool("another_tool"), builtin=True)
    return reg


def make_profile_registry() -> ProfileRegistry:
    reg = ProfileRegistry()
    reg.register(make_profile("secretary"))
    reg.register(make_profile("developer"))
    return reg


class FakeRecord:
    """Stand-in for asyncpg.Record — supports dict-like and index access."""

    def __init__(self, **kwargs) -> None:
        self._data = kwargs

    def __getitem__(self, key):
        if isinstance(key, int):
            return list(self._data.values())[key]
        return self._data[key]

    def __getattr__(self, name):
        try:
            return self._data[name]
        except KeyError:
            raise AttributeError(name)

    def get(self, name, default=None):
        return self._data.get(name, default)


class FakeConnection:
    """In-memory asyncpg connection for unit tests.

    Results are returned from a FIFO queue set up by the test.
    All calls are logged in `.calls` for assertions.
    """

    def __init__(self) -> None:
        self._results: list = []
        self.calls: list[tuple[str, str, tuple]] = []

    def enqueue(self, result) -> None:
        """Queue a result for the next async operation."""
        self._results.append(result)

    def _next(self):
        if self._results:
            result = self._results.pop(0)
            if isinstance(result, Exception):
                raise result
            return result
        return None

    async def execute(self, query: str, *args) -> str:
        self.calls.append(("execute", query, args))
        return self._next() or "OK"

    async def fetch(self, query: str, *args) -> list[FakeRecord]:
        self.calls.append(("fetch", query, args))
        result = self._next()
        if result is None:
            return []
        return result if isinstance(result, list) else [result]

    async def fetchval(self, query: str, *args):
        self.calls.append(("fetchval", query, args))
        return self._next()

    async def fetchrow(self, query: str, *args) -> FakeRecord | None:
        self.calls.append(("fetchrow", query, args))
        return self._next()

    async def executemany(self, query: str, args_list: list) -> None:
        self.calls.append(("executemany", query, args_list))

    async def __aenter__(self):
        return self

    async def __aexit__(self, *exc):
        pass


class FakePool:
    """In-memory asyncpg pool that always returns the same FakeConnection."""

    def __init__(self, conn: FakeConnection | None = None) -> None:
        self._conn = conn or FakeConnection()

    def acquire(self):
        class _Ctx:
            async def __aenter__(_self):
                return self._conn
            async def __aexit__(_self, *exc):
                pass
        return _Ctx()

    async def close(self):
        pass

    async def __aenter__(self):
        return self

    async def __aexit__(self, *exc):
        pass


def make_store_with_pool(conn: FakeConnection | None = None):
    """Build a MemoryStore wired to a FakePool."""
    from navi.memory.store import MemoryStore

    pool = FakePool(conn)
    store = MemoryStore(pool=pool)
    store._initialized = True
    return store


def make_scheduler_with_pool(conn: FakeConnection | None = None):
    """Build a RecallScheduler wired to a FakePool."""
    from navi.core.scheduler import RecallScheduler

    pool = FakePool(conn)
    scheduler = RecallScheduler(pool=pool)
    scheduler._initialized = True
    return scheduler