Newer
Older
navi-1 / navi / api / deps.py
@Eugene Sukhodolskiy Eugene Sukhodolskiy on 15 May 5 KB Add self-recall (scheduled callback) system
"""FastAPI dependency injection — provides shared singletons to route handlers."""

from typing import Annotated

from fastapi import Depends

from navi.config import settings
from navi.context_providers._loader import ContextProviderRegistry
from navi.core import (
    Agent,
    BackendRegistry,
    PgSessionStore,
    ProfileRegistry,
    SessionStore,
    ToolRegistry,
    build_default_registries,
)
from navi.core.scheduler import RecallScheduler
from navi.llm.ollama import OllamaBackend
from navi.auth.deps import (
    get_current_user,
    get_current_user_ws,
    require_admin,
    require_permission,
    require_user,
)
from navi.memory import MemoryStore
from navi.workers import Worker, build_default_workers
from navi.mcp import McpManager, load_mcp_servers
from navi.mcp.tools import McpTool


def _make_session_store() -> SessionStore:
    if not settings.database_url:
        raise RuntimeError("DATABASE_URL is required. SQLite support has been removed.")
    return PgSessionStore(settings.database_url)


def _make_memory_store() -> MemoryStore:
    if not settings.database_url:
        raise RuntimeError("DATABASE_URL is required. SQLite support has been removed.")
    return MemoryStore(settings.database_url)


_memory_store: MemoryStore | None = None
_registries: tuple[ToolRegistry, ProfileRegistry, BackendRegistry, ContextProviderRegistry] | None = None
_mcp_manager: McpManager | None = None
_scheduler: RecallScheduler | None = None


def get_memory_store() -> MemoryStore:
    global _memory_store
    if _memory_store is None:
        _memory_store = _make_memory_store()
    return _memory_store


def get_scheduler() -> RecallScheduler:
    global _scheduler
    if _scheduler is None:
        if not settings.database_url:
            raise RuntimeError("DATABASE_URL is required for RecallScheduler.")
        _scheduler = RecallScheduler(settings.database_url)
    return _scheduler


def get_registries() -> tuple[ToolRegistry, ProfileRegistry, BackendRegistry, ContextProviderRegistry]:
    global _registries
    if _registries is None:
        _registries = build_default_registries(
            memory_store=get_memory_store(),
            session_store=get_session_store(),
            scheduler=get_scheduler(),
        )
        # Wire embedding backend into memory store for vector search.
        # Uses a dedicated Ollama endpoint when configured, otherwise falls back
        # to the main chat backend.
        try:
            if settings.embedding_ollama_host:
                emb_backend = OllamaBackend(
                    model=settings.embedding_model,
                    host=settings.embedding_ollama_host,
                    api_key=settings.embedding_ollama_api_key,
                    timeout=settings.ollama_request_timeout,
                )
            else:
                emb_backend = _registries[2].get("ollama")
            if hasattr(_memory_store, "set_embedding_backend"):
                _memory_store.set_embedding_backend(emb_backend)
        except Exception:
            pass
    return _registries


def get_tool_registry() -> ToolRegistry:
    return get_registries()[0]


def get_profile_registry() -> ProfileRegistry:
    return get_registries()[1]


def get_backend_registry() -> BackendRegistry:
    return get_registries()[2]


def get_cp_registry() -> ContextProviderRegistry:
    return get_registries()[3]


async def get_mcp_manager() -> McpManager:
    global _mcp_manager
    if _mcp_manager is None:
        _mcp_manager = McpManager()
        await _mcp_manager.load_all()
    return _mcp_manager


async def register_mcp_tools(registry: ToolRegistry, manager: McpManager) -> None:
    """Discover tools from all connected MCP servers and register them as external."""
    # clear previous external MCP tools
    for name in list(registry._external_names):
        if name.startswith("mcp:"):
            registry.unregister_external(name)

    tools = await manager.get_all_tools()
    for server_name, tool in tools:
        mcp_tool = McpTool(
            server_name=server_name,
            tool_name=tool.name,
            description=tool.description or "",
            parameters=tool.inputSchema,
            manager=manager,
        )
        registry.register_external(mcp_tool)


_session_store: SessionStore | None = None
_workers: list[Worker] | None = None


def get_session_store() -> SessionStore:
    global _session_store
    if _session_store is None:
        _session_store = _make_session_store()
    return _session_store


def get_workers() -> list[Worker]:
    global _workers
    if _workers is None:
        _workers = build_default_workers()
    return _workers


def get_agent(
    session_store: Annotated[SessionStore, Depends(get_session_store)],
    profile_registry: Annotated[ProfileRegistry, Depends(get_profile_registry)],
    tool_registry: Annotated[ToolRegistry, Depends(get_tool_registry)],
    backend_registry: Annotated[BackendRegistry, Depends(get_backend_registry)],
    cp_registry: Annotated[ContextProviderRegistry, Depends(get_cp_registry)],
) -> Agent:
    return Agent(
        session_store, profile_registry, tool_registry, backend_registry,
        workers=get_workers(), memory_store=get_memory_store(),
        cp_registry=cp_registry, mcp_manager=_mcp_manager,
    )