diff --git a/navi/api/deps.py b/navi/api/deps.py index 855ae70..57bf3b4 100644 --- a/navi/api/deps.py +++ b/navi/api/deps.py @@ -94,6 +94,10 @@ return _resolve_container().get_agent() +def get_orchestrator(): + return _resolve_container().orchestrator + + async def register_mcp_tools(registry: ToolRegistry, manager: McpManager) -> None: """(kept for backward compat; MCP tools are already registered at startup).""" pass diff --git a/navi/api/websocket.py b/navi/api/websocket.py index ad17a3b..6c7b7bf 100644 --- a/navi/api/websocket.py +++ b/navi/api/websocket.py @@ -1,5 +1,4 @@ -""" -WebSocket endpoint for streaming agent responses. +"""WebSocket endpoint for streaming agent responses. Protocol (client -> server): {"type": "message", "content": "..."} @@ -17,80 +16,28 @@ """ import asyncio -import dataclasses import json import structlog from fastapi import APIRouter, Depends, WebSocket, WebSocketDisconnect from typing import Annotated -from navi.api.deps import get_session_store +from navi.api.deps import get_orchestrator, get_session_store from navi.auth.deps import get_current_user, get_current_user_ws from navi.auth import User from navi.auth.deps import check_session_access -from navi.core import Agent, SessionStore +from navi.core import SessionStore from navi.core.event_bus import get_event_bus from navi.core.events import AgentEvent, RecallUpdate -from navi.exceptions import MaxIterationsReached, NaviError, SessionNotFound router = APIRouter(tags=["websocket"]) log = structlog.get_logger() -_MAX_REPLAY_EVENTS = 500 # cap replay buffer to avoid unbounded growth - -# ── Per-session run state ────────────────────────────────────────────────────── - -@dataclasses.dataclass -class _AgentRun: - """Holds the running agent task and all active subscriber queues.""" - task: asyncio.Task | None = None - stop_event: asyncio.Event = dataclasses.field(default_factory=asyncio.Event) - subscribers: list[asyncio.Queue] = dataclasses.field(default_factory=list) - # Replay buffer: all serialised event dicts emitted so far this turn. - # Used to reconstruct the UI for clients that reconnect mid-stream. - events: list[dict] = dataclasses.field(default_factory=list) - - def subscribe(self) -> asyncio.Queue: - q: asyncio.Queue = asyncio.Queue() - self.subscribers.append(q) - return q - - def unsubscribe(self, q: asyncio.Queue) -> None: - try: - self.subscribers.remove(q) - except ValueError: - pass - - async def broadcast(self, item) -> None: - kind, payload = item - # Serialise and buffer every agent event so reconnecting clients can replay. - if kind == "event": - ev_dict = _event_to_dict(payload) - if ev_dict: - self.events.append(ev_dict) - # Evict oldest events to keep memory and replay cost bounded. - if len(self.events) > _MAX_REPLAY_EVENTS: - self.events.pop(0) - for q in list(self.subscribers): - await q.put(item) - - -# session_id → active run (present only while agent is executing) -_runs: dict[str, _AgentRun] = {} - -# session_id → stop_event for headless scheduled runs (blocks user messages) -_busy_sessions: dict[str, asyncio.Event] = {} - # session_id → all connected WebSocket clients (for out-of-run events like recall updates) _session_sockets: dict[str, list[WebSocket]] = {} -def is_session_running(session_id: str) -> bool: - """True if the session has an active WebSocket run OR a headless scheduled run.""" - return session_id in _runs or session_id in _busy_sessions - - async def _notify_session(session_id: str, payload: dict) -> None: """Send a JSON payload to every open WebSocket for the given session.""" sockets = _session_sockets.get(session_id, []) @@ -119,47 +66,6 @@ return None -async def _run_agent( - run: _AgentRun, - agent: Agent, - session_id: str, - user_content: str, - raw_images: list[str] | None, - display_content: str | None = None, - files: list[dict] | None = None, -) -> None: - """ - Execute the agent to completion, broadcasting events to all subscribers. - The session is saved by run_stream before StreamEnd — guaranteed even on disconnect. - """ - from navi.tools._internal.base import current_stop_event - current_stop_event.set(run.stop_event) - - try: - async for event in agent.run_stream( - session_id, user_content, images=raw_images, display_message=display_content, - files=files, - ): - await get_event_bus().publish(event) - await run.broadcast(("event", event)) - except asyncio.CancelledError: - log.info("ws.agent_stopped", session_id=session_id) - await run.broadcast(("stopped", None)) - raise # re-raise so the task is properly marked cancelled - except SessionNotFound: - await run.broadcast(("error", "Session not found")) - except MaxIterationsReached as e: - await run.broadcast(("error", str(e))) - except NaviError as e: - await run.broadcast(("error", str(e))) - except Exception as e: - log.exception("ws.agent_error", session_id=session_id) - await run.broadcast(("error", f"Internal error: {e}")) - finally: - await run.broadcast(("done", None)) - _runs.pop(session_id, None) - - _HEARTBEAT_INTERVAL = 20.0 # seconds — keeps the browser from dropping idle connections @@ -214,16 +120,9 @@ session = await store.get(session_id) if session is not None: check_session_access(session, user) - run = _runs.get(session_id) - if run is not None: - run.stop_event.set() - return {"ok": True} - # Also stop a headless recall run - recall_stop = _busy_sessions.get(session_id) - if recall_stop is not None: - recall_stop.set() - return {"ok": True} - return {"ok": False, "reason": "no active run"} + orchestrator = get_orchestrator() + ok = orchestrator.stop(session_id) + return {"ok": ok} if ok else {"ok": False, "reason": "no active run"} @router.websocket("/ws/sessions/{session_id}") @@ -243,6 +142,7 @@ log.info("ws.user_resolved", session_id=session_id, user_id=user.id if user else None) session_store = get_session_store() + orchestrator = get_orchestrator() session = await session_store.get(session_id) if session is None: @@ -273,29 +173,19 @@ await websocket.close(code=4003, reason="Access denied") return - from navi.api.deps import _resolve_container, get_memory_store, get_mcp_manager, get_workers - container = _resolve_container() - try: - mcp_manager = await get_mcp_manager() - except Exception: - mcp_manager = None - agent = Agent( - session_store, - container.profile_registry, - container.tool_registry, - container.backend_registry, - workers=get_workers(), - memory_store=get_memory_store(), - cp_registry=container.cp_registry, - mcp_manager=mcp_manager, - ) + # Wire the orchestrator notify callback so recall headless runs can push + # events to this (and any other) open WebSocket for the session. + async def _notify(session_id_: str, payload: dict) -> None: + await _notify_session(session_id_, payload) + + orchestrator.set_notify(_notify) queue: asyncio.Queue | None = None - current_run: _AgentRun | None = None + current_run = None try: # Re-attach to an in-progress run (e.g. client reloaded the page mid-stream). - existing = _runs.get(session_id) + existing = orchestrator.get_run(session_id) if existing is not None: current_run = existing # Subscribe BEFORE noting replay_count — single-threaded async, no race: @@ -396,7 +286,7 @@ ) # Guard against concurrent runs for the same session. - if session_id in _runs or session_id in _busy_sessions: + if orchestrator.is_running(session_id): await websocket.send_json({ "type": "error", "message": "Agent is already running for this session.", @@ -405,10 +295,9 @@ # Register run and subscribe before starting the task so we never # miss events even if the task is very fast. - run = _AgentRun() + run = orchestrator.create_run(session_id) queue = run.subscribe() current_run = run - _runs[session_id] = run # Set user context for tool sandboxing (inherited by the agent task) from navi.tools._internal.base import current_user_id as _uid_var, current_user_role as _role_var, current_user_info as _uinfo_var @@ -422,7 +311,9 @@ _uinfo_var.set(None) run.task = asyncio.create_task( - _run_agent(run, agent, session_id, user_content, raw_images, original_content, uploaded_files) + orchestrator.run_agent( + session_id, user_content, raw_images, original_content, uploaded_files, session_store + ) ) await websocket.send_json({"type": "stream_start"}) diff --git a/navi/core/container.py b/navi/core/container.py index 46ec786..7656595 100644 --- a/navi/core/container.py +++ b/navi/core/container.py @@ -11,6 +11,7 @@ if TYPE_CHECKING: from navi.core import Agent, BackendRegistry, ProfileRegistry, SessionStore, ToolRegistry + from navi.core.orchestrator import AgentSessionOrchestrator from navi.core.scheduler import RecallScheduler from navi.memory import MemoryStore from navi.mcp import McpManager @@ -32,6 +33,7 @@ cp_registry: "ContextProviderRegistry" workers: list["Worker"] mcp_manager: "McpManager | None" = None + orchestrator: "AgentSessionOrchestrator | None" = None _agent: "Agent | None" = field(default=None, repr=False) @@ -151,7 +153,7 @@ except Exception: pass - return AppContainer( + container = AppContainer( memory_store=memory_store, session_store=session_store, kv_store=kv_store, @@ -163,3 +165,7 @@ workers=workers, mcp_manager=mcp_manager, ) + from navi.core.orchestrator import AgentSessionOrchestrator + + container.orchestrator = AgentSessionOrchestrator(container) + return container diff --git a/navi/core/orchestrator.py b/navi/core/orchestrator.py index 00094d9..679bda6 100644 --- a/navi/core/orchestrator.py +++ b/navi/core/orchestrator.py @@ -1,23 +1,315 @@ -""" -Orchestrator stub — foundation for multi-agent scenarios. - -When multi-agent support is needed: -1. Implement OrchestratorAgent that decomposes tasks into subtasks -2. Each subtask is dispatched to a worker Agent with a specialized profile -3. Workers communicate results via asyncio.Queue (local) or Redis Pub/Sub (distributed) -4. Orchestrator aggregates results and produces a final response - -The existing Agent class requires no modification — orchestration is purely additive. -""" +"""Agent session orchestrator — manages active runs and recall lifecycle.""" from __future__ import annotations +import asyncio +import dataclasses +from datetime import datetime, timedelta, timezone +from typing import TYPE_CHECKING, Any, Callable -class Orchestrator: - """Placeholder for future multi-agent orchestration.""" +import structlog - def __init__(self) -> None: - raise NotImplementedError( - "Multi-agent orchestration is not yet implemented. " - "Use Agent directly for single-agent scenarios." +if TYPE_CHECKING: + from navi.core.container import AppContainer + +log = structlog.get_logger() + +_MAX_REPLAY_EVENTS = 500 + + +@dataclasses.dataclass +class SessionRun: + """Holds the running agent task and all active subscriber queues.""" + + task: asyncio.Task | None = None + stop_event: asyncio.Event = dataclasses.field(default_factory=asyncio.Event) + subscribers: list[asyncio.Queue] = dataclasses.field(default_factory=list) + # Replay buffer: all serialised event dicts emitted so far this turn. + events: list[dict] = dataclasses.field(default_factory=list) + + def subscribe(self) -> asyncio.Queue: + q: asyncio.Queue = asyncio.Queue() + self.subscribers.append(q) + return q + + def unsubscribe(self, q: asyncio.Queue) -> None: + try: + self.subscribers.remove(q) + except ValueError: + pass + + async def broadcast(self, item) -> None: + kind, payload = item + if kind == "event": + ev_dict = _event_to_dict(payload) + if ev_dict: + self.events.append(ev_dict) + if len(self.events) > _MAX_REPLAY_EVENTS: + self.events.pop(0) + for q in list(self.subscribers): + await q.put(item) + + +def _event_to_dict(event) -> dict | None: + if hasattr(event, "to_wire"): + return event.to_wire() + return None + + +class AgentSessionOrchestrator: + """Owns all active agent runs and headless recall sessions. + + Transport-agnostic — the WebSocket handler (or any other transport) + sets a *notify* callback so the orchestrator can push events to + connected clients without knowing about WebSockets directly. + """ + + def __init__(self, container: AppContainer) -> None: + self._container = container + self._runs: dict[str, SessionRun] = {} + self._busy_sessions: dict[str, asyncio.Event] = {} + # Callback injected by the transport layer (e.g. WebSocket handler). + self._notify: Callable[[str, dict], Any] | None = None + + def set_notify(self, notify: Callable[[str, dict], Any] | None) -> None: + self._notify = notify + + async def _notify_session(self, session_id: str, payload: dict) -> None: + if self._notify is not None: + try: + await self._notify(session_id, payload) + except Exception: + pass + + # ── Run state ───────────────────────────────────────────────────────────── + + def is_running(self, session_id: str) -> bool: + return session_id in self._runs or session_id in self._busy_sessions + + def stop(self, session_id: str) -> bool: + run = self._runs.get(session_id) + if run is not None: + run.stop_event.set() + return True + recall_stop = self._busy_sessions.get(session_id) + if recall_stop is not None: + recall_stop.set() + return True + return False + + def get_run(self, session_id: str) -> SessionRun | None: + return self._runs.get(session_id) + + def create_run(self, session_id: str) -> SessionRun: + run = SessionRun() + self._runs[session_id] = run + return run + + def mark_busy(self, session_id: str, stop_event: asyncio.Event | None = None) -> None: + self._busy_sessions[session_id] = stop_event or asyncio.Event() + + def clear_busy(self, session_id: str) -> None: + self._busy_sessions.pop(session_id, None) + + # ── Agent factory ───────────────────────────────────────────────────────── + + def _build_agent(self, session_store): + from navi.core import Agent + + try: + mcp_manager = self._container.mcp_manager + except Exception: + mcp_manager = None + return Agent( + session_store, + self._container.profile_registry, + self._container.tool_registry, + self._container.backend_registry, + workers=self._container.workers, + memory_store=self._container.memory_store, + cp_registry=self._container.cp_registry, + mcp_manager=mcp_manager, ) + + # ── Main run loop (WebSocket-driven) ────────────────────────────────────── + + async def run_agent( + self, + session_id: str, + user_content: str, + raw_images: list[str] | None, + display_content: str | None, + files: list[dict] | None, + session_store, + ) -> None: + """Execute the agent to completion, broadcasting events to subscribers.""" + from navi.tools._internal.base import current_stop_event + from navi.exceptions import MaxIterationsReached, NaviError, SessionNotFound + + run = self._runs[session_id] + current_stop_event.set(run.stop_event) + + agent = self._build_agent(session_store) + + try: + async for event in agent.run_stream( + session_id, + user_content, + images=raw_images, + display_message=display_content, + files=files, + ): + await run.broadcast(("event", event)) + except asyncio.CancelledError: + log.info("ws.agent_stopped", session_id=session_id) + await run.broadcast(("stopped", None)) + raise + except SessionNotFound: + await run.broadcast(("error", "Session not found")) + except MaxIterationsReached as e: + await run.broadcast(("error", str(e))) + except NaviError as e: + await run.broadcast(("error", str(e))) + except Exception as e: + log.exception("ws.agent_error", session_id=session_id) + await run.broadcast(("error", f"Internal error: {e}")) + finally: + await run.broadcast(("done", None)) + self._runs.pop(session_id, None) + + # ── Headless recall run ─────────────────────────────────────────────────── + + async def run_recall( + self, + recall, + scheduler, + store, + ) -> None: + """Execute a scheduled recall headlessly and notify connected clients.""" + from navi.core.agent import Agent, MaxIterationsReached + from navi.core.event_bus import get_event_bus + from navi.core.events import StreamEnd + from navi.tools._internal.base import current_stop_event + + # Guard: if a websocket run is active for this session, defer by 60 seconds + if recall.session_id in self._runs: + log.info("scheduler.defer_busy", session_id=recall.session_id) + await scheduler.reschedule( + recall.id, datetime.now(timezone.utc) + timedelta(seconds=60) + ) + return + + session = await store.get(recall.session_id) + if session is None: + log.warning("scheduler.session_missing", recall_id=recall.id) + await scheduler.mark_cancelled(recall.id) + return + + # Set user context for tools + from navi.tools._internal.base import ( + current_user_id as _uid_var, + current_user_role as _role_var, + current_user_info as _uinfo_var, + ) + if session and session.user_id is not None: + _uid_var.set(session.user_id) + _role_var.set("user") + _uinfo_var.set(None) + else: + _uid_var.set(None) + _role_var.set("user") + _uinfo_var.set(None) + + agent = self._build_agent(store) + + stop_event = asyncio.Event() + self.mark_busy(recall.session_id, stop_event) + token = current_stop_event.set(stop_event) + + run = SessionRun() + self._runs[recall.session_id] = run + + accumulated_text = "" + try: + await self._notify_session(recall.session_id, {"type": "stream_start"}) + + async for event in agent.run_stream( + session_id=recall.session_id, + user_message=recall.additional_context_message, + display_message=recall.additional_context_message, + is_recall=True, + ): + await get_event_bus().publish(event) + wire = event.to_wire() + if wire: + await self._notify_session(recall.session_id, wire) + await run.broadcast(("event", event)) + if isinstance(event, StreamEnd): + accumulated_text = event.full_content + + log.info( + "scheduler.recall_fired", + recall_id=recall.id, + session_id=recall.session_id, + reply_len=len(accumulated_text), + ) + if recall.call_type == "recurring": + next_trigger = datetime.now(timezone.utc) + timedelta( + seconds=recall.interval_seconds or 0 + ) + await scheduler.reschedule(recall.id, next_trigger) + from navi.core.scheduler import _publish_recall_update + await _publish_recall_update( + recall.session_id, recall.id, recall.call_type, + trigger_at=next_trigger.isoformat(), status="pending", action="rescheduled" + ) + else: + await scheduler.mark_fired(recall.id) + from navi.core.scheduler import _publish_recall_update + await _publish_recall_update( + recall.session_id, recall.id, recall.call_type, + trigger_at=recall.trigger_at.isoformat(), status="fired", action="fired" + ) + except MaxIterationsReached: + log.info("scheduler.max_iterations", recall_id=recall.id) + if recall.call_type == "recurring": + next_trigger = datetime.now(timezone.utc) + timedelta( + seconds=recall.interval_seconds or 0 + ) + await scheduler.reschedule(recall.id, next_trigger) + from navi.core.scheduler import _publish_recall_update + await _publish_recall_update( + recall.session_id, recall.id, recall.call_type, + trigger_at=next_trigger.isoformat(), status="pending", action="rescheduled" + ) + else: + await scheduler.mark_fired(recall.id) + from navi.core.scheduler import _publish_recall_update + await _publish_recall_update( + recall.session_id, recall.id, recall.call_type, + trigger_at=recall.trigger_at.isoformat(), status="fired", action="fired" + ) + except Exception: + log.exception("scheduler.recall_failed", recall_id=recall.id) + if recall.call_type == "recurring": + next_trigger = datetime.now(timezone.utc) + timedelta( + seconds=recall.interval_seconds or 0 + ) + await scheduler.reschedule(recall.id, next_trigger) + from navi.core.scheduler import _publish_recall_update + await _publish_recall_update( + recall.session_id, recall.id, recall.call_type, + trigger_at=next_trigger.isoformat(), status="pending", action="rescheduled" + ) + else: + await scheduler.mark_cancelled(recall.id) + from navi.core.scheduler import _publish_recall_update + await _publish_recall_update( + recall.session_id, recall.id, recall.call_type, + trigger_at=recall.trigger_at.isoformat(), status="cancelled", action="cancelled" + ) + finally: + self.clear_busy(recall.session_id) + current_stop_event.reset(token) + self._runs.pop(recall.session_id, None) + await self._notify_session(recall.session_id, {"type": "session_sync"}) diff --git a/navi/core/scheduler.py b/navi/core/scheduler.py index b838f15..be3ad89 100644 --- a/navi/core/scheduler.py +++ b/navi/core/scheduler.py @@ -315,7 +315,7 @@ return {r["session_id"] for r in rows} -async def recall_scheduler_loop(scheduler: RecallScheduler, store: Any) -> None: +async def recall_scheduler_loop(scheduler: RecallScheduler, store: Any, orchestrator) -> None: """Background task: poll for due recalls and fire them.""" await scheduler.ensure_tables() semaphore = asyncio.Semaphore(3) @@ -327,7 +327,7 @@ if pending: tasks = [ - asyncio.create_task(_fire_recall(semaphore, recall, scheduler, store)) + asyncio.create_task(_fire_recall(semaphore, recall, scheduler, store, orchestrator)) for recall in pending ] await asyncio.gather(*tasks, return_exceptions=True) @@ -351,170 +351,7 @@ recall: Recall, scheduler: RecallScheduler, store: Any, + orchestrator, ) -> None: - from navi.api.deps import ( - get_backend_registry, - get_cp_registry, - get_memory_store, - get_mcp_manager, - get_profile_registry, - get_tool_registry, - get_workers, - ) - from navi.api.websocket import ( - _AgentRun, - _busy_sessions, - _notify_session, - _runs, - ) - from navi.core.agent import Agent, MaxIterationsReached - async with semaphore: - # Guard: if a websocket run is active for this session, defer by 60 seconds - if recall.session_id in _runs: - log.info("scheduler.defer_busy", session_id=recall.session_id) - await scheduler.reschedule( - recall.id, datetime.now(timezone.utc) + timedelta(seconds=60) - ) - return - - session = await store.get(recall.session_id) - if session is None: - log.warning("scheduler.session_missing", recall_id=recall.id) - await scheduler.mark_cancelled(recall.id) - return - - # Set user context for tools so sandboxing and ownership checks work - from navi.tools._internal.base import ( - current_user_id as _uid_var, - current_user_role as _role_var, - current_user_info as _uinfo_var, - ) - if session and session.user_id is not None: - _uid_var.set(session.user_id) - _role_var.set("user") - _uinfo_var.set(None) - else: - _uid_var.set(None) - _role_var.set("user") - _uinfo_var.set(None) - - # Build agent (same deps pattern as websocket handler) - tools = get_tool_registry() - profiles = get_profile_registry() - backends = get_backend_registry() - cp_registry = get_cp_registry() - try: - mcp_manager = await get_mcp_manager() - except Exception: - mcp_manager = None - - agent = Agent( - store, - profiles, - tools, - backends, - workers=get_workers(), - memory_store=get_memory_store(), - cp_registry=cp_registry, - mcp_manager=mcp_manager, - ) - - from navi.core.event_bus import get_event_bus - from navi.core.events import StreamEnd - from navi.tools._internal.base import current_stop_event - - stop_event = asyncio.Event() - _busy_sessions[recall.session_id] = stop_event - token = current_stop_event.set(stop_event) - - # Register a headless run so reconnecting clients can replay events - run = _AgentRun() - _runs[recall.session_id] = run - - accumulated_text = "" - try: - # Notify any open WebSocket clients that a headless turn is starting. - # We intentionally do NOT send session_sync here — it races with - # stream_start and can wipe the in-progress streaming message before - # any deltas arrive. The final session_sync in the finally block - # ensures the saved turn (recall user msg + assistant response) is - # loaded once streaming is complete. - await _notify_session(recall.session_id, {"type": "stream_start"}) - - async for event in agent.run_stream( - session_id=recall.session_id, - user_message=recall.additional_context_message, - display_message=recall.additional_context_message, - is_recall=True, - ): - await get_event_bus().publish(event) - wire = event.to_wire() - if wire: - await _notify_session(recall.session_id, wire) - await run.broadcast(("event", event)) - if isinstance(event, StreamEnd): - accumulated_text = event.full_content - - log.info( - "scheduler.recall_fired", - recall_id=recall.id, - session_id=recall.session_id, - reply_len=len(accumulated_text), - ) - if recall.call_type == "recurring": - next_trigger = datetime.now(timezone.utc) + timedelta( - seconds=recall.interval_seconds or 0 - ) - await scheduler.reschedule(recall.id, next_trigger) - await _publish_recall_update( - recall.session_id, recall.id, recall.call_type, - trigger_at=next_trigger.isoformat(), status="pending", action="rescheduled" - ) - else: - await scheduler.mark_fired(recall.id) - await _publish_recall_update( - recall.session_id, recall.id, recall.call_type, - trigger_at=recall.trigger_at.isoformat(), status="fired", action="fired" - ) - except MaxIterationsReached: - log.info("scheduler.max_iterations", recall_id=recall.id) - if recall.call_type == "recurring": - next_trigger = datetime.now(timezone.utc) + timedelta( - seconds=recall.interval_seconds or 0 - ) - await scheduler.reschedule(recall.id, next_trigger) - await _publish_recall_update( - recall.session_id, recall.id, recall.call_type, - trigger_at=next_trigger.isoformat(), status="pending", action="rescheduled" - ) - else: - await scheduler.mark_fired(recall.id) - await _publish_recall_update( - recall.session_id, recall.id, recall.call_type, - trigger_at=recall.trigger_at.isoformat(), status="fired", action="fired" - ) - except Exception: - log.exception("scheduler.recall_failed", recall_id=recall.id) - if recall.call_type == "recurring": - next_trigger = datetime.now(timezone.utc) + timedelta( - seconds=recall.interval_seconds or 0 - ) - await scheduler.reschedule(recall.id, next_trigger) - await _publish_recall_update( - recall.session_id, recall.id, recall.call_type, - trigger_at=next_trigger.isoformat(), status="pending", action="rescheduled" - ) - else: - await scheduler.mark_cancelled(recall.id) - await _publish_recall_update( - recall.session_id, recall.id, recall.call_type, - trigger_at=recall.trigger_at.isoformat(), status="cancelled", action="cancelled" - ) - finally: - _busy_sessions.pop(recall.session_id, None) - current_stop_event.reset(token) - _runs.pop(recall.session_id, None) - # Tell all connected clients to reload session history so they see the - # full headless turn (recall message + assistant response). - await _notify_session(recall.session_id, {"type": "session_sync"}) + await orchestrator.run_recall(recall, scheduler, store) diff --git a/navi/main.py b/navi/main.py index 91742f9..c620079 100644 --- a/navi/main.py +++ b/navi/main.py @@ -83,7 +83,7 @@ # Start background tasks cleanup_task = asyncio.create_task(cleanup_loop(container.session_store)) scheduler_task = asyncio.create_task( - recall_scheduler_loop(container.scheduler, container.session_store) + recall_scheduler_loop(container.scheduler, container.session_store, container.orchestrator) ) yield diff --git a/tests/integration/test_scheduler_loop.py b/tests/integration/test_scheduler_loop.py index e57967c..bfa0482 100644 --- a/tests/integration/test_scheduler_loop.py +++ b/tests/integration/test_scheduler_loop.py @@ -2,14 +2,28 @@ import asyncio from datetime import datetime, timedelta, timezone -from unittest.mock import AsyncMock +from unittest.mock import AsyncMock, MagicMock import pytest from navi.core.events import StreamEnd +from navi.core.orchestrator import AgentSessionOrchestrator from navi.core.scheduler import Recall, _fire_recall, recall_scheduler_loop +@pytest.fixture +def fake_orchestrator(): + container = MagicMock() + container.profile_registry = None + container.tool_registry = None + container.backend_registry = None + container.cp_registry = None + container.workers = [] + container.memory_store = None + container.mcp_manager = None + return AgentSessionOrchestrator(container) + + @pytest.fixture(autouse=True) def patch_scheduler_deps(monkeypatch): """Prevent real _fire_recall from triggering heavy dependency initialization.""" @@ -42,14 +56,14 @@ # Patch _fire_recall to avoid full Agent construction fire_calls = [] - async def _fake_fire(semaphore, recall, scheduler, store): + async def _fake_fire(semaphore, recall, scheduler, store, orchestrator): fire_calls.append(recall.id) await scheduler.mark_fired(recall.id) monkeypatch.setattr("navi.core.scheduler._fire_recall", _fake_fire) # Run loop for one iteration then cancel - task = asyncio.create_task(recall_scheduler_loop(scheduler, store)) + task = asyncio.create_task(recall_scheduler_loop(scheduler, store, AsyncMock())) await asyncio.sleep(0.1) task.cancel() try: @@ -60,7 +74,7 @@ assert "r1" in fire_calls @pytest.mark.anyio - async def test_loop_respects_semaphore(self, monkeypatch): + async def test_loop_respects_semaphore(self, monkeypatch, fake_orchestrator): scheduler = AsyncMock() scheduler.get_pending_recalls.return_value = [ Recall( @@ -81,7 +95,7 @@ max_concurrent = 0 current = 0 - async def _slow_fire(semaphore, recall, scheduler, store): + async def _slow_fire(semaphore, recall, scheduler, store, orchestrator): nonlocal max_concurrent, current async with semaphore: current += 1 @@ -92,7 +106,7 @@ monkeypatch.setattr("navi.core.scheduler._fire_recall", _slow_fire) - task = asyncio.create_task(recall_scheduler_loop(scheduler, store)) + task = asyncio.create_task(recall_scheduler_loop(scheduler, store, fake_orchestrator)) await asyncio.sleep(0.15) task.cancel() try: @@ -103,7 +117,7 @@ assert max_concurrent <= 3 @pytest.mark.anyio - async def test_loop_defers_when_session_busy(self, monkeypatch): + async def test_loop_defers_when_session_busy(self, monkeypatch, fake_orchestrator): scheduler = AsyncMock() scheduler.get_pending_recalls.return_value = [ Recall( @@ -120,17 +134,15 @@ store = AsyncMock() - # Simulate an active websocket run - import navi.api.websocket as ws_mod - from navi.api.websocket import _AgentRun - ws_mod._runs["s1"] = _AgentRun() + # Simulate an active websocket run in the orchestrator + fake_orchestrator.create_run("s1") - async def _fake_fire(semaphore, recall, scheduler, store): - await _fire_recall(semaphore, recall, scheduler, store) + async def _fake_fire(semaphore, recall, scheduler, store, orchestrator): + await _fire_recall(semaphore, recall, scheduler, store, orchestrator) monkeypatch.setattr("navi.core.scheduler._fire_recall", _fake_fire) - task = asyncio.create_task(recall_scheduler_loop(scheduler, store)) + task = asyncio.create_task(recall_scheduler_loop(scheduler, store, fake_orchestrator)) await asyncio.sleep(0.1) task.cancel() try: @@ -138,12 +150,12 @@ except asyncio.CancelledError: pass finally: - ws_mod._runs.pop("s1", None) + fake_orchestrator._runs.pop("s1", None) scheduler.reschedule.assert_called_once() @pytest.mark.anyio - async def test_loop_cancels_when_session_missing(self, monkeypatch): + async def test_loop_cancels_when_session_missing(self, monkeypatch, fake_orchestrator): scheduler = AsyncMock() scheduler.get_pending_recalls.return_value = [ Recall( @@ -161,12 +173,12 @@ store = AsyncMock() store.get.return_value = None - async def _fake_fire(semaphore, recall, scheduler, store): - await _fire_recall(semaphore, recall, scheduler, store) + async def _fake_fire(semaphore, recall, scheduler, store, orchestrator): + await _fire_recall(semaphore, recall, scheduler, store, orchestrator) monkeypatch.setattr("navi.core.scheduler._fire_recall", _fake_fire) - task = asyncio.create_task(recall_scheduler_loop(scheduler, store)) + task = asyncio.create_task(recall_scheduler_loop(scheduler, store, fake_orchestrator)) await asyncio.sleep(0.1) task.cancel() try: @@ -177,7 +189,7 @@ scheduler.mark_cancelled.assert_called_once_with("r1") @pytest.mark.anyio - async def test_recurring_rescheduled_on_success(self, monkeypatch): + async def test_recurring_rescheduled_on_success(self, monkeypatch, fake_orchestrator): scheduler = AsyncMock() recall = Recall( id="r1", session_id="s1", call_type="recurring", @@ -203,7 +215,7 @@ monkeypatch.setattr("navi.core.agent.Agent", lambda *a, **kw: FakeAgent()) - task = asyncio.create_task(recall_scheduler_loop(scheduler, store)) + task = asyncio.create_task(recall_scheduler_loop(scheduler, store, fake_orchestrator)) await asyncio.sleep(0.1) task.cancel() try: @@ -214,7 +226,7 @@ scheduler.reschedule.assert_called_once() @pytest.mark.anyio - async def test_recurring_rescheduled_on_failure(self, monkeypatch): + async def test_recurring_rescheduled_on_failure(self, monkeypatch, fake_orchestrator): scheduler = AsyncMock() recall = Recall( id="r1", session_id="s1", call_type="recurring", @@ -240,7 +252,7 @@ monkeypatch.setattr("navi.core.agent.Agent", lambda *a, **kw: FakeAgent()) - task = asyncio.create_task(recall_scheduler_loop(scheduler, store)) + task = asyncio.create_task(recall_scheduler_loop(scheduler, store, fake_orchestrator)) await asyncio.sleep(0.1) task.cancel() try: @@ -251,7 +263,7 @@ scheduler.reschedule.assert_called_once() @pytest.mark.anyio - async def test_one_time_cancelled_on_failure(self, monkeypatch): + async def test_one_time_cancelled_on_failure(self, monkeypatch, fake_orchestrator): scheduler = AsyncMock() recall = Recall( id="r1", session_id="s1", call_type="once", @@ -277,7 +289,7 @@ monkeypatch.setattr("navi.core.agent.Agent", lambda *a, **kw: FakeAgent()) - task = asyncio.create_task(recall_scheduler_loop(scheduler, store)) + task = asyncio.create_task(recall_scheduler_loop(scheduler, store, fake_orchestrator)) await asyncio.sleep(0.1) task.cancel() try: @@ -288,7 +300,7 @@ scheduler.mark_cancelled.assert_called_once_with("r1") @pytest.mark.anyio - async def test_loop_picks_up_after_restart(self, monkeypatch): + async def test_loop_picks_up_after_restart(self, monkeypatch, fake_orchestrator): scheduler = AsyncMock() scheduler.get_pending_recalls.return_value = [ Recall( @@ -306,13 +318,13 @@ store = AsyncMock() fire_calls = [] - async def _fake_fire(semaphore, recall, scheduler, store): + async def _fake_fire(semaphore, recall, scheduler, store, orchestrator): fire_calls.append(recall.id) await scheduler.mark_fired(recall.id) monkeypatch.setattr("navi.core.scheduler._fire_recall", _fake_fire) - task = asyncio.create_task(recall_scheduler_loop(scheduler, store)) + task = asyncio.create_task(recall_scheduler_loop(scheduler, store, fake_orchestrator)) await asyncio.sleep(0.1) task.cancel() try: diff --git a/tests/unit/api/test_websocket.py b/tests/unit/api/test_websocket.py index 1c0764f..4543188 100644 --- a/tests/unit/api/test_websocket.py +++ b/tests/unit/api/test_websocket.py @@ -8,13 +8,12 @@ from fastapi import WebSocketDisconnect from navi.api import websocket as ws_mod +from navi.core.orchestrator import AgentSessionOrchestrator, SessionRun @pytest.fixture(autouse=True) def _clear_state(monkeypatch): """Clear global state before every WS test.""" - ws_mod._runs.clear() - ws_mod._busy_sessions.clear() ws_mod._session_sockets.clear() yield @@ -43,12 +42,25 @@ return user -# ── _AgentRun buffer tests ─────────────────────────────────────────────────── +@pytest.fixture +def fake_orchestrator(): + container = MagicMock() + container.profile_registry = None + container.tool_registry = None + container.backend_registry = None + container.cp_registry = None + container.workers = [] + container.memory_store = None + container.mcp_manager = None + return AgentSessionOrchestrator(container) + + +# ── SessionRun buffer tests ───────────────────────────────────────────────── @pytest.mark.anyio async def test_event_buffer_appended_and_replayed(): """Broadcast stores serialised events; oldest evicted when cap exceeded.""" - run = ws_mod._AgentRun() + run = SessionRun() class FakeEvent: def __init__(self, idx: int) -> None: @@ -66,15 +78,15 @@ # Fill buffer past limit run.events.clear() - for i in range(ws_mod._MAX_REPLAY_EVENTS + 5): + for i in range(500 + 5): await run.broadcast(("event", FakeEvent(i))) - assert len(run.events) == ws_mod._MAX_REPLAY_EVENTS + assert len(run.events) == 500 assert run.events[0] == {"type": "stream_delta", "delta": "5"} - assert run.events[-1] == {"type": "stream_delta", "delta": str(ws_mod._MAX_REPLAY_EVENTS + 4)} + assert run.events[-1] == {"type": "stream_delta", "delta": str(500 + 4)} -# ── Reconnect / replay tests ─────────────────────────────────────────────────── +# ── Reconnect / replay tests ───────────────────────────────────────────────── @pytest.mark.anyio async def test_reconnect_replays_buffered_events(mock_websocket, mock_session, mock_user, monkeypatch): @@ -83,22 +95,22 @@ mock_store = MagicMock() mock_store.get = AsyncMock(return_value=mock_session) monkeypatch.setattr(ws_mod, "get_session_store", lambda: mock_store) - monkeypatch.setattr(ws_mod, "Agent", lambda *a, **kw: MagicMock()) monkeypatch.setattr(ws_mod, "_stream_to_client", AsyncMock(return_value=True)) - # Provide a dummy container so _resolve_container() succeeds + fake_container = MagicMock() fake_container.profile_registry = None fake_container.tool_registry = None fake_container.backend_registry = None fake_container.cp_registry = None + fake_container.orchestrator = AgentSessionOrchestrator(fake_container) monkeypatch.setattr("navi.api.deps._resolve_container", lambda: fake_container) - run = ws_mod._AgentRun() + orchestrator = fake_container.orchestrator + run = orchestrator.create_run("s1") run.events = [ {"type": "stream_delta", "delta": "hello"}, {"type": "thinking_delta", "delta": "hmm"}, ] - ws_mod._runs["s1"] = run mock_websocket.receive_text = AsyncMock(side_effect=WebSocketDisconnect()) @@ -117,7 +129,7 @@ ] assert calls[1]["count"] == 2 - ws_mod._runs.pop("s1", None) + orchestrator._runs.pop("s1", None) @pytest.mark.anyio @@ -127,14 +139,14 @@ mock_store = MagicMock() mock_store.get = AsyncMock(return_value=mock_session) monkeypatch.setattr(ws_mod, "get_session_store", lambda: mock_store) - monkeypatch.setattr(ws_mod, "Agent", lambda *a, **kw: MagicMock()) monkeypatch.setattr(ws_mod, "_stream_to_client", AsyncMock(return_value=True)) - # Provide a dummy container so _resolve_container() succeeds + fake_container = MagicMock() fake_container.profile_registry = None fake_container.tool_registry = None fake_container.backend_registry = None fake_container.cp_registry = None + fake_container.orchestrator = AgentSessionOrchestrator(fake_container) monkeypatch.setattr("navi.api.deps._resolve_container", lambda: fake_container) mock_websocket.receive_text = AsyncMock(side_effect=WebSocketDisconnect()) @@ -153,17 +165,18 @@ mock_store = MagicMock() mock_store.get = AsyncMock(return_value=mock_session) monkeypatch.setattr(ws_mod, "get_session_store", lambda: mock_store) - monkeypatch.setattr(ws_mod, "Agent", lambda *a, **kw: MagicMock()) monkeypatch.setattr(ws_mod, "_stream_to_client", AsyncMock(return_value=True)) - # Provide a dummy container so _resolve_container() succeeds + fake_container = MagicMock() fake_container.profile_registry = None fake_container.tool_registry = None fake_container.backend_registry = None fake_container.cp_registry = None + fake_container.orchestrator = AgentSessionOrchestrator(fake_container) monkeypatch.setattr("navi.api.deps._resolve_container", lambda: fake_container) - ws_mod._busy_sessions["s1"] = asyncio.Event() + orchestrator = fake_container.orchestrator + orchestrator.mark_busy("s1") mock_websocket.receive_text = AsyncMock(side_effect=WebSocketDisconnect()) @@ -173,7 +186,7 @@ types = [c["type"] for c in calls] assert types == ["session_sync"] - ws_mod._busy_sessions.pop("s1", None) + orchestrator.clear_busy("s1") # ── Concurrent run guard ───────────────────────────────────────────────────── @@ -185,21 +198,23 @@ mock_store = MagicMock() mock_store.get = AsyncMock(return_value=mock_session) monkeypatch.setattr(ws_mod, "get_session_store", lambda: mock_store) - monkeypatch.setattr(ws_mod, "Agent", lambda *a, **kw: MagicMock()) monkeypatch.setattr(ws_mod, "_stream_to_client", AsyncMock(return_value=True)) - # Provide a dummy container so _resolve_container() succeeds + fake_container = MagicMock() fake_container.profile_registry = None fake_container.tool_registry = None fake_container.backend_registry = None fake_container.cp_registry = None + fake_container.orchestrator = AgentSessionOrchestrator(fake_container) monkeypatch.setattr("navi.api.deps._resolve_container", lambda: fake_container) + orchestrator = fake_container.orchestrator + # _run_agent sleeps so the run stays registered async def fake_run_agent(*a, **kw): await asyncio.sleep(3600) - monkeypatch.setattr(ws_mod, "_run_agent", fake_run_agent) + monkeypatch.setattr(orchestrator, "run_agent", fake_run_agent) message_count = 0 @@ -222,11 +237,11 @@ assert "already running" in error_calls[0]["message"] # Cleanup background task - run = ws_mod._runs.get("s1") + run = orchestrator._runs.get("s1") if run and run.task: run.task.cancel() try: await run.task except asyncio.CancelledError: pass - ws_mod._runs.pop("s1", None) + orchestrator._runs.pop("s1", None)