diff --git a/navi/context_providers/_loader.py b/navi/context_providers/_loader.py index e5e67c1..03f53b3 100644 --- a/navi/context_providers/_loader.py +++ b/navi/context_providers/_loader.py @@ -45,6 +45,9 @@ def all_names(self) -> list[str]: return list(self._providers) + def all(self) -> list[Any]: + return list(self._providers.values()) + def get_globals(self) -> list[Any]: return [p for p in self._providers.values() if getattr(p, "global_provider", False)] diff --git a/tests/unit/core/test_registry.py b/tests/unit/core/test_registry.py index 9737488..0e65b52 100644 --- a/tests/unit/core/test_registry.py +++ b/tests/unit/core/test_registry.py @@ -3,6 +3,7 @@ import pytest from navi.core.registry import BackendRegistry, ProfileRegistry, ToolRegistry +from navi.context_providers import ContextProviderRegistry from navi.exceptions import ProfileNotFound, ToolNotFound from tests.conftest_factory import FakeLLMBackend, FakeTool, make_profile @@ -77,3 +78,19 @@ reg.register("a", FakeLLMBackend()) reg.register("b", FakeLLMBackend()) assert sorted(reg.all_keys()) == ["a", "b"] + + +class TestContextProviderRegistry: + def test_all_returns_registered_providers(self): + reg = ContextProviderRegistry() + + class Provider: + name = "provider" + description = "Provider" + + async def get_context(self): + return None + + reg.register(Provider) + + assert reg.all() == [Provider]