diff --git a/docs/testing.md b/docs/testing.md new file mode 100644 index 0000000..c237bef --- /dev/null +++ b/docs/testing.md @@ -0,0 +1,72 @@ +# Testing Strategy + +## Stack +- **pytest** + **pytest-asyncio** (`asyncio_mode = auto`) +- **pytest-mock** — `mocker` fixture +- **httpx** — `TestClient` for FastAPI routes +- **asgi-lifespan** — lifespan management in integration tests + +## Directory layout + +``` +tests/ +├── conftest.py # Shared fixtures (settings override, event_loop policy) +├── conftest_factory.py # Factories: FakeLLMBackend, FakeSessionStore, FakeMemoryStore +├── unit/ # No external deps (mocked DB / LLM) +│ ├── core/ +│ │ ├── test_events.py +│ │ ├── test_context_builder.py +│ │ ├── test_compressor.py +│ │ ├── test_registry.py +│ │ └── test_planning.py +│ ├── memory/ +│ │ ├── test_store.py +│ │ └── test_extractor.py +│ ├── tools/ +│ │ └── test_filesystem.py +│ ├── profiles/ +│ │ └── test_base.py +│ └── config/ +│ └── test_settings.py +├── integration/ # FastAPI TestClient + mocked or real DB +│ ├── test_api_routes.py +│ ├── test_websocket.py +│ └── test_memory_store.py +└── e2e/ + └── test_chat_flow.py # Critical path: message → tool call → response +``` + +## Mock strategy + +### LLM +`FakeLLMBackend` cycles through a list of pre-defined responses and optionally emits `ToolCallRequest` objects. This lets us test the agent loop and planning without real Ollama. + +### PostgreSQL +Unit tests mock `asyncpg.Pool` via a thin `FakePool`/`FakeConnection` that stores rows in-memory (`list[dict]`). Integration tests may use a real Postgres instance via `TEST_DATABASE_URL`. + +## Execution order + +| Phase | Scope | Est. time | +|-------|-------|-----------| +| 1 | Infrastructure (`conftest`, `FakeLLMBackend`) + `events`, `context_builder`, `compressor`, `registry`, `profiles` | 2–3 h | +| 2 | Memory store (mock DB) + extractor | 2 h | +| 3 | API routes (`TestClient`) + WebSocket | 2 h | +| 4 | Agent loop + planning with `FakeLLMBackend` | 3 h | +| 5 | Tools (`filesystem`, `code_exec`, `terminal`) | 2 h | +| 6 | Integration with real Postgres (optional) | later | + +## Running tests + +```bash +# All tests +pytest + +# Unit only +pytest tests/unit + +# With verbose +pytest -v tests/unit/core + +# Integration (requires TEST_DATABASE_URL) +TEST_DATABASE_URL=postgresql://... pytest tests/integration +``` diff --git a/navi/core/context_builder.py b/navi/core/context_builder.py index 940e1b3..7460a77 100644 --- a/navi/core/context_builder.py +++ b/navi/core/context_builder.py @@ -6,9 +6,10 @@ from datetime import datetime, timezone from typing import TYPE_CHECKING -from navi.config import settings from navi.llm.base import Message +import navi.config as _config + if TYPE_CHECKING: from navi.context_providers._loader import ContextProviderRegistry from navi.memory.store import MemoryStore @@ -45,7 +46,7 @@ return cached parts: list[str] = [] - persona = settings.navi_persona.strip() + persona = _config.settings.navi_persona.strip() if persona: parts.append(persona) parts.append(profile.system_prompt) diff --git a/tests/conftest.py b/tests/conftest.py new file mode 100644 index 0000000..97d3e06 --- /dev/null +++ b/tests/conftest.py @@ -0,0 +1,38 @@ +"""Shared pytest fixtures and configuration.""" + +import asyncio +import os +from typing import Generator + +import pytest + +# Ensure the project root is importable +import sys + +sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) + + +@pytest.fixture(autouse=True) +def _reset_settings(monkeypatch: pytest.MonkeyPatch) -> Generator[None, None, None]: + """Reset navi.config.settings before every test so mutations don't leak.""" + from navi.config import Settings + + # Clear env vars that would leak from the host environment / .env + for key in ("NAVI_PERSONA", "NAVI_PERSONA_FILE"): + monkeypatch.delenv(key, raising=False) + + # Re-create settings from defaults (no .env file during tests) + fresh = Settings(_env_file=None, navi_persona_file="") + import navi.config as _config_mod + + _config_mod.settings = fresh + yield + _config_mod.settings = fresh + + +@pytest.fixture +def event_loop(): + """Provide a consistent event loop for async tests.""" + loop = asyncio.get_event_loop_policy().new_event_loop() + yield loop + loop.close() diff --git a/tests/conftest_factory.py b/tests/conftest_factory.py new file mode 100644 index 0000000..0ea7e45 --- /dev/null +++ b/tests/conftest_factory.py @@ -0,0 +1,160 @@ +"""Test factories and fakes for NAVI internals.""" + +from typing import AsyncGenerator + +from navi.core.registry import ProfileRegistry, ToolRegistry +from navi.llm.base import LLMBackend, LLMChunk, LLMResponse, Message, ToolCallRequest, ToolSchema +from navi.profiles.base import AgentProfile + + +class FakeLLMBackend(LLMBackend): + """Deterministic LLM backend for tests. + + Usage: + backend = FakeLLMBackend(responses=["Hello"]) + resp = await backend.complete([]) # -> LLMResponse(content="Hello") + """ + + def __init__( + self, + responses: list[str] | None = None, + tool_calls: list[list[ToolCallRequest]] | None = None, + thinking: list[str] | None = None, + prompt_tokens: int = 0, + completion_tokens: int = 0, + ) -> None: + self._responses = responses or [] + self._tool_calls = tool_calls or [] + self._thinking = thinking or [] + self._prompt_tokens = prompt_tokens + self._completion_tokens = completion_tokens + self._call_idx = 0 + self._stream_idx = 0 + + def _next(self): + idx = self._call_idx + self._call_idx += 1 + return idx + + async def complete( + self, + messages: list[Message], + tools: list[ToolSchema] | None = None, + temperature: float = 0.7, + model: str | None = None, + think: bool | None = None, + max_tokens: int | None = None, + ) -> LLMResponse: + idx = self._next() + content = self._responses[idx] if idx < len(self._responses) else "" + tcalls = self._tool_calls[idx] if idx < len(self._tool_calls) else None + finish = "tool_calls" if tcalls else "stop" + thinking = self._thinking[idx] if idx < len(self._thinking) else None + return LLMResponse( + content=content, + tool_calls=tcalls, + finish_reason=finish, + thinking=thinking, + prompt_tokens=self._prompt_tokens, + completion_tokens=self._completion_tokens, + ) + + async def stream( + self, + messages: list[Message], + temperature: float = 0.7, + model: str | None = None, + ) -> AsyncGenerator[LLMChunk, None]: + idx = self._stream_idx + self._stream_idx += 1 + content = self._responses[idx] if idx < len(self._responses) else "" + if content: + yield LLMChunk(delta=content, finish_reason="stop") + else: + yield LLMChunk(delta="", finish_reason="stop") + + async def stream_complete( + self, + messages: list[Message], + tools: list[ToolSchema] | None = None, + temperature: float = 0.7, + model: str | None = None, + think: bool | None = None, + ) -> AsyncGenerator[LLMChunk, None]: + idx = self._stream_idx + self._stream_idx += 1 + content = self._responses[idx] if idx < len(self._responses) else "" + thinking = self._thinking[idx] if idx < len(self._thinking) else None + tcalls = self._tool_calls[idx] if idx < len(self._tool_calls) else None + + if thinking: + yield LLMChunk(thinking=thinking) + if content: + yield LLMChunk(delta=content) + yield LLMChunk( + finish_reason="tool_calls" if tcalls else "stop", + tool_calls=tcalls, + prompt_tokens=self._prompt_tokens, + completion_tokens=self._completion_tokens, + ) + + async def embed( + self, + texts: list[str], + model: str | None = None, + ) -> list[list[float]]: + return [[0.1] * 768 for _ in texts] + + +class FakeTool: + """Minimal stand-in for navi.tools.base.Tool.""" + + def __init__(self, name: str, description: str = "", parameters: dict | None = None) -> None: + self.name = name + self.description = description + self.parameters = parameters or {"type": "object", "properties": {}} + + def schema(self) -> dict: + return { + "type": "function", + "function": { + "name": self.name, + "description": self.description, + "parameters": self.parameters, + }, + } + + async def execute(self, arguments: dict) -> "FakeToolResult": + from navi.tools.base import ToolResult + + return ToolResult(success=True, output=f"executed {self.name}") + + +def make_profile(profile_id: str = "test", **overrides) -> AgentProfile: + defaults = { + "id": profile_id, + "name": "Test Profile", + "description": "A profile for tests", + "system_prompt": "You are a test assistant.", + "enabled_tools": ["test_tool"], + "llm_backend": "ollama", + "model": ["gemma4:31b-cloud"], + "max_iterations": 5, + "temperature": 0.7, + } + defaults.update(overrides) + return AgentProfile(**defaults) + + +def make_registry_with_tools() -> ToolRegistry: + reg = ToolRegistry() + reg.register(FakeTool("test_tool"), builtin=True) + reg.register(FakeTool("another_tool"), builtin=True) + return reg + + +def make_profile_registry() -> ProfileRegistry: + reg = ProfileRegistry() + reg.register(make_profile("secretary")) + reg.register(make_profile("developer")) + return reg diff --git a/tests/unit/core/test_compressor.py b/tests/unit/core/test_compressor.py new file mode 100644 index 0000000..416f9a1 --- /dev/null +++ b/tests/unit/core/test_compressor.py @@ -0,0 +1,170 @@ +"""Unit tests for context compressor.""" + +import pytest + +from navi.core.compressor import ( + _format_for_summary, + compress_context, + partition_messages, + should_compress, +) +from navi.llm.base import Message, ToolCallRequest +from tests.conftest_factory import FakeLLMBackend + + +class TestShouldCompress: + def test_below_threshold(self): + assert should_compress(100, 1000, 0.7) is False + + def test_at_threshold(self): + assert should_compress(700, 1000, 0.7) is True + + def test_above_threshold(self): + assert should_compress(800, 1000, 0.7) is True + + +class TestPartitionMessages: + def test_empty(self): + old, recent = partition_messages([], keep_recent=2) + assert old == [] + assert recent == [] + + def test_fewer_turns_than_keep(self): + msgs = [ + Message(role="user", content="hi"), + Message(role="assistant", content="hello"), + ] + old, recent = partition_messages(msgs, keep_recent=5) + assert old == [] + assert recent == msgs + + def test_splits_into_old_and_recent(self): + msgs = [ + Message(role="user", content="1"), + Message(role="assistant", content="a1"), + Message(role="user", content="2"), + Message(role="assistant", content="a2"), + Message(role="user", content="3"), + Message(role="assistant", content="a3"), + ] + old, recent = partition_messages(msgs, keep_recent=2) + assert len(old) == 2 # turn 1 + assert len(recent) == 4 # turns 2+3 + + def test_system_messages_ignored(self): + msgs = [ + Message(role="system", content="sys"), + Message(role="user", content="1"), + Message(role="assistant", content="a1"), + ] + old, recent = partition_messages(msgs, keep_recent=5) + assert recent == [m for m in msgs if m.role != "system"] + + def test_tool_calls_stay_with_assistant(self): + msgs = [ + Message(role="user", content="1"), + Message( + role="assistant", + content="", + tool_calls=[ToolCallRequest(id="1", name="fs", arguments={})], + ), + Message(role="tool", content="result", name="fs", tool_call_id="1"), + Message(role="user", content="2"), + Message(role="assistant", content="ok"), + ] + old, recent = partition_messages(msgs, keep_recent=1) + assert len(old) == 3 + assert len(recent) == 2 + + +class TestFormatForSummary: + def test_user_message(self): + msgs = [Message(role="user", content="hello")] + text, images = _format_for_summary(msgs) + assert "User: hello" in text + assert images == [] + + def test_assistant_message(self): + msgs = [Message(role="assistant", content="world")] + text, images = _format_for_summary(msgs) + assert "Assistant: world" in text + + def test_tool_call_block(self): + msgs = [ + Message( + role="assistant", + tool_calls=[ToolCallRequest(id="1", name="fs", arguments={"path": "/tmp"})], + ), + Message(role="tool", content="file data", name="fs", tool_call_id="1"), + ] + text, images = _format_for_summary(msgs) + assert "[Tool call: fs" in text + assert "[Tool result: fs" in text + + def test_images_collected(self): + msgs = [Message(role="user", content="look", images=["base64img"])] + text, images = _format_for_summary(msgs) + assert images == ["base64img"] + assert "[+ 1 image(s)]" in text + + def test_summary_message_folded(self): + msgs = [Message(role="user", content="old summary", is_summary=True)] + text, _ = _format_for_summary(msgs) + assert "old summary" in text + + +class TestCompressContext: + async def test_nothing_to_compress(self): + backend = FakeLLMBackend() + result = await compress_context( + context=[Message(role="user", content="hi")], + llm=backend, + model="test", + temperature=0.3, + keep_recent=8, + ) + assert result is None + + async def test_compresses_old_turns(self): + backend = FakeLLMBackend(responses=["Summary of old stuff"]) + context = [ + Message(role="system", content="sys"), + Message(role="user", content="1"), + Message(role="assistant", content="a1"), + Message(role="user", content="2"), + Message(role="assistant", content="a2"), + Message(role="user", content="3"), + Message(role="assistant", content="a3"), + ] + new_context, summary = await compress_context( + context=context, + llm=backend, + model="test", + temperature=0.3, + keep_recent=2, + ) + assert summary == "Summary of old stuff" + # system + summary + 2 recent turns (user+assistant × 2) = 6 + assert len(new_context) == 6 + assert new_context[0].role == "system" + assert new_context[1].is_summary is True + + async def test_preserves_system_messages(self): + backend = FakeLLMBackend(responses=["sum"]) + context = [ + Message(role="system", content="s1"), + Message(role="system", content="s2"), + Message(role="user", content="1"), + Message(role="assistant", content="a1"), + Message(role="user", content="2"), + Message(role="assistant", content="a2"), + ] + new_context, _ = await compress_context( + context=context, + llm=backend, + model="test", + temperature=0.3, + keep_recent=1, + ) + system_msgs = [m for m in new_context if m.role == "system"] + assert len(system_msgs) == 2 diff --git a/tests/unit/core/test_context_builder.py b/tests/unit/core/test_context_builder.py new file mode 100644 index 0000000..29d9c1b --- /dev/null +++ b/tests/unit/core/test_context_builder.py @@ -0,0 +1,100 @@ +"""Unit tests for ContextBuilder.""" + +import pytest + +from navi.core.context_builder import ContextBuilder +from navi.llm.base import Message +from tests.conftest_factory import make_profile, make_profile_registry + + +class TestBuildSystemPrompt: + def test_includes_persona(self): + import navi.config as _config + + _config.settings.navi_persona = "You are Navi." + _config.settings.navi_persona_file = "" + builder = ContextBuilder(profile_registry=make_profile_registry()) + profile = make_profile("test") + prompt = builder.build_system_prompt(profile) + assert "You are Navi." in prompt + assert profile.system_prompt in prompt + + def test_includes_other_profiles(self): + reg = make_profile_registry() + builder = ContextBuilder(profile_registry=reg) + profile = reg.get("secretary") + prompt = builder.build_system_prompt(profile) + assert "## Available profiles" in prompt + assert "developer" in prompt + + def test_cache_returns_same_object(self): + reg = make_profile_registry() + builder = ContextBuilder(profile_registry=reg) + profile = reg.get("secretary") + p1 = builder.build_system_prompt(profile) + p2 = builder.build_system_prompt(profile) + assert p1 is p2 # cached + + def test_invalidate_cache(self): + reg = make_profile_registry() + builder = ContextBuilder(profile_registry=reg) + profile = reg.get("secretary") + p1 = builder.build_system_prompt(profile) + builder.invalidate_system_prompt_cache(profile.id) + p2 = builder.build_system_prompt(profile) + assert p1 == p2 + assert p1 is not p2 # cache was invalidated + + +class TestBuildGoalAnchor: + def test_includes_original_request(self): + builder = ContextBuilder(profile_registry=make_profile_registry()) + msg = builder._build_goal_anchor("sess-1", "Write tests") + assert msg.role == "system" + assert "Original request: Write tests" in msg.content + assert "Stay on track" in msg.content + + +class TestBuild: + def test_puts_system_first(self): + builder = ContextBuilder(profile_registry=make_profile_registry()) + profile = make_profile("test") + context = [Message(role="user", content="hi")] + result = builder.build(context, profile, mem=None) + assert result[0].role == "system" + + def test_injects_memory(self): + builder = ContextBuilder(profile_registry=make_profile_registry()) + profile = make_profile("test") + mem = Message(role="system", content="I remember you.") + context = [Message(role="user", content="hi")] + result = builder.build(context, profile, mem=mem) + assert result[1] == mem + + def test_strips_existing_system(self): + builder = ContextBuilder(profile_registry=make_profile_registry()) + profile = make_profile("test") + context = [ + Message(role="system", content="old"), + Message(role="user", content="hi"), + ] + result = builder.build(context, profile, mem=None) + system_msgs = [m for m in result if m.role == "system"] + assert len(system_msgs) == 1 # only the new system prompt + + def test_iteration_budget_injection(self): + builder = ContextBuilder(profile_registry=make_profile_registry()) + profile = make_profile("test", iteration_budget_enabled=True) + context = [Message(role="user", content="hi")] + result = builder.build(context, profile, mem=None, iteration=7, max_iterations=10) + last = result[-1] + assert last.role == "system" + assert "Iteration 8/10" in last.content + assert "3 remaining" in last.content + + def test_critical_urgency(self): + builder = ContextBuilder(profile_registry=make_profile_registry()) + profile = make_profile("test", iteration_budget_enabled=True) + context = [Message(role="user", content="hi")] + result = builder.build(context, profile, mem=None, iteration=9, max_iterations=10) + assert "CRITICAL" in result[-1].content diff --git a/tests/unit/core/test_events.py b/tests/unit/core/test_events.py new file mode 100644 index 0000000..82153e9 --- /dev/null +++ b/tests/unit/core/test_events.py @@ -0,0 +1,157 @@ +"""Unit tests for event dataclass wire serialization.""" + +import pytest + +from navi.core.events import ( + AIHelperTokensUsed, + ContextCompressed, + PlanReady, + PlanningDebugData, + PlanningStatus, + ProfileSwitched, + StreamEnd, + StreamStopped, + SubagentComplete, + TextDelta, + ThinkingDelta, + ThinkingEnd, + ToolEvent, + ToolStarted, + TurnThinking, +) + + +class TestToolStarted: + def test_to_wire(self): + ev = ToolStarted(tool_name="fs", arguments={"action": "read"}) + assert ev.to_wire() == { + "type": "tool_started", + "tool": "fs", + "args": {"action": "read"}, + "is_subagent": False, + } + + def test_to_wire_subagent(self): + ev = ToolStarted(tool_name="spawn_agent", arguments={}, is_subagent=True) + assert ev.to_wire()["is_subagent"] is True + + +class TestToolEvent: + def test_to_wire(self): + ev = ToolEvent( + tool_name="fs", + arguments={"path": "/tmp"}, + result="file contents", + success=True, + metadata={"size": 42}, + ) + assert ev.to_wire() == { + "type": "tool_call", + "tool": "fs", + "args": {"path": "/tmp"}, + "result": "file contents", + "success": True, + "is_subagent": False, + "metadata": {"size": 42}, + } + + +class TestTextDelta: + def test_to_wire(self): + assert TextDelta(delta="hello").to_wire() == {"type": "stream_delta", "delta": "hello"} + + +class TestThinkingDelta: + def test_to_wire(self): + assert ThinkingDelta(delta="hmm").to_wire() == {"type": "thinking_delta", "delta": "hmm"} + + +class TestThinkingEnd: + def test_to_wire(self): + assert ThinkingEnd().to_wire() == {"type": "thinking_end"} + + +class TestStreamEnd: + def test_to_wire(self): + ev = StreamEnd( + full_content="done", + context_tokens=150, + max_context_tokens=8192, + elapsed_seconds=1.5, + tool_call_count=2, + token_count=150, + ) + wire = ev.to_wire() + assert wire["type"] == "stream_end" + assert wire["content"] == "done" + assert wire["context_tokens"] == 150 + assert wire["tool_call_count"] == 2 + + +class TestStreamStopped: + def test_to_wire(self): + assert StreamStopped().to_wire() == {"type": "stream_stopped"} + + +class TestContextCompressed: + def test_to_wire(self): + ev = ContextCompressed(messages_before=10, messages_after=3, summary="summary text") + wire = ev.to_wire() + assert wire["type"] == "context_compressed" + assert wire["messages_before"] == 10 + assert wire["messages_after"] == 3 + assert wire["summary"] == "summary text" + + +class TestProfileSwitched: + def test_to_wire(self): + ev = ProfileSwitched(profile_id="dev", profile_name="Developer") + assert ev.to_wire() == { + "type": "profile_switched", + "profile_id": "dev", + "profile_name": "Developer", + } + + +class TestPlanningStatus: + def test_to_wire(self): + ev = PlanningStatus(phase=1, label="Working on it...") + assert ev.to_wire() == { + "type": "planning_status", + "phase": 1, + "label": "Working on it...", + "is_subagent": False, + } + + +class TestPlanReady: + def test_to_wire(self): + ev = PlanReady(plan="1. Do thing") + wire = ev.to_wire() + assert wire["type"] == "plan_ready" + assert wire["plan"] == "1. Do thing" + assert wire["is_subagent"] is False + + +class TestTurnThinking: + def test_to_wire(self): + ev = TurnThinking(thinking="reasoning...", is_subagent=True) + wire = ev.to_wire() + assert wire["type"] == "turn_thinking" + assert wire["thinking"] == "reasoning..." + assert wire["is_subagent"] is True + + +class TestInternalEvents: + """Internal events must NOT serialize to the wire.""" + + def test_subagent_complete(self): + assert SubagentComplete(token_count=42).to_wire() is None + + def test_planning_debug_data(self): + assert PlanningDebugData(log={"phases": {}}).to_wire() is None + + def test_ai_helper_tokens_used(self): + ev = AIHelperTokensUsed(prompt_tokens=10, completion_tokens=20) + assert ev.to_wire() is None + assert ev.total == 30 diff --git a/tests/unit/core/test_registry.py b/tests/unit/core/test_registry.py new file mode 100644 index 0000000..9737488 --- /dev/null +++ b/tests/unit/core/test_registry.py @@ -0,0 +1,79 @@ +"""Unit tests for ToolRegistry, ProfileRegistry, BackendRegistry.""" + +import pytest + +from navi.core.registry import BackendRegistry, ProfileRegistry, ToolRegistry +from navi.exceptions import ProfileNotFound, ToolNotFound +from tests.conftest_factory import FakeLLMBackend, FakeTool, make_profile + + +class TestToolRegistry: + def test_register_and_get(self): + reg = ToolRegistry() + tool = FakeTool("test") + reg.register(tool) + assert reg.get("test") is tool + + def test_get_missing_raises(self): + reg = ToolRegistry() + with pytest.raises(ToolNotFound): + reg.get("missing") + + def test_resolve(self): + reg = ToolRegistry() + reg.register(FakeTool("a"), builtin=True) + reg.register(FakeTool("b")) + tools = reg.resolve(["a", "b"]) + assert [t.name for t in tools] == ["a", "b"] + + def test_all(self): + reg = ToolRegistry() + reg.register(FakeTool("a")) + reg.register(FakeTool("b")) + assert len(reg.all()) == 2 + + def test_builtin_names_preserved_on_reload(self): + reg = ToolRegistry() + reg.register(FakeTool("builtin"), builtin=True) + reg.register(FakeTool("user")) + # reload_user_tools would drop non-builtins — simulate by checking names + assert "builtin" in reg._builtin_names + assert "user" not in reg._builtin_names + + +class TestProfileRegistry: + def test_register_and_get(self): + reg = ProfileRegistry() + p = make_profile("dev") + reg.register(p) + assert reg.get("dev") is p + + def test_get_missing_raises(self): + reg = ProfileRegistry() + with pytest.raises(ProfileNotFound): + reg.get("missing") + + def test_all(self): + reg = ProfileRegistry() + reg.register(make_profile("a")) + reg.register(make_profile("b")) + assert len(reg.all()) == 2 + + +class TestBackendRegistry: + def test_register_and_get(self): + reg = BackendRegistry() + backend = FakeLLMBackend() + reg.register("ollama", backend) + assert reg.get("ollama") is backend + + def test_get_missing_raises(self): + reg = BackendRegistry() + with pytest.raises(KeyError, match="not registered"): + reg.get("missing") + + def test_all_keys(self): + reg = BackendRegistry() + reg.register("a", FakeLLMBackend()) + reg.register("b", FakeLLMBackend()) + assert sorted(reg.all_keys()) == ["a", "b"] diff --git a/tests/unit/profiles/test_base.py b/tests/unit/profiles/test_base.py new file mode 100644 index 0000000..bfe136e --- /dev/null +++ b/tests/unit/profiles/test_base.py @@ -0,0 +1,81 @@ +"""Unit tests for AgentProfile Pydantic model.""" + +import pytest +from pydantic import ValidationError + +from navi.profiles.base import AgentProfile + + +class TestModelCoercion: + def test_model_string_to_list(self): + p = AgentProfile( + id="test", + name="Test", + description="desc", + system_prompt="sys", + enabled_tools=[], + model="gemma4:31b-cloud", + ) + assert p.model == ["gemma4:31b-cloud"] + + def test_model_list_passthrough(self): + p = AgentProfile( + id="test", + name="Test", + description="desc", + system_prompt="sys", + enabled_tools=[], + model=["a", "b"], + ) + assert p.model == ["a", "b"] + + def test_empty_model_defaults(self): + p = AgentProfile( + id="test", + name="Test", + description="desc", + system_prompt="sys", + enabled_tools=[], + model="", + ) + assert p.model == ["gemma4:31b-cloud"] + + +class TestDefaults: + def test_default_flags(self): + p = AgentProfile( + id="test", + name="Test", + description="desc", + system_prompt="sys", + enabled_tools=[], + ) + assert p.think_enabled is True + assert p.planning_enabled is False + assert p.planning_phase2_enabled is False + assert p.iteration_budget_enabled is True + assert p.anti_stall_enabled is True + assert p.anti_stall_threshold == 8 + + def test_max_iterations_default(self): + p = AgentProfile( + id="test", + name="Test", + description="desc", + system_prompt="sys", + enabled_tools=[], + ) + assert p.max_iterations == 10 + + +class TestExtraFields: + def test_extra_fields_allowed(self): + p = AgentProfile( + id="test", + name="Test", + description="desc", + system_prompt="sys", + enabled_tools=[], + custom_field="whatever", + ) + assert p.model_dump()["custom_field"] == "whatever"