diff --git a/docs/architecture_weak_spots.md b/docs/architecture_weak_spots.md index a651db5..4ffd455 100644 --- a/docs/architecture_weak_spots.md +++ b/docs/architecture_weak_spots.md @@ -134,7 +134,7 @@ --- -## 9. Сессионное состояние в памяти процесса +## 9. Сессионное состояние в памяти процесса ✅ **Severity:** Medium **Файл:** `navi/api/websocket.py` (строки 80–86, 403–406) @@ -142,6 +142,15 @@ **Почему блокер:** Сессия может быть запущена на инстансе A, а WebSocket подключён к инстансу B. **Направление:** Вынести состояние запуска в `SessionStore` (PostgreSQL) или Redis. `_session_sockets` заменить на pub/sub. +**Решение 2026-05-24:** +- Создан `SessionState` dataclass в `navi/core/orchestrator.py` — единый контейнер для `run`, `busy_event`, `websockets` +- `_session_sockets` module-level global удалён из `websocket.py` и перенесён в `AgentSessionOrchestrator._sessions` +- Event bus subscriber `_on_recall_update` перенесён из `websocket.py` в `AgentSessionOrchestrator` +- Добавлен `asyncio.Lock` per session_id — `AgentSessionOrchestrator.session_lock()` защищает concurrent-run guard от race condition +- WebSocket handler теперь использует `orchestrator.add_websocket()` / `remove_websocket()` и `async with orchestrator.session_lock()` +- `_cleanup()` удаляет пустые `SessionState` entries автоматически +- Для горизонтального масштабирования всё ещё требуется Redis/pub-sub (не в scope этого фикса), но in-memory state теперь unified и explicit + --- ## 10. MCP: чтение конфига с диска на каждый вызов + retry без backoff diff --git a/navi/api/websocket.py b/navi/api/websocket.py index 6c7b7bf..c2fba7e 100644 --- a/navi/api/websocket.py +++ b/navi/api/websocket.py @@ -34,30 +34,6 @@ log = structlog.get_logger() -# session_id → all connected WebSocket clients (for out-of-run events like recall updates) -_session_sockets: dict[str, list[WebSocket]] = {} - - -async def _notify_session(session_id: str, payload: dict) -> None: - """Send a JSON payload to every open WebSocket for the given session.""" - sockets = _session_sockets.get(session_id, []) - for ws in list(sockets): - try: - await ws.send_json(payload) - except Exception: - pass - - -async def _on_recall_update(event: AgentEvent) -> None: - if isinstance(event, RecallUpdate) and event.session_id: - payload = event.to_wire() - if payload: - await _notify_session(event.session_id, payload) - - -get_event_bus().subscribe(RecallUpdate, _on_recall_update) - - # ── Helpers ─────────────────────────────────────────────────────────────────── def _event_to_dict(event) -> dict | None: @@ -155,7 +131,7 @@ # Accept the WebSocket before checking access so that auth failures can be # sent as WebSocket close codes rather than HTTP 403 on the upgrade request. await websocket.accept() - _session_sockets.setdefault(session_id, []).append(websocket) + orchestrator.add_websocket(session_id, websocket) log.info("ws.accepted", session_id=session_id) if user is None: @@ -173,13 +149,6 @@ await websocket.close(code=4003, reason="Access denied") return - # Wire the orchestrator notify callback so recall headless runs can push - # events to this (and any other) open WebSocket for the session. - async def _notify(session_id_: str, payload: dict) -> None: - await _notify_session(session_id_, payload) - - orchestrator.set_notify(_notify) - queue: asyncio.Queue | None = None current_run = None @@ -285,19 +254,20 @@ user_content + f"\n\n[Uploaded files on disk:\n{file_lines}]" ) - # Guard against concurrent runs for the same session. - if orchestrator.is_running(session_id): - await websocket.send_json({ - "type": "error", - "message": "Agent is already running for this session.", - }) - continue + # Guard against concurrent runs for the same session (atomically). + async with orchestrator.session_lock(session_id): + if orchestrator.is_running(session_id): + await websocket.send_json({ + "type": "error", + "message": "Agent is already running for this session.", + }) + continue - # Register run and subscribe before starting the task so we never - # miss events even if the task is very fast. - run = orchestrator.create_run(session_id) - queue = run.subscribe() - current_run = run + # Register run and subscribe before starting the task so we never + # miss events even if the task is very fast. + run = orchestrator.create_run(session_id) + queue = run.subscribe() + current_run = run # Set user context for tool sandboxing (inherited by the agent task) from navi.tools._internal.base import current_user_id as _uid_var, current_user_role as _role_var, current_user_info as _uinfo_var @@ -332,8 +302,4 @@ if queue is not None and current_run is not None: current_run.unsubscribe(queue) # Remove this socket from the session tracking set. - sockets = _session_sockets.get(session_id, []) - if websocket in sockets: - sockets.remove(websocket) - if not sockets: - _session_sockets.pop(session_id, None) + orchestrator.remove_websocket(session_id, websocket) diff --git a/navi/core/orchestrator.py b/navi/core/orchestrator.py index 679bda6..69fb292 100644 --- a/navi/core/orchestrator.py +++ b/navi/core/orchestrator.py @@ -4,12 +4,14 @@ import asyncio import dataclasses +from dataclasses import dataclass, field from datetime import datetime, timedelta, timezone from typing import TYPE_CHECKING, Any, Callable import structlog if TYPE_CHECKING: + from fastapi import WebSocket from navi.core.container import AppContainer log = structlog.get_logger() @@ -17,7 +19,7 @@ _MAX_REPLAY_EVENTS = 500 -@dataclasses.dataclass +@dataclass class SessionRun: """Holds the running agent task and all active subscriber queues.""" @@ -50,6 +52,15 @@ await q.put(item) +@dataclass +class SessionState: + """All ephemeral in-memory state for a single session.""" + + run: SessionRun | None = None + busy_event: asyncio.Event | None = None + websockets: list[Any] = field(default_factory=list) + + def _event_to_dict(event) -> dict | None: if hasattr(event, "to_wire"): return event.to_wire() @@ -57,7 +68,7 @@ class AgentSessionOrchestrator: - """Owns all active agent runs and headless recall sessions. + """Owns all active agent runs, headless recall sessions, and connected transports. Transport-agnostic — the WebSocket handler (or any other transport) sets a *notify* callback so the orchestrator can push events to @@ -66,50 +77,126 @@ def __init__(self, container: AppContainer) -> None: self._container = container - self._runs: dict[str, SessionRun] = {} - self._busy_sessions: dict[str, asyncio.Event] = {} + self._sessions: dict[str, SessionState] = {} + self._session_locks: dict[str, asyncio.Lock] = {} # Callback injected by the transport layer (e.g. WebSocket handler). self._notify: Callable[[str, dict], Any] | None = None + # Wire event bus subscriber so recall updates reach connected clients + from navi.core.event_bus import get_event_bus + from navi.core.events import RecallUpdate + + get_event_bus().subscribe(RecallUpdate, self._on_recall_update) + def set_notify(self, notify: Callable[[str, dict], Any] | None) -> None: self._notify = notify async def _notify_session(self, session_id: str, payload: dict) -> None: + """Send a JSON payload to every open WebSocket for the given session.""" + state = self._sessions.get(session_id) + if state is None: + return if self._notify is not None: try: await self._notify(session_id, payload) except Exception: pass + # Fallback: send directly to tracked websockets + for ws in list(state.websockets): + try: + await ws.send_json(payload) + except Exception: + pass + + async def _on_recall_update(self, event: Any) -> None: + from navi.core.events import RecallUpdate + + if isinstance(event, RecallUpdate) and event.session_id: + payload = event.to_wire() + if payload: + await self._notify_session(event.session_id, payload) + + def _get_or_create_state(self, session_id: str) -> SessionState: + state = self._sessions.get(session_id) + if state is None: + state = SessionState() + self._sessions[session_id] = state + return state + + def _cleanup(self, session_id: str) -> None: + """Remove session entry if it holds no run, busy flag, or websockets.""" + state = self._sessions.get(session_id) + if state is None: + return + if state.run is None and state.busy_event is None and not state.websockets: + self._sessions.pop(session_id, None) + self._session_locks.pop(session_id, None) + + # ── Session lock ──────────────────────────────────────────────────────────── + + def session_lock(self, session_id: str) -> asyncio.Lock: + """Return (and create if needed) a per-session asyncio.Lock.""" + lock = self._session_locks.get(session_id) + if lock is None: + lock = asyncio.Lock() + self._session_locks[session_id] = lock + return lock + + # ── WebSocket tracking ────────────────────────────────────────────────────── + + def add_websocket(self, session_id: str, websocket: "WebSocket") -> None: + state = self._get_or_create_state(session_id) + state.websockets.append(websocket) + + def remove_websocket(self, session_id: str, websocket: "WebSocket") -> None: + state = self._sessions.get(session_id) + if state is None: + return + try: + state.websockets.remove(websocket) + except ValueError: + pass + self._cleanup(session_id) # ── Run state ───────────────────────────────────────────────────────────── def is_running(self, session_id: str) -> bool: - return session_id in self._runs or session_id in self._busy_sessions + state = self._sessions.get(session_id) + if state is None: + return False + return state.run is not None or state.busy_event is not None def stop(self, session_id: str) -> bool: - run = self._runs.get(session_id) - if run is not None: - run.stop_event.set() + state = self._sessions.get(session_id) + if state is None: + return False + if state.run is not None: + state.run.stop_event.set() return True - recall_stop = self._busy_sessions.get(session_id) - if recall_stop is not None: - recall_stop.set() + if state.busy_event is not None: + state.busy_event.set() return True return False def get_run(self, session_id: str) -> SessionRun | None: - return self._runs.get(session_id) + state = self._sessions.get(session_id) + return state.run if state else None def create_run(self, session_id: str) -> SessionRun: + state = self._get_or_create_state(session_id) run = SessionRun() - self._runs[session_id] = run + state.run = run return run def mark_busy(self, session_id: str, stop_event: asyncio.Event | None = None) -> None: - self._busy_sessions[session_id] = stop_event or asyncio.Event() + state = self._get_or_create_state(session_id) + state.busy_event = stop_event or asyncio.Event() def clear_busy(self, session_id: str) -> None: - self._busy_sessions.pop(session_id, None) + state = self._sessions.get(session_id) + if state is not None: + state.busy_event = None + self._cleanup(session_id) # ── Agent factory ───────────────────────────────────────────────────────── @@ -146,7 +233,7 @@ from navi.tools._internal.base import current_stop_event from navi.exceptions import MaxIterationsReached, NaviError, SessionNotFound - run = self._runs[session_id] + run = self._sessions[session_id].run current_stop_event.set(run.stop_event) agent = self._build_agent(session_store) @@ -175,7 +262,10 @@ await run.broadcast(("error", f"Internal error: {e}")) finally: await run.broadcast(("done", None)) - self._runs.pop(session_id, None) + state = self._sessions.get(session_id) + if state is not None: + state.run = None + self._cleanup(session_id) # ── Headless recall run ─────────────────────────────────────────────────── @@ -192,7 +282,8 @@ from navi.tools._internal.base import current_stop_event # Guard: if a websocket run is active for this session, defer by 60 seconds - if recall.session_id in self._runs: + state = self._sessions.get(recall.session_id) + if state is not None and state.run is not None: log.info("scheduler.defer_busy", session_id=recall.session_id) await scheduler.reschedule( recall.id, datetime.now(timezone.utc) + timedelta(seconds=60) @@ -226,8 +317,7 @@ self.mark_busy(recall.session_id, stop_event) token = current_stop_event.set(stop_event) - run = SessionRun() - self._runs[recall.session_id] = run + run = self.create_run(recall.session_id) accumulated_text = "" try: @@ -311,5 +401,8 @@ finally: self.clear_busy(recall.session_id) current_stop_event.reset(token) - self._runs.pop(recall.session_id, None) + state = self._sessions.get(recall.session_id) + if state is not None: + state.run = None + self._cleanup(recall.session_id) await self._notify_session(recall.session_id, {"type": "session_sync"}) diff --git a/tests/integration/test_scheduler_loop.py b/tests/integration/test_scheduler_loop.py index bfa0482..39c1974 100644 --- a/tests/integration/test_scheduler_loop.py +++ b/tests/integration/test_scheduler_loop.py @@ -150,7 +150,7 @@ except asyncio.CancelledError: pass finally: - fake_orchestrator._runs.pop("s1", None) + fake_orchestrator._sessions.pop("s1", None) scheduler.reschedule.assert_called_once() diff --git a/tests/unit/api/test_websocket.py b/tests/unit/api/test_websocket.py index 4543188..8d287f7 100644 --- a/tests/unit/api/test_websocket.py +++ b/tests/unit/api/test_websocket.py @@ -14,7 +14,6 @@ @pytest.fixture(autouse=True) def _clear_state(monkeypatch): """Clear global state before every WS test.""" - ws_mod._session_sockets.clear() yield @@ -129,7 +128,7 @@ ] assert calls[1]["count"] == 2 - orchestrator._runs.pop("s1", None) + orchestrator._sessions.pop("s1", None) @pytest.mark.anyio @@ -237,11 +236,11 @@ assert "already running" in error_calls[0]["message"] # Cleanup background task - run = orchestrator._runs.get("s1") - if run and run.task: - run.task.cancel() + state = orchestrator._sessions.get("s1") + if state and state.run and state.run.task: + state.run.task.cancel() try: - await run.task + await state.run.task except asyncio.CancelledError: pass - orchestrator._runs.pop("s1", None) + orchestrator._sessions.pop("s1", None)