Newer
Older
navi-1 / navi / core / scheduler.py
"""Scheduled recall system — background scheduler that fires headless agent runs."""

import asyncio
import uuid
from dataclasses import dataclass
from datetime import datetime, timedelta, timezone
from typing import Any

import structlog

from navi.llm.base import Message

log = structlog.get_logger()


async def _publish_recall_update(
    session_id: str,
    recall_id: str | None = None,
    call_type: str | None = None,
    trigger_at: str | None = None,
    status: str | None = None,
    action: str | None = None,
) -> None:
    from navi.core.event_bus import get_event_bus
    from navi.core.events import RecallUpdate

    await get_event_bus().publish(
        RecallUpdate(
            session_id=session_id,
            recall_id=recall_id,
            call_type=call_type,
            trigger_at=trigger_at,
            status=status,
            action=action,
        )
    )

_DDL = """
CREATE TABLE IF NOT EXISTS session_recalls (
    id                  TEXT PRIMARY KEY,
    session_id          TEXT NOT NULL REFERENCES sessions(id) ON DELETE CASCADE,
    call_type           TEXT NOT NULL CHECK (call_type IN ('once', 'recurring', 'immediate')),
    trigger_at          TIMESTAMPTZ NOT NULL,
    interval_seconds    INTEGER,
    internal_comment    TEXT,
    additional_context_message TEXT NOT NULL,
    status              TEXT NOT NULL DEFAULT 'pending'
                        CHECK (status IN ('pending', 'fired', 'cancelled')),
    created_at          TIMESTAMPTZ NOT NULL,
    updated_at          TIMESTAMPTZ NOT NULL
);

CREATE INDEX IF NOT EXISTS idx_recalls_trigger_at_status
    ON session_recalls (trigger_at, status);

CREATE INDEX IF NOT EXISTS idx_recalls_session_status
    ON session_recalls (session_id, status);

CREATE UNIQUE INDEX IF NOT EXISTS idx_recalls_active_pending
    ON session_recalls (session_id) WHERE status = 'pending';
"""


@dataclass(frozen=True, slots=True)
class Recall:
    id: str
    session_id: str
    call_type: str
    trigger_at: datetime
    interval_seconds: int | None
    internal_comment: str | None
    additional_context_message: str
    status: str
    created_at: datetime
    updated_at: datetime


class RecallExistsError(Exception):
    """Raised when a session already has a pending recall."""


class RecallScheduler:
    """PostgreSQL-backed scheduler for session recalls."""

    def __init__(self, pool: Any) -> None:
        self._pool = pool
        self._initialized = False
        self._lock = asyncio.Lock()

    async def _get_pool(self) -> Any:
        if not self._initialized:
            async with self._lock:
                if not self._initialized:
                    import asyncpg

                    async with self._pool.acquire() as conn:
                        await conn.execute(_DDL)
                    self._initialized = True
        return self._pool

    async def ensure_tables(self) -> None:
        await self._get_pool()

    def _row_to_recall(self, row: Any) -> Recall:
        return Recall(
            id=row["id"],
            session_id=row["session_id"],
            call_type=row["call_type"],
            trigger_at=row["trigger_at"],
            interval_seconds=row["interval_seconds"],
            internal_comment=row["internal_comment"],
            additional_context_message=row["additional_context_message"],
            status=row["status"],
            created_at=row["created_at"],
            updated_at=row["updated_at"],
        )

    async def schedule_recall(
        self,
        *,
        session_id: str,
        call_type: str,
        trigger_at: datetime,
        interval_seconds: int | None = None,
        internal_comment: str | None = None,
        additional_context_message: str,
    ) -> Recall:
        pool = await self._get_pool()
        recall_id = str(uuid.uuid4())
        now = datetime.now(timezone.utc)
        async with pool.acquire() as conn:
            try:
                row = await conn.fetchrow(
                    """
                    INSERT INTO session_recalls (
                        id, session_id, call_type, trigger_at, interval_seconds,
                        internal_comment, additional_context_message, status, created_at, updated_at
                    ) VALUES ($1, $2, $3, $4, $5, $6, $7, 'pending', $8, $8)
                    RETURNING *
                    """,
                    recall_id,
                    session_id,
                    call_type,
                    trigger_at,
                    interval_seconds,
                    internal_comment,
                    additional_context_message,
                    now,
                )
            except Exception as exc:
                if "unique" in str(exc).lower() or "duplicate" in str(exc).lower():
                    raise RecallExistsError(
                        f"Session {session_id} already has a pending recall."
                    ) from exc
                raise
        return self._row_to_recall(row)

    async def cancel_recall(self, session_id: str) -> bool:
        pool = await self._get_pool()
        async with pool.acquire() as conn:
            result = await conn.execute(
                """
                UPDATE session_recalls
                SET status = 'cancelled', updated_at = $1
                WHERE session_id = $2 AND status = 'pending'
                """,
                datetime.now(timezone.utc),
                session_id,
            )
        # asyncpg execute returns a status string like "UPDATE 1"
        return "UPDATE 1" in result or result.endswith(" 1")

    async def skip_next_recall(self, session_id: str) -> bool:
        pool = await self._get_pool()
        now = datetime.now(timezone.utc)
        async with pool.acquire() as conn:
            result = await conn.execute(
                """
                UPDATE session_recalls
                SET trigger_at = GREATEST(trigger_at, $1) + (interval_seconds || ' seconds')::interval,
                    updated_at = $1
                WHERE session_id = $2 AND status = 'pending' AND call_type = 'recurring'
                """,
                now,
                session_id,
            )
        return "UPDATE 1" in result or result.endswith(" 1")

    async def list_recalls(
        self,
        *,
        session_id: str | None = None,
        user_id: str | None = None,
        is_admin: bool = False,
        limit: int = 50,
        offset: int = 0,
    ) -> list[Recall]:
        pool = await self._get_pool()
        params: list[Any] = []
        where_clauses: list[str] = []

        if session_id:
            params.append(session_id)
            where_clauses.append(f"session_id = ${len(params)}")

        if not is_admin and user_id:
            params.append(user_id)
            where_clauses.append(
                f"session_id IN (SELECT id FROM sessions WHERE user_id = ${len(params)})"
            )

        where_sql = "WHERE " + " AND ".join(where_clauses) if where_clauses else ""

        # Add limit and offset to params
        params.append(limit)
        limit_idx = len(params)
        params.append(offset)
        offset_idx = len(params)

        query = f"""
            SELECT * FROM session_recalls
            {where_sql}
            ORDER BY trigger_at DESC
            LIMIT ${limit_idx} OFFSET ${offset_idx}
        """

        async with pool.acquire() as conn:
            rows = await conn.fetch(query, *params)
        return [self._row_to_recall(r) for r in rows]

    async def get_pending_recalls(self, before: datetime) -> list[Recall]:
        pool = await self._get_pool()
        async with pool.acquire() as conn:
            rows = await conn.fetch(
                """
                SELECT * FROM session_recalls
                WHERE status = 'pending' AND trigger_at <= $1
                ORDER BY trigger_at ASC
                """,
                before,
            )
        return [self._row_to_recall(r) for r in rows]

    async def get_next_trigger_at(self) -> datetime | None:
        pool = await self._get_pool()
        async with pool.acquire() as conn:
            row = await conn.fetchrow(
                """
                SELECT trigger_at FROM session_recalls
                WHERE status = 'pending'
                ORDER BY trigger_at ASC
                LIMIT 1
                """
            )
        return row["trigger_at"] if row else None

    async def mark_fired(self, recall_id: str) -> None:
        pool = await self._get_pool()
        async with pool.acquire() as conn:
            await conn.execute(
                """
                UPDATE session_recalls
                SET status = 'fired', updated_at = $1
                WHERE id = $2
                """,
                datetime.now(timezone.utc),
                recall_id,
            )

    async def mark_cancelled(self, recall_id: str) -> None:
        pool = await self._get_pool()
        async with pool.acquire() as conn:
            await conn.execute(
                """
                UPDATE session_recalls
                SET status = 'cancelled', updated_at = $1
                WHERE id = $2
                """,
                datetime.now(timezone.utc),
                recall_id,
            )

    async def reschedule(self, recall_id: str, new_trigger_at: datetime) -> None:
        pool = await self._get_pool()
        async with pool.acquire() as conn:
            await conn.execute(
                """
                UPDATE session_recalls
                SET trigger_at = $1, updated_at = $2
                WHERE id = $3
                """,
                new_trigger_at,
                datetime.now(timezone.utc),
                recall_id,
            )

    async def get_pending_session_ids(self, session_ids: list[str]) -> set[str]:
        """Return the subset of session_ids that have a pending recall."""
        if not session_ids:
            return set()
        pool = await self._get_pool()
        import asyncpg

        async with pool.acquire() as conn:
            rows = await conn.fetch(
                """
                SELECT DISTINCT session_id FROM session_recalls
                WHERE status = 'pending' AND session_id = ANY($1)
                """,
                session_ids,
            )
        return {r["session_id"] for r in rows}


async def recall_scheduler_loop(scheduler: RecallScheduler, store: Any, orchestrator) -> None:
    """Background task: poll for due recalls and fire them."""
    await scheduler.ensure_tables()
    semaphore = asyncio.Semaphore(3)

    while True:
        try:
            now = datetime.now(timezone.utc)
            pending = await scheduler.get_pending_recalls(before=now)

            if pending:
                tasks = [
                    asyncio.create_task(_fire_recall(semaphore, recall, scheduler, store, orchestrator))
                    for recall in pending
                ]
                await asyncio.gather(*tasks, return_exceptions=True)

            next_trigger = await scheduler.get_next_trigger_at()
            if next_trigger:
                sleep_for = max(1, (next_trigger - datetime.now(timezone.utc)).total_seconds())
            else:
                sleep_for = 60
            await asyncio.sleep(sleep_for)

        except asyncio.CancelledError:
            raise
        except Exception:
            log.exception("scheduler.loop_error")
            await asyncio.sleep(60)


async def _fire_recall(
    semaphore: asyncio.Semaphore,
    recall: Recall,
    scheduler: RecallScheduler,
    store: Any,
    orchestrator,
) -> None:
    async with semaphore:
        await orchestrator.run_recall(recall, scheduler, store)