Newer
Older
navi-1 / navi / core / orchestrator.py
"""Agent session orchestrator — manages active runs and recall lifecycle."""

from __future__ import annotations

import asyncio
import dataclasses
from dataclasses import dataclass, field
from datetime import datetime, timedelta, timezone
from typing import TYPE_CHECKING, Any, Callable

import structlog

from navi.config import settings

if TYPE_CHECKING:
    from fastapi import WebSocket
    from navi.core.container import AppContainer

log = structlog.get_logger()


@dataclass
class SessionRun:
    """Holds the running agent task and all active subscriber queues."""

    task: asyncio.Task | None = None
    stop_event: asyncio.Event = dataclasses.field(default_factory=asyncio.Event)
    subscribers: list[asyncio.Queue] = dataclasses.field(default_factory=list)
    # Replay buffer: all serialised event dicts emitted so far this turn.
    events: list[dict] = dataclasses.field(default_factory=list)

    def subscribe(self) -> asyncio.Queue:
        q: asyncio.Queue = asyncio.Queue()
        self.subscribers.append(q)
        return q

    def unsubscribe(self, q: asyncio.Queue) -> None:
        try:
            self.subscribers.remove(q)
        except ValueError:
            pass

    async def broadcast(self, item) -> None:
        kind, payload = item
        if kind == "event":
            ev_dict = _event_to_dict(payload)
            if ev_dict:
                self.events.append(ev_dict)
                if len(self.events) > settings.ws_replay_buffer_size:
                    self.events.pop(0)
        for q in list(self.subscribers):
            await q.put(item)


@dataclass
class SessionState:
    """All ephemeral in-memory state for a single session."""

    run: SessionRun | None = None
    busy_event: asyncio.Event | None = None
    websockets: list[Any] = field(default_factory=list)


def _event_to_dict(event) -> dict | None:
    if hasattr(event, "to_wire"):
        return event.to_wire()
    return None


class AgentSessionOrchestrator:
    """Owns all active agent runs, headless recall sessions, and connected transports."""

    def __init__(self, container: AppContainer) -> None:
        self._container = container
        self._sessions: dict[str, SessionState] = {}
        self._session_locks: dict[str, asyncio.Lock] = {}

        # Wire event bus subscriber so recall updates reach connected clients
        from navi.core.event_bus import get_event_bus
        from navi.core.events import RecallUpdate, McpStatusUpdate

        get_event_bus().subscribe(RecallUpdate, self._on_recall_update)
        get_event_bus().subscribe(McpStatusUpdate, self._on_mcp_status_update)

    async def _notify_session(self, session_id: str, payload: dict) -> None:
        """Send a JSON payload to every open WebSocket for the given session."""
        state = self._sessions.get(session_id)
        if state is None:
            return
        for ws in list(state.websockets):
            try:
                await ws.send_json(payload)
            except Exception:
                pass

    async def _broadcast_all_sessions(self, payload: dict) -> None:
        """Send a JSON payload to every open WebSocket across all sessions."""
        for session_id, state in self._sessions.items():
            for ws in list(state.websockets):
                try:
                    await ws.send_json(payload)
                except Exception:
                    pass

    async def _on_recall_update(self, event: Any) -> None:
        from navi.core.events import RecallUpdate

        if isinstance(event, RecallUpdate) and event.session_id:
            payload = event.to_wire()
            if payload:
                await self._notify_session(event.session_id, payload)

    async def _on_mcp_status_update(self, event: Any) -> None:
        from navi.core.events import McpStatusUpdate

        if isinstance(event, McpStatusUpdate):
            payload = event.to_wire()
            if payload:
                await self._broadcast_all_sessions(payload)

    def _get_or_create_state(self, session_id: str) -> SessionState:
        state = self._sessions.get(session_id)
        if state is None:
            state = SessionState()
            self._sessions[session_id] = state
        return state

    def _cleanup(self, session_id: str) -> None:
        """Remove session entry if it holds no run, busy flag, or websockets."""
        state = self._sessions.get(session_id)
        if state is None:
            return
        if state.run is None and state.busy_event is None and not state.websockets:
            self._sessions.pop(session_id, None)
            self._session_locks.pop(session_id, None)

    # ── Session lock ────────────────────────────────────────────────────────────

    def session_lock(self, session_id: str) -> asyncio.Lock:
        """Return (and create if needed) a per-session asyncio.Lock."""
        lock = self._session_locks.get(session_id)
        if lock is None:
            lock = asyncio.Lock()
            self._session_locks[session_id] = lock
        return lock

    # ── WebSocket tracking ──────────────────────────────────────────────────────

    def add_websocket(self, session_id: str, websocket: "WebSocket") -> None:
        state = self._get_or_create_state(session_id)
        state.websockets.append(websocket)

    def remove_websocket(self, session_id: str, websocket: "WebSocket") -> None:
        state = self._sessions.get(session_id)
        if state is None:
            return
        try:
            state.websockets.remove(websocket)
        except ValueError:
            pass
        self._cleanup(session_id)

    # ── Run state ─────────────────────────────────────────────────────────────

    def is_running(self, session_id: str) -> bool:
        state = self._sessions.get(session_id)
        if state is None:
            return False
        return state.run is not None or state.busy_event is not None

    def stop(self, session_id: str) -> bool:
        state = self._sessions.get(session_id)
        if state is None:
            return False
        if state.run is not None:
            state.run.stop_event.set()
            return True
        if state.busy_event is not None:
            state.busy_event.set()
            return True
        return False

    def get_run(self, session_id: str) -> SessionRun | None:
        state = self._sessions.get(session_id)
        return state.run if state else None

    def create_run(self, session_id: str) -> SessionRun:
        state = self._get_or_create_state(session_id)
        run = SessionRun()
        state.run = run
        return run

    def mark_busy(self, session_id: str, stop_event: asyncio.Event | None = None) -> None:
        state = self._get_or_create_state(session_id)
        state.busy_event = stop_event or asyncio.Event()

    def clear_busy(self, session_id: str) -> None:
        state = self._sessions.get(session_id)
        if state is not None:
            state.busy_event = None
        self._cleanup(session_id)

    # ── Agent factory ─────────────────────────────────────────────────────────

    def _build_agent(self, session_store):
        from navi.core import Agent

        try:
            mcp_manager = self._container.mcp_manager
        except Exception:
            mcp_manager = None
        return Agent(
            session_store,
            self._container.profile_registry,
            self._container.tool_registry,
            self._container.backend_registry,
            workers=self._container.workers,
            memory_store=self._container.memory_store,
            cp_registry=self._container.cp_registry,
            mcp_manager=mcp_manager,
        )

    # ── Main run loop (WebSocket-driven) ──────────────────────────────────────

    async def run_agent(
        self,
        session_id: str,
        user_content: str,
        raw_images: list[str] | None,
        display_content: str | None,
        files: list[dict] | None,
        session_store,
    ) -> None:
        """Execute the agent to completion, broadcasting events to subscribers."""
        from navi.tools._internal.base import current_stop_event
        from navi.exceptions import MaxIterationsReached, NaviError, SessionNotFound

        run = self._sessions[session_id].run
        current_stop_event.set(run.stop_event)

        agent = self._build_agent(session_store)

        try:
            async for event in agent.run_stream(
                session_id,
                user_content,
                images=raw_images,
                display_message=display_content,
                files=files,
            ):
                await run.broadcast(("event", event))
        except asyncio.CancelledError:
            log.info("ws.agent_stopped", session_id=session_id)
            await run.broadcast(("stopped", None))
            raise
        except SessionNotFound:
            await run.broadcast(("error", "Session not found"))
        except MaxIterationsReached as e:
            await run.broadcast(("error", str(e)))
        except NaviError as e:
            await run.broadcast(("error", str(e)))
        except Exception as e:
            log.exception("ws.agent_error", session_id=session_id)
            await run.broadcast(("error", f"Internal error: {e}"))
        finally:
            await run.broadcast(("done", None))
            state = self._sessions.get(session_id)
            if state is not None:
                state.run = None
            self._cleanup(session_id)

    # ── Headless recall run ───────────────────────────────────────────────────

    async def _finalize_recall(self, recall, scheduler, *, outcome: str) -> None:
        """Reschedule/mark fired/mark cancelled and publish update."""
        from navi.core.scheduler import _publish_recall_update

        if recall.call_type == "recurring":
            next_trigger = datetime.now(timezone.utc) + timedelta(
                seconds=recall.interval_seconds or 0
            )
            await scheduler.reschedule(recall.id, next_trigger)
            await _publish_recall_update(
                recall.session_id, recall.id, recall.call_type,
                trigger_at=next_trigger.isoformat(), status="pending", action="rescheduled"
            )
            return

        # one-time
        if outcome == "success":
            await scheduler.mark_fired(recall.id)
            await _publish_recall_update(
                recall.session_id, recall.id, recall.call_type,
                trigger_at=recall.trigger_at.isoformat(), status="fired", action="fired"
            )
        elif outcome == "failed":
            await scheduler.mark_cancelled(recall.id)
            await _publish_recall_update(
                recall.session_id, recall.id, recall.call_type,
                trigger_at=recall.trigger_at.isoformat(), status="cancelled", action="cancelled"
            )
        else:
            # max_iterations for one-time -> mark_fired (preserving existing behaviour)
            await scheduler.mark_fired(recall.id)
            await _publish_recall_update(
                recall.session_id, recall.id, recall.call_type,
                trigger_at=recall.trigger_at.isoformat(), status="fired", action="fired"
            )

    async def run_recall(
        self,
        recall,
        scheduler,
        store,
    ) -> None:
        """Execute a scheduled recall headlessly and notify connected clients."""
        from navi.core.agent import Agent, MaxIterationsReached
        from navi.core.event_bus import get_event_bus
        from navi.core.events import StreamEnd
        from navi.tools._internal.base import current_stop_event

        async with self.session_lock(recall.session_id):
            # Guard: if a websocket run is active for this session, defer by 60 seconds
            state = self._sessions.get(recall.session_id)
            if state is not None and state.run is not None:
                log.info("scheduler.defer_busy", session_id=recall.session_id)
                await scheduler.reschedule(
                    recall.id, datetime.now(timezone.utc) + timedelta(seconds=60)
                )
                return

            session = await store.get(recall.session_id)
            if session is None:
                log.warning("scheduler.session_missing", recall_id=recall.id)
                await scheduler.mark_cancelled(recall.id)
                return

            # Set user context for tools
            from navi.tools._internal.base import (
                current_user_id as _uid_var,
                current_user_role as _role_var,
                current_user_info as _uinfo_var,
            )
            if session and session.user_id is not None:
                uid_token = _uid_var.set(session.user_id)
                role_token = _role_var.set("user")
                uinfo_token = _uinfo_var.set(None)
            else:
                uid_token = _uid_var.set(None)
                role_token = _role_var.set("user")
                uinfo_token = _uinfo_var.set(None)

            agent = self._build_agent(store)

            stop_event = asyncio.Event()
            self.mark_busy(recall.session_id, stop_event)
            token = current_stop_event.set(stop_event)

            run = self.create_run(recall.session_id)

        accumulated_text = ""
        try:
            await self._notify_session(recall.session_id, {"type": "stream_start"})

            async for event in agent.run_stream(
                session_id=recall.session_id,
                user_message=recall.additional_context_message,
                display_message=recall.additional_context_message,
                is_recall=True,
            ):
                await get_event_bus().publish(event)
                wire = event.to_wire()
                if wire:
                    await self._notify_session(recall.session_id, wire)
                    await run.broadcast(("event", event))
                if isinstance(event, StreamEnd):
                    accumulated_text = event.full_content

            log.info(
                "scheduler.recall_fired",
                recall_id=recall.id,
                session_id=recall.session_id,
                reply_len=len(accumulated_text),
            )
            await self._finalize_recall(recall, scheduler, outcome="success")
        except MaxIterationsReached:
            log.info("scheduler.max_iterations", recall_id=recall.id)
            await self._finalize_recall(recall, scheduler, outcome="max_iterations")
        except Exception:
            log.exception("scheduler.recall_failed", recall_id=recall.id)
            await self._finalize_recall(recall, scheduler, outcome="failed")
        finally:
            self.clear_busy(recall.session_id)
            current_stop_event.reset(token)
            _uid_var.reset(uid_token)
            _role_var.reset(role_token)
            _uinfo_var.reset(uinfo_token)
            state = self._sessions.get(recall.session_id)
            if state is not None:
                state.run = None
            self._cleanup(recall.session_id)
            await self._notify_session(recall.session_id, {"type": "session_sync"})