diff --git a/docs/memory.md b/docs/memory.md index 34477f5..1dcc021 100644 --- a/docs/memory.md +++ b/docs/memory.md @@ -6,12 +6,18 @@ The memory system requires **PostgreSQL** with two extensions: -| Extension | Purpose | -|---|---| -| `vector` (pgvector) | Semantic search via cosine distance on `embedding vector(768)` | -| `pg_trgm` | Fast ILIKE fallback via GIN trigram indexes on `category`, `key`, `value` | +| Extension | Purpose | Auto-created by app? | +|---|---|---| +| `vector` (pgvector) | Semantic search via cosine distance on `embedding vector(768)` | Yes (`CREATE EXTENSION IF NOT EXISTS`) | +| `pg_trgm` | Fast ILIKE fallback via GIN trigram indexes on `category`, `key`, `value` | **No — must be installed by DBA** | -Both extensions are created automatically on startup (`CREATE EXTENSION IF NOT EXISTS`). +**pgvector** is created automatically because the app typically runs with sufficient privileges on its own database. **pg_trgm** is a core PostgreSQL extension that may require superuser privileges to install. If it is already installed, GIN trigram indexes are created automatically; if not, the app falls back to plain ILIKE without indexes (functional but slower on large tables). + +To install pg_trgm manually: + +```sql +CREATE EXTENSION IF NOT EXISTS pg_trgm; +``` | Feature | SQLite | PostgreSQL | |---|---|---| diff --git a/navi/api/websocket.py b/navi/api/websocket.py index 6410ca9..f206338 100644 --- a/navi/api/websocket.py +++ b/navi/api/websocket.py @@ -24,8 +24,9 @@ from fastapi import APIRouter, WebSocket, WebSocketDisconnect from navi.api.deps import get_session_store -from navi.core import Agent, ContextCompressed, StreamEnd, TextDelta, ThinkingDelta, ThinkingEnd, ToolEvent -from navi.core.events import PlanningStatus, PlanReady, ProfileSwitched, StreamStopped, ToolStarted, TurnThinking +from navi.core import Agent +from navi.core.event_bus import get_event_bus +from navi.core.events import AgentEvent from navi.exceptions import MaxIterationsReached, NaviError, SessionNotFound router = APIRouter(tags=["websocket"]) @@ -78,56 +79,8 @@ # ── Helpers ─────────────────────────────────────────────────────────────────── def _event_to_dict(event) -> dict | None: - if isinstance(event, ThinkingDelta): - return {"type": "thinking_delta", "delta": event.delta} - if isinstance(event, ThinkingEnd): - return {"type": "thinking_end"} - if isinstance(event, TextDelta): - return {"type": "stream_delta", "delta": event.delta} - if isinstance(event, ToolStarted): - return { - "type": "tool_started", - "tool": event.tool_name, - "args": event.arguments, - "is_subagent": event.is_subagent, - } - if isinstance(event, ToolEvent): - return { - "type": "tool_call", - "tool": event.tool_name, - "args": event.arguments, - "result": event.result, - "success": event.success, - "is_subagent": event.is_subagent, - "metadata": event.metadata, - } - if isinstance(event, StreamEnd): - return { - "type": "stream_end", - "content": event.full_content, - "context_tokens": event.context_tokens, - "max_context_tokens": event.max_context_tokens, - "elapsed_seconds": event.elapsed_seconds, - "tool_call_count": event.tool_call_count, - "token_count": event.token_count, - } - if isinstance(event, ContextCompressed): - return { - "type": "context_compressed", - "messages_before": event.messages_before, - "messages_after": event.messages_after, - "summary": event.summary, - } - if isinstance(event, TurnThinking): - return {"type": "turn_thinking", "thinking": event.thinking, "is_subagent": event.is_subagent} - if isinstance(event, ProfileSwitched): - return {"type": "profile_switched", "profile_id": event.profile_id, "profile_name": event.profile_name} - if isinstance(event, StreamStopped): - return {"type": "stream_stopped"} - if isinstance(event, PlanningStatus): - return {"type": "planning_status", "phase": event.phase, "label": event.label, "is_subagent": event.is_subagent} - if isinstance(event, PlanReady): - return {"type": "plan_ready", "plan": event.plan, "is_subagent": event.is_subagent} + if hasattr(event, "to_wire"): + return event.to_wire() return None @@ -150,6 +103,7 @@ async for event in agent.run_stream( session_id, user_content, images=raw_images, display_message=display_content ): + await get_event_bus().publish(event) await run.broadcast(("event", event)) except asyncio.CancelledError: log.info("ws.agent_stopped", session_id=session_id) diff --git a/navi/config.py b/navi/config.py index 6657258..29ead09 100644 --- a/navi/config.py +++ b/navi/config.py @@ -22,6 +22,8 @@ embedding_dimensions: int = 768 openai_api_key: str = "" + openai_model: str = "gpt-4" + openai_base_url: str | None = None anthropic_api_key: str = "" # Web search fallbacks (used when DuckDuckGo returns no results) diff --git a/navi/core/agent.py b/navi/core/agent.py index 5ad6b52..0e59591 100644 --- a/navi/core/agent.py +++ b/navi/core/agent.py @@ -1581,6 +1581,7 @@ self, tool_calls: list[ToolCallRequest], tools: list[Tool] ) -> tuple[list[Message], list[Message]]: tool_map = {t.name: t for t in tools} + middlewares = getattr(self._tools, "_middlewares", []) async def _run_one(tc: ToolCallRequest) -> tuple[Message, Message | None]: tool = tool_map.get(tc.name) @@ -1590,7 +1591,11 @@ content = f"Error: tool '{tc.name}' not found." else: log.info("tool.execute", tool=tc.name, args=tc.arguments) + for mw in middlewares: + await mw.before_execute(tc.name, tc.arguments) result = await tool.execute(tc.arguments) + for mw in middlewares: + await mw.after_execute(tc.name, tc.arguments, result) content = result.to_message_content() metadata = result.metadata or {} if result.success and result.metadata and result.metadata.get("is_image"): @@ -1613,6 +1618,7 @@ self, tool_calls: list[ToolCallRequest], tools: list[Tool] ) -> tuple[list[tuple[ToolEvent, Message]], list[Message]]: tool_map = {t.name: t for t in tools} + middlewares = getattr(self._tools, "_middlewares", []) async def _run_one(tc: ToolCallRequest) -> tuple[ToolEvent, Message, Message | None]: tool = tool_map.get(tc.name) @@ -1625,7 +1631,13 @@ ) else: log.info("tool.execute", tool=tc.name, args=tc.arguments) + # Run middleware before + for mw in middlewares: + await mw.before_execute(tc.name, tc.arguments) result = await tool.execute(tc.arguments) + # Run middleware after + for mw in middlewares: + await mw.after_execute(tc.name, tc.arguments, result) content = result.to_message_content() metadata = result.metadata or {} event = ToolEvent( diff --git a/navi/core/event_bus.py b/navi/core/event_bus.py new file mode 100644 index 0000000..b453ad9 --- /dev/null +++ b/navi/core/event_bus.py @@ -0,0 +1,67 @@ +"""Async event bus — pub/sub for AgentEvents. + +Allows external modules to subscribe to tool calls, completions, etc. +without modifying the WebSocket handler. +""" + +from __future__ import annotations + +import asyncio +from collections import defaultdict +from typing import Awaitable, Callable + +from navi.core.events import AgentEvent + +Subscriber = Callable[[AgentEvent], Awaitable[None]] + + +class EventBus: + """Simple async pub/sub broker for AgentEvents.""" + + def __init__(self) -> None: + self._subs: defaultdict[type, list[Subscriber]] = defaultdict(list) + self._all_subs: list[Subscriber] = [] + + def subscribe(self, event_type: type | None, callback: Subscriber) -> None: + """Subscribe to a specific event type (or all events if event_type is None).""" + if event_type is None: + self._all_subs.append(callback) + else: + self._subs[event_type].append(callback) + + def unsubscribe(self, event_type: type | None, callback: Subscriber) -> None: + """Remove a subscriber.""" + if event_type is None: + self._all_subs[:] = [s for s in self._all_subs if s is not callback] + else: + self._subs[event_type][:] = [s for s in self._subs[event_type] if s is not callback] + + async def publish(self, event: AgentEvent) -> None: + """Publish an event to all matching subscribers.""" + tasks: list[asyncio.Task] = [] + for sub in self._all_subs: + tasks.append(asyncio.create_task(sub(event))) + for etype, subs in self._subs.items(): + if isinstance(event, etype): + for sub in subs: + tasks.append(asyncio.create_task(sub(event))) + if tasks: + await asyncio.gather(*tasks, return_exceptions=True) + + +# Global default bus — used by agent and WebSocket +_default_bus: EventBus | None = None + + +def get_event_bus() -> EventBus: + """Return the global event bus (lazy singleton).""" + global _default_bus + if _default_bus is None: + _default_bus = EventBus() + return _default_bus + + +def set_event_bus(bus: EventBus) -> None: + """Replace the global bus (useful for testing).""" + global _default_bus + _default_bus = bus diff --git a/navi/core/events.py b/navi/core/events.py index 4d59b92..49ef0fb 100644 --- a/navi/core/events.py +++ b/navi/core/events.py @@ -11,6 +11,14 @@ arguments: dict is_subagent: bool = False # True when emitted from inside run_ephemeral + def to_wire(self) -> dict: + return { + "type": "tool_started", + "tool": self.tool_name, + "args": self.arguments, + "is_subagent": self.is_subagent, + } + @dataclass class ToolEvent: @@ -23,6 +31,17 @@ is_subagent: bool = False # True when emitted from inside run_ephemeral metadata: dict = field(default_factory=dict) # Extra data for client rendering + def to_wire(self) -> dict: + return { + "type": "tool_call", + "tool": self.tool_name, + "args": self.arguments, + "result": self.result, + "success": self.success, + "is_subagent": self.is_subagent, + "metadata": self.metadata, + } + @dataclass class TextDelta: @@ -30,6 +49,9 @@ delta: str + def to_wire(self) -> dict: + return {"type": "stream_delta", "delta": self.delta} + @dataclass class ThinkingDelta: @@ -37,11 +59,17 @@ delta: str + def to_wire(self) -> dict: + return {"type": "thinking_delta", "delta": self.delta} + @dataclass class ThinkingEnd: """Marks the end of the thinking phase.""" + def to_wire(self) -> dict: + return {"type": "thinking_end"} + @dataclass class StreamEnd: @@ -54,11 +82,25 @@ tool_call_count: int = 0 token_count: int | None = None # same as context_tokens; kept separate for clarity + def to_wire(self) -> dict: + return { + "type": "stream_end", + "content": self.full_content, + "context_tokens": self.context_tokens, + "max_context_tokens": self.max_context_tokens, + "elapsed_seconds": self.elapsed_seconds, + "tool_call_count": self.tool_call_count, + "token_count": self.token_count, + } + @dataclass class StreamStopped: """Emitted when the user stops generation mid-stream (cooperative stop).""" + def to_wire(self) -> dict: + return {"type": "stream_stopped"} + @dataclass class ContextCompressed: @@ -68,6 +110,14 @@ messages_after: int summary: str = "" # the actual summary text produced by the LLM + def to_wire(self) -> dict: + return { + "type": "context_compressed", + "messages_before": self.messages_before, + "messages_after": self.messages_after, + "summary": self.summary, + } + @dataclass class ProfileSwitched: @@ -76,6 +126,13 @@ profile_id: str profile_name: str + def to_wire(self) -> dict: + return { + "type": "profile_switched", + "profile_id": self.profile_id, + "profile_name": self.profile_name, + } + @dataclass class PlanningStatus: @@ -90,6 +147,14 @@ label: str is_subagent: bool = False + def to_wire(self) -> dict: + return { + "type": "planning_status", + "phase": self.phase, + "label": self.label, + "is_subagent": self.is_subagent, + } + @dataclass class PlanReady: @@ -103,6 +168,13 @@ plan: str is_subagent: bool = False + def to_wire(self) -> dict: + return { + "type": "plan_ready", + "plan": self.plan, + "is_subagent": self.is_subagent, + } + @dataclass class TurnThinking: @@ -116,6 +188,13 @@ thinking: str is_subagent: bool = False + def to_wire(self) -> dict: + return { + "type": "turn_thinking", + "thinking": self.thinking, + "is_subagent": self.is_subagent, + } + @dataclass class SubagentComplete: @@ -125,6 +204,9 @@ token_count: int = 0 tool_call_count: int = 0 + def to_wire(self) -> dict | None: + return None # internal only + @dataclass class PlanningDebugData: @@ -133,6 +215,9 @@ log: dict # {timestamp, result, phases: {1: {output, prompt_tokens, completion_tokens}, ...}} + def to_wire(self) -> dict | None: + return None # internal only + @dataclass class AIHelperTokensUsed: @@ -146,6 +231,9 @@ def total(self) -> int: return self.prompt_tokens + self.completion_tokens + def to_wire(self) -> dict | None: + return None # internal only + AgentEvent = ( ToolStarted | ToolEvent | TextDelta | ThinkingDelta | ThinkingEnd diff --git a/navi/core/registry.py b/navi/core/registry.py index bf3ab3d..4cb3940 100644 --- a/navi/core/registry.py +++ b/navi/core/registry.py @@ -34,6 +34,7 @@ from navi.tools.share_file import ShareFileTool from navi.tools.content_publish import ContentPublishTool from navi.tools.loader import LoadResult, load_tools_from_dir +from navi.tools.logging_middleware import LoggingMiddleware from navi.context_providers._loader import ContextProviderRegistry @@ -41,12 +42,17 @@ def __init__(self) -> None: self._tools: dict[str, Tool] = {} self._builtin_names: set[str] = set() + self._middlewares: list = [] def register(self, tool: Tool, builtin: bool = False) -> None: self._tools[tool.name] = tool if builtin: self._builtin_names.add(tool.name) + def add_middleware(self, middleware) -> None: + """Add a ToolMiddleware instance.""" + self._middlewares.append(middleware) + def get(self, name: str) -> Tool: if name not in self._tools: raise ToolNotFound(name) @@ -104,6 +110,35 @@ return list(self._backends.keys()) +def _discover_backends() -> list[tuple[str, LLMBackend]]: + """Auto-discover LLM backends from navi/llm/ modules.""" + discovered: list[tuple[str, LLMBackend]] = [] + from navi.llm.ollama import OllamaBackend + from navi.llm.fallback import FallbackOllamaBackend + from navi.llm.openai_backend import OpenAIBackend + + # Ollama backend (primary) + if settings.ollama_backends_file: + servers = load_servers_from_file(settings.ollama_backends_file) + discovered.append(("ollama", FallbackOllamaBackend(servers))) + else: + discovered.append(("ollama", OllamaBackend( + model=settings.ollama_default_model, + host=settings.ollama_host, + api_key=settings.ollama_api_key, + ))) + + # OpenAI backend (if configured) + if settings.openai_api_key: + discovered.append(("openai", OpenAIBackend( + model=settings.openai_model, + api_key=settings.openai_api_key, + base_url=settings.openai_base_url, + ))) + + return discovered + + def build_default_registries( memory_store=None, session_store=None, @@ -111,19 +146,18 @@ """Build and populate registries with all built-in components.""" from navi.core.ai_helper import AIHelper - # Backends are needed by AIHelper — build early. - # Use FallbackOllamaBackend when a backends file is configured. - if settings.ollama_backends_file: - servers = load_servers_from_file(settings.ollama_backends_file) - ollama_backend: LLMBackend = FallbackOllamaBackend(servers) - else: - ollama_backend = OllamaBackend( - model=settings.ollama_default_model, - host=settings.ollama_host, - api_key=settings.ollama_api_key, - ) + backends = BackendRegistry() + backend_instances = _discover_backends() + if not backend_instances: + raise RuntimeError("No LLM backends discovered. Check OLLAMA_HOST or OPENAI_API_KEY.") + + for key, backend in backend_instances: + backends.register(key, backend) + + # Use primary backend for AIHelper + primary_backend = backend_instances[0][1] ai_helper = AIHelper( - backend=ollama_backend, + backend=primary_backend, default_model=settings.ollama_default_model, ) @@ -149,6 +183,9 @@ for user_tool in result.loaded: tools.register(user_tool) + # Register built-in middleware + tools.add_middleware(LoggingMiddleware()) + profiles = ProfileRegistry() for p in ALL_PROFILES: profiles.register(p) @@ -172,8 +209,6 @@ list_profiles_tool = ListProfilesTool(profile_registry=profiles) tools.register(list_profiles_tool, builtin=True) - backends = BackendRegistry() - backends.register("ollama", ollama_backend) # Patch backend registry into spawn_tool now that it's available spawn_tool._backend_registry = backends diff --git a/navi/memory/store.py b/navi/memory/store.py index 0161675..61bc26f 100644 --- a/navi/memory/store.py +++ b/navi/memory/store.py @@ -38,7 +38,6 @@ if pgvector_available else "" ) stmts = [ - "CREATE EXTENSION IF NOT EXISTS pg_trgm", """CREATE TABLE IF NOT EXISTS memory_facts ( id TEXT PRIMARY KEY, category TEXT NOT NULL, @@ -57,9 +56,16 @@ )""" % embedding_col, "CREATE INDEX IF NOT EXISTS idx_memory_facts_expires ON memory_facts (expires_at) WHERE expires_at IS NOT NULL", "CREATE INDEX IF NOT EXISTS idx_memory_facts_source_cat ON memory_facts (source, category)", - "CREATE INDEX IF NOT EXISTS idx_memory_facts_cat_trgm ON memory_facts USING gin (category gin_trgm_ops)", - "CREATE INDEX IF NOT EXISTS idx_memory_facts_key_trgm ON memory_facts USING gin (key gin_trgm_ops)", - "CREATE INDEX IF NOT EXISTS idx_memory_facts_value_trgm ON memory_facts USING gin (value gin_trgm_ops)", + # GIN trigram indexes — only if pg_trgm extension is already installed. + # CREATE EXTENSION requires superuser/CREATE privilege, so we skip it here. + """DO $$ + BEGIN + IF EXISTS (SELECT 1 FROM pg_extension WHERE extname = 'pg_trgm') THEN + CREATE INDEX IF NOT EXISTS idx_memory_facts_cat_trgm ON memory_facts USING gin (category gin_trgm_ops); + CREATE INDEX IF NOT EXISTS idx_memory_facts_key_trgm ON memory_facts USING gin (key gin_trgm_ops); + CREATE INDEX IF NOT EXISTS idx_memory_facts_value_trgm ON memory_facts USING gin (value gin_trgm_ops); + END IF; + END $$;""", """CREATE TABLE IF NOT EXISTS memory_summary ( id INTEGER PRIMARY KEY DEFAULT 1, content TEXT NOT NULL, diff --git a/navi/profiles/base.py b/navi/profiles/base.py index ee96761..b313d79 100644 --- a/navi/profiles/base.py +++ b/navi/profiles/base.py @@ -1,14 +1,15 @@ -from dataclasses import dataclass, field +from pydantic import BaseModel, Field, field_validator -@dataclass -class AgentProfile: +class AgentProfile(BaseModel): """ Defines a complete agent configuration. A profile ties together a system prompt, an LLM backend, a model, and the set of tools the agent is allowed to use. """ + model_config = {"extra": "allow"} + id: str name: str description: str @@ -17,11 +18,8 @@ llm_backend: str = "ollama" # backend key, e.g. "ollama", "openai" # Ordered list of preferred models; first available wins at runtime. # Accepts a plain string for backward compatibility (auto-wrapped in a list). - model: list[str] = field(default_factory=lambda: ["gemma4:31b-cloud"]) + model: list[str] = Field(default_factory=lambda: ["gemma4:31b-cloud"]) - def __post_init__(self) -> None: - if isinstance(self.model, str): - self.model = [self.model] max_iterations: int = 10 temperature: float = 0.7 top_k: int | None = None @@ -35,7 +33,7 @@ # short_description: 1-line summary shown in every system prompt to all profiles. # full_description: structured dict with keys: specialization, when_to_use, key_tools. short_description: str = "" - full_description: dict = field(default_factory=dict) + full_description: dict = Field(default_factory=dict) # ── Thinking mechanics ──────────────────────────────────────────────────── # Each flag can be set per-profile in config.json to tune the balance @@ -90,10 +88,17 @@ # subagent_planning_enabled: if True, sub-agents run the planning phase before their tool loop. # subagent_system_prompt: injected as an additional system message for sub-agents, # after the profile's main system_prompt. Loaded from subagent_system_prompt.txt if present. - subagent_tools: list[str] = field(default_factory=list) + subagent_tools: list[str] = Field(default_factory=list) subagent_planning_enabled: bool = False subagent_system_prompt: str = "" # Extra context providers to inject for this profile (by name). # Global providers (global_provider=True) are always injected regardless of this list. - context_providers: list[str] = field(default_factory=list) + context_providers: list[str] = Field(default_factory=list) + + @field_validator("model", mode="before") + @classmethod + def _coerce_model(cls, v): + if isinstance(v, str): + return [v] if v else ["gemma4:31b-cloud"] + return v diff --git a/navi/tools/logging_middleware.py b/navi/tools/logging_middleware.py new file mode 100644 index 0000000..5fde94e --- /dev/null +++ b/navi/tools/logging_middleware.py @@ -0,0 +1,24 @@ +"""Built-in tool middleware — logs every tool call for observability.""" + +import structlog + +from navi.tools.middleware import ToolMiddleware +from navi.tools.base import ToolResult + +log = structlog.get_logger() + + +class LoggingMiddleware(ToolMiddleware): + """Logs every tool execution with duration and result summary.""" + + async def before_execute(self, tool_name: str, params: dict) -> None: + log.debug("middleware.tool.before", tool=tool_name, args=params) + + async def after_execute(self, tool_name: str, params: dict, result: ToolResult) -> None: + log.info( + "middleware.tool.after", + tool=tool_name, + success=result.success, + output_len=len(result.output), + has_error=bool(result.error), + ) diff --git a/navi/tools/middleware.py b/navi/tools/middleware.py new file mode 100644 index 0000000..f754f65 --- /dev/null +++ b/navi/tools/middleware.py @@ -0,0 +1,48 @@ +"""Tool middleware — pre/post execute hooks. + +Middleware runs around every tool call. Useful for logging, metrics, +rate limiting, and authorization without modifying individual tools. +""" + +from abc import ABC, abstractmethod +from typing import Awaitable, Callable + +from .base import ToolResult + +MiddlewareFunc = Callable[[str, dict], Awaitable[None]] +PostExecuteFunc = Callable[[str, dict, ToolResult], Awaitable[None]] + + +class ToolMiddleware(ABC): + """Base class for tool middleware. + + Subclasses override `before_execute` and/or `after_execute`. + """ + + async def before_execute(self, tool_name: str, params: dict) -> None: + """Called before the tool executes.""" + pass + + async def after_execute(self, tool_name: str, params: dict, result: ToolResult) -> None: + """Called after the tool executes.""" + pass + + +class MiddlewareChain: + """Chains multiple middleware instances around a tool execution.""" + + def __init__(self, middlewares: list[ToolMiddleware]) -> None: + self._middlewares = middlewares + + async def run(self, tool_name: str, params: dict, execute: Callable[[], Awaitable[ToolResult]]) -> ToolResult: + for mw in self._middlewares: + await mw.before_execute(tool_name, params) + try: + result = await execute() + finally: + # We don't have result on exception, but we can still call after_execute + # with a synthetic failure result if needed. For now, call only on success. + pass + for mw in self._middlewares: + await mw.after_execute(tool_name, params, result) + return result diff --git a/navi/workers/__init__.py b/navi/workers/__init__.py index f4003b1..c202042 100644 --- a/navi/workers/__init__.py +++ b/navi/workers/__init__.py @@ -3,7 +3,36 @@ def build_default_workers() -> list[Worker]: - return [CompressionWorker()] + """Auto-discover and instantiate all built-in workers.""" + import importlib + import inspect + from pathlib import Path + + workers: list[Worker] = [] + pkg_dir = Path(__file__).parent + + for py_file in sorted(pkg_dir.glob("*.py")): + if py_file.name.startswith("_"): + continue + mod_name = f"navi.workers.{py_file.stem}" + try: + mod = importlib.import_module(mod_name) + except Exception: + continue + for _name, obj in inspect.getmembers(mod): + if ( + inspect.isclass(obj) + and issubclass(obj, Worker) + and obj is not Worker + and not getattr(obj, "__abstractmethods__", None) + ): + try: + instance = obj() + workers.append(instance) + except Exception: + pass + + return workers __all__ = [