Newer
Older
navi-1 / navi / core / scheduler.py
@Eugene Sukhodolskiy Eugene Sukhodolskiy on 15 May 13 KB Add self-recall (scheduled callback) system
"""Scheduled recall system — background scheduler that fires headless agent runs."""

import asyncio
import uuid
from dataclasses import dataclass
from datetime import datetime, timedelta, timezone
from typing import Any

import structlog

from navi.llm.base import Message

log = structlog.get_logger()

_DDL = """
CREATE TABLE IF NOT EXISTS session_recalls (
    id                  TEXT PRIMARY KEY,
    session_id          TEXT NOT NULL REFERENCES sessions(id) ON DELETE CASCADE,
    call_type           TEXT NOT NULL CHECK (call_type IN ('once', 'recurring', 'immediate')),
    trigger_at          TIMESTAMPTZ NOT NULL,
    interval_seconds    INTEGER,
    internal_comment    TEXT,
    additional_context_message TEXT NOT NULL,
    status              TEXT NOT NULL DEFAULT 'pending'
                        CHECK (status IN ('pending', 'fired', 'cancelled')),
    created_at          TIMESTAMPTZ NOT NULL,
    updated_at          TIMESTAMPTZ NOT NULL
);

CREATE INDEX IF NOT EXISTS idx_recalls_trigger_at_status
    ON session_recalls (trigger_at, status);

CREATE INDEX IF NOT EXISTS idx_recalls_session_status
    ON session_recalls (session_id, status);

CREATE UNIQUE INDEX IF NOT EXISTS idx_recalls_active_pending
    ON session_recalls (session_id) WHERE status = 'pending';
"""


@dataclass(frozen=True, slots=True)
class Recall:
    id: str
    session_id: str
    call_type: str
    trigger_at: datetime
    interval_seconds: int | None
    internal_comment: str | None
    additional_context_message: str
    status: str
    created_at: datetime
    updated_at: datetime


class RecallExistsError(Exception):
    """Raised when a session already has a pending recall."""


class RecallScheduler:
    """PostgreSQL-backed scheduler for session recalls."""

    def __init__(self, dsn: str) -> None:
        self._dsn = dsn
        self._pool: Any | None = None
        self._lock = asyncio.Lock()

    async def _get_pool(self) -> Any:
        if self._pool is not None:
            return self._pool
        async with self._lock:
            if self._pool is not None:
                return self._pool
            import asyncpg

            pool = await asyncpg.create_pool(self._dsn)
            async with pool.acquire() as conn:
                await conn.execute(_DDL)
            self._pool = pool
        return self._pool

    async def ensure_tables(self) -> None:
        await self._get_pool()

    def _row_to_recall(self, row: Any) -> Recall:
        return Recall(
            id=row["id"],
            session_id=row["session_id"],
            call_type=row["call_type"],
            trigger_at=row["trigger_at"],
            interval_seconds=row["interval_seconds"],
            internal_comment=row["internal_comment"],
            additional_context_message=row["additional_context_message"],
            status=row["status"],
            created_at=row["created_at"],
            updated_at=row["updated_at"],
        )

    async def schedule_recall(
        self,
        *,
        session_id: str,
        call_type: str,
        trigger_at: datetime,
        interval_seconds: int | None = None,
        internal_comment: str | None = None,
        additional_context_message: str,
    ) -> Recall:
        pool = await self._get_pool()
        recall_id = str(uuid.uuid4())
        now = datetime.now(timezone.utc)
        async with pool.acquire() as conn:
            try:
                row = await conn.fetchrow(
                    """
                    INSERT INTO session_recalls (
                        id, session_id, call_type, trigger_at, interval_seconds,
                        internal_comment, additional_context_message, status, created_at, updated_at
                    ) VALUES ($1, $2, $3, $4, $5, $6, $7, 'pending', $8, $8)
                    RETURNING *
                    """,
                    recall_id,
                    session_id,
                    call_type,
                    trigger_at,
                    interval_seconds,
                    internal_comment,
                    additional_context_message,
                    now,
                )
            except Exception as exc:
                if "unique" in str(exc).lower() or "duplicate" in str(exc).lower():
                    raise RecallExistsError(
                        f"Session {session_id} already has a pending recall."
                    ) from exc
                raise
        return self._row_to_recall(row)

    async def cancel_recall(self, session_id: str) -> bool:
        pool = await self._get_pool()
        async with pool.acquire() as conn:
            result = await conn.execute(
                """
                UPDATE session_recalls
                SET status = 'cancelled', updated_at = $1
                WHERE session_id = $2 AND status = 'pending'
                """,
                datetime.now(timezone.utc),
                session_id,
            )
        # asyncpg execute returns a status string like "UPDATE 1"
        return "UPDATE 1" in result or result.endswith(" 1")

    async def skip_next_recall(self, session_id: str) -> bool:
        pool = await self._get_pool()
        now = datetime.now(timezone.utc)
        async with pool.acquire() as conn:
            result = await conn.execute(
                """
                UPDATE session_recalls
                SET trigger_at = trigger_at + (interval_seconds || ' seconds')::interval,
                    updated_at = $1
                WHERE session_id = $2 AND status = 'pending' AND call_type = 'recurring'
                """,
                now,
                session_id,
            )
        return "UPDATE 1" in result or result.endswith(" 1")

    async def list_recalls(
        self,
        *,
        session_id: str | None = None,
        user_id: str | None = None,
        is_admin: bool = False,
        limit: int = 50,
        offset: int = 0,
    ) -> list[Recall]:
        pool = await self._get_pool()
        params: list[Any] = []
        where_clauses: list[str] = []

        if session_id:
            params.append(session_id)
            where_clauses.append(f"session_id = ${len(params)}")

        if not is_admin and user_id:
            params.append(user_id)
            where_clauses.append(
                f"session_id IN (SELECT id FROM sessions WHERE user_id = ${len(params)})"
            )

        where_sql = "WHERE " + " AND ".join(where_clauses) if where_clauses else ""

        # Add limit and offset to params
        params.append(limit)
        limit_idx = len(params)
        params.append(offset)
        offset_idx = len(params)

        query = f"""
            SELECT * FROM session_recalls
            {where_sql}
            ORDER BY trigger_at DESC
            LIMIT ${limit_idx} OFFSET ${offset_idx}
        """

        async with pool.acquire() as conn:
            rows = await conn.fetch(query, *params)
        return [self._row_to_recall(r) for r in rows]

    async def get_pending_recalls(self, before: datetime) -> list[Recall]:
        pool = await self._get_pool()
        async with pool.acquire() as conn:
            rows = await conn.fetch(
                """
                SELECT * FROM session_recalls
                WHERE status = 'pending' AND trigger_at <= $1
                ORDER BY trigger_at ASC
                """,
                before,
            )
        return [self._row_to_recall(r) for r in rows]

    async def get_next_trigger_at(self) -> datetime | None:
        pool = await self._get_pool()
        async with pool.acquire() as conn:
            row = await conn.fetchrow(
                """
                SELECT trigger_at FROM session_recalls
                WHERE status = 'pending'
                ORDER BY trigger_at ASC
                LIMIT 1
                """
            )
        return row["trigger_at"] if row else None

    async def mark_fired(self, recall_id: str) -> None:
        pool = await self._get_pool()
        async with pool.acquire() as conn:
            await conn.execute(
                """
                UPDATE session_recalls
                SET status = 'fired', updated_at = $1
                WHERE id = $2
                """,
                datetime.now(timezone.utc),
                recall_id,
            )

    async def mark_cancelled(self, recall_id: str) -> None:
        pool = await self._get_pool()
        async with pool.acquire() as conn:
            await conn.execute(
                """
                UPDATE session_recalls
                SET status = 'cancelled', updated_at = $1
                WHERE id = $2
                """,
                datetime.now(timezone.utc),
                recall_id,
            )

    async def reschedule(self, recall_id: str, new_trigger_at: datetime) -> None:
        pool = await self._get_pool()
        async with pool.acquire() as conn:
            await conn.execute(
                """
                UPDATE session_recalls
                SET trigger_at = $1, updated_at = $2
                WHERE id = $3
                """,
                new_trigger_at,
                datetime.now(timezone.utc),
                recall_id,
            )

    async def get_pending_session_ids(self, session_ids: list[str]) -> set[str]:
        """Return the subset of session_ids that have a pending recall."""
        if not session_ids:
            return set()
        pool = await self._get_pool()
        import asyncpg

        async with pool.acquire() as conn:
            rows = await conn.fetch(
                """
                SELECT DISTINCT session_id FROM session_recalls
                WHERE status = 'pending' AND session_id = ANY($1)
                """,
                session_ids,
            )
        return {r["session_id"] for r in rows}


async def recall_scheduler_loop(scheduler: RecallScheduler, store: Any) -> None:
    """Background task: poll for due recalls and fire them."""
    await scheduler.ensure_tables()
    semaphore = asyncio.Semaphore(3)

    while True:
        try:
            now = datetime.now(timezone.utc)
            pending = await scheduler.get_pending_recalls(before=now)

            if pending:
                tasks = [
                    asyncio.create_task(_fire_recall(semaphore, recall, scheduler, store))
                    for recall in pending
                ]
                await asyncio.gather(*tasks, return_exceptions=True)

            next_trigger = await scheduler.get_next_trigger_at()
            if next_trigger:
                sleep_for = max(1, (next_trigger - datetime.now(timezone.utc)).total_seconds())
            else:
                sleep_for = 60
            await asyncio.sleep(sleep_for)

        except asyncio.CancelledError:
            raise
        except Exception:
            log.exception("scheduler.loop_error")
            await asyncio.sleep(60)


async def _fire_recall(
    semaphore: asyncio.Semaphore,
    recall: Recall,
    scheduler: RecallScheduler,
    store: Any,
) -> None:
    from navi.api.deps import (
        get_backend_registry,
        get_cp_registry,
        get_memory_store,
        get_mcp_manager,
        get_profile_registry,
        get_tool_registry,
        get_workers,
    )
    from navi.api.websocket import is_session_running
    from navi.core.agent import Agent

    async with semaphore:
        # Guard: if user is actively chatting, defer by 60 seconds
        if is_session_running(recall.session_id):
            log.info("scheduler.defer_busy", session_id=recall.session_id)
            await scheduler.reschedule(
                recall.id, datetime.now(timezone.utc) + timedelta(seconds=60)
            )
            return

        session = await store.get(recall.session_id)
        if session is None:
            log.warning("scheduler.session_missing", recall_id=recall.id)
            await scheduler.mark_cancelled(recall.id)
            return

        # Inject system message with recall context
        sys_msg = Message(
            role="system",
            content=f"[Scheduled recall] {recall.additional_context_message}",
        )
        session.messages.append(sys_msg)
        session.context.append(sys_msg)
        await store.save(session)

        # Build agent (same deps pattern as websocket handler)
        tools = get_tool_registry()
        profiles = get_profile_registry()
        backends = get_backend_registry()
        cp_registry = get_cp_registry()
        try:
            mcp_manager = await get_mcp_manager()
        except Exception:
            mcp_manager = None

        agent = Agent(
            store,
            profiles,
            tools,
            backends,
            workers=get_workers(),
            memory_store=get_memory_store(),
            cp_registry=cp_registry,
            mcp_manager=mcp_manager,
        )

        # Block user messages during headless run
        from navi.api.websocket import _busy_sessions

        _busy_sessions.add(recall.session_id)
        try:
            reply = await agent.run(
                session_id=recall.session_id,
                user_message=recall.additional_context_message,
            )
            log.info(
                "scheduler.recall_fired",
                recall_id=recall.id,
                session_id=recall.session_id,
                reply_len=len(reply) if reply else 0,
            )
            if recall.call_type == "recurring":
                next_trigger = datetime.now(timezone.utc) + timedelta(
                    seconds=recall.interval_seconds or 0
                )
                await scheduler.reschedule(recall.id, next_trigger)
            else:
                await scheduler.mark_fired(recall.id)
        except Exception:
            log.exception("scheduler.recall_failed", recall_id=recall.id)
            if recall.call_type == "recurring":
                next_trigger = datetime.now(timezone.utc) + timedelta(
                    seconds=recall.interval_seconds or 0
                )
                await scheduler.reschedule(recall.id, next_trigger)
            else:
                await scheduler.mark_cancelled(recall.id)
        finally:
            _busy_sessions.discard(recall.session_id)