diff --git a/navi/api/websocket.py b/navi/api/websocket.py index c2fba7e..576fa08 100644 --- a/navi/api/websocket.py +++ b/navi/api/websocket.py @@ -272,28 +272,33 @@ # Set user context for tool sandboxing (inherited by the agent task) from navi.tools._internal.base import current_user_id as _uid_var, current_user_role as _role_var, current_user_info as _uinfo_var if user is not None: - _uid_var.set(user.id) - _role_var.set(user.role) - _uinfo_var.set(user.model_dump(mode="json")) + uid_token = _uid_var.set(user.id) + role_token = _role_var.set(user.role) + uinfo_token = _uinfo_var.set(user.model_dump(mode="json")) else: - _uid_var.set(None) - _role_var.set("user") - _uinfo_var.set(None) + uid_token = _uid_var.set(None) + role_token = _role_var.set("user") + uinfo_token = _uinfo_var.set(None) - run.task = asyncio.create_task( - orchestrator.run_agent( - session_id, user_content, raw_images, original_content, uploaded_files, session_store + try: + run.task = asyncio.create_task( + orchestrator.run_agent( + session_id, user_content, raw_images, original_content, uploaded_files, session_store + ) ) - ) - await websocket.send_json({"type": "stream_start"}) - connected = await _stream_to_client(websocket, queue) - run.unsubscribe(queue) - queue = None - current_run = None + await websocket.send_json({"type": "stream_start"}) + connected = await _stream_to_client(websocket, queue) + run.unsubscribe(queue) + queue = None + current_run = None - if not connected: - break # avoid calling receive_text() on a dead socket + if not connected: + break # avoid calling receive_text() on a dead socket + finally: + _uid_var.reset(uid_token) + _role_var.reset(role_token) + _uinfo_var.reset(uinfo_token) except (WebSocketDisconnect, RuntimeError): log.info("ws.disconnected", session_id=session_id) diff --git a/navi/core/orchestrator.py b/navi/core/orchestrator.py index 69fb292..91bcef6 100644 --- a/navi/core/orchestrator.py +++ b/navi/core/orchestrator.py @@ -68,19 +68,12 @@ class AgentSessionOrchestrator: - """Owns all active agent runs, headless recall sessions, and connected transports. - - Transport-agnostic — the WebSocket handler (or any other transport) - sets a *notify* callback so the orchestrator can push events to - connected clients without knowing about WebSockets directly. - """ + """Owns all active agent runs, headless recall sessions, and connected transports.""" def __init__(self, container: AppContainer) -> None: self._container = container self._sessions: dict[str, SessionState] = {} self._session_locks: dict[str, asyncio.Lock] = {} - # Callback injected by the transport layer (e.g. WebSocket handler). - self._notify: Callable[[str, dict], Any] | None = None # Wire event bus subscriber so recall updates reach connected clients from navi.core.event_bus import get_event_bus @@ -88,20 +81,11 @@ get_event_bus().subscribe(RecallUpdate, self._on_recall_update) - def set_notify(self, notify: Callable[[str, dict], Any] | None) -> None: - self._notify = notify - async def _notify_session(self, session_id: str, payload: dict) -> None: """Send a JSON payload to every open WebSocket for the given session.""" state = self._sessions.get(session_id) if state is None: return - if self._notify is not None: - try: - await self._notify(session_id, payload) - except Exception: - pass - # Fallback: send directly to tracked websockets for ws in list(state.websockets): try: await ws.send_json(payload) @@ -269,6 +253,42 @@ # ── Headless recall run ─────────────────────────────────────────────────── + async def _finalize_recall(self, recall, scheduler, *, outcome: str) -> None: + """Reschedule/mark fired/mark cancelled and publish update.""" + from navi.core.scheduler import _publish_recall_update + + if recall.call_type == "recurring": + next_trigger = datetime.now(timezone.utc) + timedelta( + seconds=recall.interval_seconds or 0 + ) + await scheduler.reschedule(recall.id, next_trigger) + await _publish_recall_update( + recall.session_id, recall.id, recall.call_type, + trigger_at=next_trigger.isoformat(), status="pending", action="rescheduled" + ) + return + + # one-time + if outcome == "success": + await scheduler.mark_fired(recall.id) + await _publish_recall_update( + recall.session_id, recall.id, recall.call_type, + trigger_at=recall.trigger_at.isoformat(), status="fired", action="fired" + ) + elif outcome == "failed": + await scheduler.mark_cancelled(recall.id) + await _publish_recall_update( + recall.session_id, recall.id, recall.call_type, + trigger_at=recall.trigger_at.isoformat(), status="cancelled", action="cancelled" + ) + else: + # max_iterations for one-time -> mark_fired (preserving existing behaviour) + await scheduler.mark_fired(recall.id) + await _publish_recall_update( + recall.session_id, recall.id, recall.call_type, + trigger_at=recall.trigger_at.isoformat(), status="fired", action="fired" + ) + async def run_recall( self, recall, @@ -281,43 +301,44 @@ from navi.core.events import StreamEnd from navi.tools._internal.base import current_stop_event - # Guard: if a websocket run is active for this session, defer by 60 seconds - state = self._sessions.get(recall.session_id) - if state is not None and state.run is not None: - log.info("scheduler.defer_busy", session_id=recall.session_id) - await scheduler.reschedule( - recall.id, datetime.now(timezone.utc) + timedelta(seconds=60) + async with self.session_lock(recall.session_id): + # Guard: if a websocket run is active for this session, defer by 60 seconds + state = self._sessions.get(recall.session_id) + if state is not None and state.run is not None: + log.info("scheduler.defer_busy", session_id=recall.session_id) + await scheduler.reschedule( + recall.id, datetime.now(timezone.utc) + timedelta(seconds=60) + ) + return + + session = await store.get(recall.session_id) + if session is None: + log.warning("scheduler.session_missing", recall_id=recall.id) + await scheduler.mark_cancelled(recall.id) + return + + # Set user context for tools + from navi.tools._internal.base import ( + current_user_id as _uid_var, + current_user_role as _role_var, + current_user_info as _uinfo_var, ) - return + if session and session.user_id is not None: + uid_token = _uid_var.set(session.user_id) + role_token = _role_var.set("user") + uinfo_token = _uinfo_var.set(None) + else: + uid_token = _uid_var.set(None) + role_token = _role_var.set("user") + uinfo_token = _uinfo_var.set(None) - session = await store.get(recall.session_id) - if session is None: - log.warning("scheduler.session_missing", recall_id=recall.id) - await scheduler.mark_cancelled(recall.id) - return + agent = self._build_agent(store) - # Set user context for tools - from navi.tools._internal.base import ( - current_user_id as _uid_var, - current_user_role as _role_var, - current_user_info as _uinfo_var, - ) - if session and session.user_id is not None: - _uid_var.set(session.user_id) - _role_var.set("user") - _uinfo_var.set(None) - else: - _uid_var.set(None) - _role_var.set("user") - _uinfo_var.set(None) + stop_event = asyncio.Event() + self.mark_busy(recall.session_id, stop_event) + token = current_stop_event.set(stop_event) - agent = self._build_agent(store) - - stop_event = asyncio.Event() - self.mark_busy(recall.session_id, stop_event) - token = current_stop_event.set(stop_event) - - run = self.create_run(recall.session_id) + run = self.create_run(recall.session_id) accumulated_text = "" try: @@ -343,64 +364,19 @@ session_id=recall.session_id, reply_len=len(accumulated_text), ) - if recall.call_type == "recurring": - next_trigger = datetime.now(timezone.utc) + timedelta( - seconds=recall.interval_seconds or 0 - ) - await scheduler.reschedule(recall.id, next_trigger) - from navi.core.scheduler import _publish_recall_update - await _publish_recall_update( - recall.session_id, recall.id, recall.call_type, - trigger_at=next_trigger.isoformat(), status="pending", action="rescheduled" - ) - else: - await scheduler.mark_fired(recall.id) - from navi.core.scheduler import _publish_recall_update - await _publish_recall_update( - recall.session_id, recall.id, recall.call_type, - trigger_at=recall.trigger_at.isoformat(), status="fired", action="fired" - ) + await self._finalize_recall(recall, scheduler, outcome="success") except MaxIterationsReached: log.info("scheduler.max_iterations", recall_id=recall.id) - if recall.call_type == "recurring": - next_trigger = datetime.now(timezone.utc) + timedelta( - seconds=recall.interval_seconds or 0 - ) - await scheduler.reschedule(recall.id, next_trigger) - from navi.core.scheduler import _publish_recall_update - await _publish_recall_update( - recall.session_id, recall.id, recall.call_type, - trigger_at=next_trigger.isoformat(), status="pending", action="rescheduled" - ) - else: - await scheduler.mark_fired(recall.id) - from navi.core.scheduler import _publish_recall_update - await _publish_recall_update( - recall.session_id, recall.id, recall.call_type, - trigger_at=recall.trigger_at.isoformat(), status="fired", action="fired" - ) + await self._finalize_recall(recall, scheduler, outcome="max_iterations") except Exception: log.exception("scheduler.recall_failed", recall_id=recall.id) - if recall.call_type == "recurring": - next_trigger = datetime.now(timezone.utc) + timedelta( - seconds=recall.interval_seconds or 0 - ) - await scheduler.reschedule(recall.id, next_trigger) - from navi.core.scheduler import _publish_recall_update - await _publish_recall_update( - recall.session_id, recall.id, recall.call_type, - trigger_at=next_trigger.isoformat(), status="pending", action="rescheduled" - ) - else: - await scheduler.mark_cancelled(recall.id) - from navi.core.scheduler import _publish_recall_update - await _publish_recall_update( - recall.session_id, recall.id, recall.call_type, - trigger_at=recall.trigger_at.isoformat(), status="cancelled", action="cancelled" - ) + await self._finalize_recall(recall, scheduler, outcome="failed") finally: self.clear_busy(recall.session_id) current_stop_event.reset(token) + _uid_var.reset(uid_token) + _role_var.reset(role_token) + _uinfo_var.reset(uinfo_token) state = self._sessions.get(recall.session_id) if state is not None: state.run = None