Newer
Older
navi-1 / navi / core / event_bus.py
"""Async event bus — pub/sub for AgentEvents.

Allows external modules to subscribe to tool calls, completions, etc.
without modifying the WebSocket handler.
"""

from __future__ import annotations

import asyncio
from collections import defaultdict
from typing import Awaitable, Callable

from navi.core.events import AgentEvent

Subscriber = Callable[[AgentEvent], Awaitable[None]]


class EventBus:
    """Simple async pub/sub broker for AgentEvents."""

    def __init__(self) -> None:
        self._subs: defaultdict[type, list[Subscriber]] = defaultdict(list)
        self._all_subs: list[Subscriber] = []

    def subscribe(self, event_type: type | None, callback: Subscriber) -> None:
        """Subscribe to a specific event type (or all events if event_type is None)."""
        if event_type is None:
            self._all_subs.append(callback)
        else:
            self._subs[event_type].append(callback)

    def unsubscribe(self, event_type: type | None, callback: Subscriber) -> None:
        """Remove a subscriber."""
        if event_type is None:
            self._all_subs[:] = [s for s in self._all_subs if s is not callback]
        else:
            self._subs[event_type][:] = [s for s in self._subs[event_type] if s is not callback]

    async def publish(self, event: AgentEvent) -> None:
        """Publish an event to all matching subscribers."""
        tasks: list[asyncio.Task] = []
        for sub in self._all_subs:
            tasks.append(asyncio.create_task(sub(event)))
        for etype, subs in self._subs.items():
            if isinstance(event, etype):
                for sub in subs:
                    tasks.append(asyncio.create_task(sub(event)))
        if tasks:
            await asyncio.gather(*tasks, return_exceptions=True)


# Global default bus — used by agent and WebSocket
_default_bus: EventBus | None = None


def get_event_bus() -> EventBus:
    """Return the global event bus (lazy singleton)."""
    global _default_bus
    if _default_bus is None:
        _default_bus = EventBus()
    return _default_bus


def set_event_bus(bus: EventBus) -> None:
    """Replace the global bus (useful for testing)."""
    global _default_bus
    _default_bus = bus