"""Agent session orchestrator — manages active runs and recall lifecycle."""
from __future__ import annotations
import asyncio
import dataclasses
from dataclasses import dataclass, field
from datetime import datetime, timedelta, timezone
from typing import TYPE_CHECKING, Any, Callable
import structlog
from navi.config import settings
if TYPE_CHECKING:
from fastapi import WebSocket
from navi.core.container import AppContainer
log = structlog.get_logger()
@dataclass
class SessionRun:
"""Holds the running agent task and all active subscriber queues."""
task: asyncio.Task | None = None
stop_event: asyncio.Event = dataclasses.field(default_factory=asyncio.Event)
subscribers: list[asyncio.Queue] = dataclasses.field(default_factory=list)
# Replay buffer: all serialised event dicts emitted so far this turn.
events: list[dict] = dataclasses.field(default_factory=list)
def subscribe(self) -> asyncio.Queue:
q: asyncio.Queue = asyncio.Queue()
self.subscribers.append(q)
return q
def unsubscribe(self, q: asyncio.Queue) -> None:
try:
self.subscribers.remove(q)
except ValueError:
pass
async def broadcast(self, item) -> None:
kind, payload = item
if kind == "event":
ev_dict = _event_to_dict(payload)
if ev_dict:
self.events.append(ev_dict)
if len(self.events) > settings.ws_replay_buffer_size:
self.events.pop(0)
for q in list(self.subscribers):
await q.put(item)
@dataclass
class SessionState:
"""All ephemeral in-memory state for a single session."""
run: SessionRun | None = None
busy_event: asyncio.Event | None = None
websockets: list[Any] = field(default_factory=list)
def _event_to_dict(event) -> dict | None:
if hasattr(event, "to_wire"):
return event.to_wire()
return None
class AgentSessionOrchestrator:
"""Owns all active agent runs, headless recall sessions, and connected transports."""
def __init__(self, container: AppContainer) -> None:
self._container = container
self._sessions: dict[str, SessionState] = {}
self._session_locks: dict[str, asyncio.Lock] = {}
# Wire event bus subscriber so recall updates reach connected clients
from navi.core.event_bus import get_event_bus
from navi.core.events import RecallUpdate, McpStatusUpdate
get_event_bus().subscribe(RecallUpdate, self._on_recall_update)
get_event_bus().subscribe(McpStatusUpdate, self._on_mcp_status_update)
async def _notify_session(self, session_id: str, payload: dict) -> None:
"""Send a JSON payload to every open WebSocket for the given session."""
state = self._sessions.get(session_id)
if state is None:
return
for ws in list(state.websockets):
try:
await ws.send_json(payload)
except Exception:
pass
async def _broadcast_all_sessions(self, payload: dict) -> None:
"""Send a JSON payload to every open WebSocket across all sessions."""
for session_id, state in self._sessions.items():
for ws in list(state.websockets):
try:
await ws.send_json(payload)
except Exception:
pass
async def _on_recall_update(self, event: Any) -> None:
from navi.core.events import RecallUpdate
if isinstance(event, RecallUpdate) and event.session_id:
payload = event.to_wire()
if payload:
await self._notify_session(event.session_id, payload)
async def _on_mcp_status_update(self, event: Any) -> None:
from navi.core.events import McpStatusUpdate
if isinstance(event, McpStatusUpdate):
payload = event.to_wire()
if payload:
await self._broadcast_all_sessions(payload)
def _get_or_create_state(self, session_id: str) -> SessionState:
state = self._sessions.get(session_id)
if state is None:
state = SessionState()
self._sessions[session_id] = state
return state
async def _cleanup(self, session_id: str) -> None:
"""Remove session entry if it holds no run, busy flag, or websockets."""
async with self.session_lock(session_id):
state = self._sessions.get(session_id)
if state is None:
return
if state.run is None and state.busy_event is None and not state.websockets:
self._sessions.pop(session_id, None)
self._session_locks.pop(session_id, None)
# ── Session lock ────────────────────────────────────────────────────────────
def session_lock(self, session_id: str) -> asyncio.Lock:
"""Return (and create if needed) a per-session asyncio.Lock."""
lock = self._session_locks.get(session_id)
if lock is None:
lock = asyncio.Lock()
self._session_locks[session_id] = lock
return lock
# ── WebSocket tracking ──────────────────────────────────────────────────────
def add_websocket(self, session_id: str, websocket: "WebSocket") -> None:
state = self._get_or_create_state(session_id)
state.websockets.append(websocket)
async def remove_websocket(self, session_id: str, websocket: "WebSocket") -> None:
state = self._sessions.get(session_id)
if state is None:
return
try:
state.websockets.remove(websocket)
except ValueError:
pass
await self._cleanup(session_id)
# ── Run state ─────────────────────────────────────────────────────────────
def is_running(self, session_id: str) -> bool:
state = self._sessions.get(session_id)
if state is None:
return False
return state.run is not None or state.busy_event is not None
def stop(self, session_id: str) -> bool:
state = self._sessions.get(session_id)
if state is None:
return False
if state.run is not None:
state.run.stop_event.set()
return True
if state.busy_event is not None:
state.busy_event.set()
return True
return False
def get_run(self, session_id: str) -> SessionRun | None:
state = self._sessions.get(session_id)
return state.run if state else None
def create_run(self, session_id: str) -> SessionRun:
state = self._get_or_create_state(session_id)
run = SessionRun()
state.run = run
return run
def mark_busy(self, session_id: str, stop_event: asyncio.Event | None = None) -> None:
state = self._get_or_create_state(session_id)
state.busy_event = stop_event or asyncio.Event()
async def clear_busy(self, session_id: str) -> None:
state = self._sessions.get(session_id)
if state is not None:
state.busy_event = None
state.run = None
await self._cleanup(session_id)
# ── Agent factory ─────────────────────────────────────────────────────────
def _build_agent(self, session_store):
from navi.core import Agent
try:
mcp_manager = self._container.mcp_manager
except Exception:
mcp_manager = None
return Agent(
session_store,
self._container.profile_registry,
self._container.tool_registry,
self._container.backend_registry,
workers=self._container.workers,
memory_store=self._container.memory_store,
cp_registry=self._container.cp_registry,
mcp_manager=mcp_manager,
)
# ── Main run loop (WebSocket-driven) ──────────────────────────────────────
async def run_agent(
self,
session_id: str,
user_content: str,
raw_images: list[str] | None,
display_content: str | None,
files: list[dict] | None,
session_store,
) -> None:
"""Execute the agent to completion, broadcasting events to subscribers."""
from navi.tools._internal.base import current_stop_event
from navi.exceptions import MaxIterationsReached, NaviError, SessionNotFound
run = self._sessions[session_id].run
current_stop_event.set(run.stop_event)
agent = self._build_agent(session_store)
try:
async for event in agent.run_stream(
session_id,
user_content,
images=raw_images,
display_message=display_content,
files=files,
):
await run.broadcast(("event", event))
except asyncio.CancelledError:
log.info("ws.agent_stopped", session_id=session_id)
await run.broadcast(("stopped", None))
raise
except SessionNotFound:
await run.broadcast(("error", "Session not found"))
except MaxIterationsReached as e:
await run.broadcast(("error", str(e)))
except NaviError as e:
await run.broadcast(("error", str(e)))
except Exception as e:
log.exception("ws.agent_error", session_id=session_id)
await run.broadcast(("error", f"Internal error: {e}"))
finally:
await run.broadcast(("done", None))
state = self._sessions.get(session_id)
if state is not None:
state.run = None
await self._cleanup(session_id)
# ── Headless recall run ───────────────────────────────────────────────────
async def _finalize_recall(self, recall, scheduler, *, outcome: str) -> None:
"""Reschedule/mark fired/mark cancelled and publish update."""
from navi.core.scheduler import _publish_recall_update
if recall.call_type == "recurring":
next_trigger = datetime.now(timezone.utc) + timedelta(
seconds=recall.interval_seconds or 0
)
await scheduler.reschedule(recall.id, next_trigger)
await _publish_recall_update(
recall.session_id, recall.id, recall.call_type,
trigger_at=next_trigger.isoformat(), status="pending", action="rescheduled"
)
return
# one-time
if outcome == "success":
await scheduler.mark_fired(recall.id)
await _publish_recall_update(
recall.session_id, recall.id, recall.call_type,
trigger_at=recall.trigger_at.isoformat(), status="fired", action="fired"
)
elif outcome == "failed":
await scheduler.mark_cancelled(recall.id)
await _publish_recall_update(
recall.session_id, recall.id, recall.call_type,
trigger_at=recall.trigger_at.isoformat(), status="cancelled", action="cancelled"
)
else:
# max_iterations for one-time -> mark_fired (preserving existing behaviour)
await scheduler.mark_fired(recall.id)
await _publish_recall_update(
recall.session_id, recall.id, recall.call_type,
trigger_at=recall.trigger_at.isoformat(), status="fired", action="fired"
)
async def run_recall(
self,
recall,
scheduler,
store,
) -> None:
"""Execute a scheduled recall headlessly and notify connected clients."""
from navi.core.agent import Agent, MaxIterationsReached
from navi.core.event_bus import get_event_bus
from navi.core.events import StreamEnd
from navi.tools._internal.base import current_stop_event
async with self.session_lock(recall.session_id):
# Guard: if a websocket run is active for this session, defer by 60 seconds
state = self._sessions.get(recall.session_id)
if state is not None and state.run is not None:
log.info("scheduler.defer_busy", session_id=recall.session_id)
await scheduler.reschedule(
recall.id, datetime.now(timezone.utc) + timedelta(seconds=60)
)
return
session = await store.get(recall.session_id)
if session is None:
log.warning("scheduler.session_missing", recall_id=recall.id)
await scheduler.mark_cancelled(recall.id)
return
agent = self._build_agent(store)
# Set user context for tools
from navi.tools._internal.base import (
current_user_id as _uid_var,
current_user_role as _role_var,
current_user_info as _uinfo_var,
)
if session and session.user_id is not None:
uid_token = _uid_var.set(session.user_id)
role_token = _role_var.set("user")
uinfo_token = _uinfo_var.set(None)
else:
uid_token = _uid_var.set(None)
role_token = _role_var.set("user")
uinfo_token = _uinfo_var.set(None)
stop_event = asyncio.Event()
self.mark_busy(recall.session_id, stop_event)
token = current_stop_event.set(stop_event)
run = self.create_run(recall.session_id)
accumulated_text = ""
try:
await self._notify_session(recall.session_id, {"type": "stream_start"})
async for event in agent.run_stream(
session_id=recall.session_id,
user_message=recall.additional_context_message,
display_message=recall.additional_context_message,
is_recall=True,
):
await get_event_bus().publish(event)
wire = event.to_wire()
if wire:
await self._notify_session(recall.session_id, wire)
await run.broadcast(("event", event))
if isinstance(event, StreamEnd):
accumulated_text = event.full_content
log.info(
"scheduler.recall_fired",
recall_id=recall.id,
session_id=recall.session_id,
reply_len=len(accumulated_text),
)
await self._finalize_recall(recall, scheduler, outcome="success")
except MaxIterationsReached:
log.info("scheduler.max_iterations", recall_id=recall.id)
await self._finalize_recall(recall, scheduler, outcome="max_iterations")
except Exception:
log.exception("scheduler.recall_failed", recall_id=recall.id)
await self._finalize_recall(recall, scheduler, outcome="failed")
finally:
await self.clear_busy(recall.session_id)
current_stop_event.reset(token)
_uid_var.reset(uid_token)
_role_var.reset(role_token)
_uinfo_var.reset(uinfo_token)
state = self._sessions.get(recall.session_id)
if state is not None:
state.run = None
await self._cleanup(recall.session_id)
await self._notify_session(recall.session_id, {"type": "session_sync"})