"""Scheduled recall system — background scheduler that fires headless agent runs."""
import asyncio
import uuid
from dataclasses import dataclass
from datetime import datetime, timedelta, timezone
from typing import Any
import structlog
from navi.llm.base import Message
log = structlog.get_logger()
_DDL = """
CREATE TABLE IF NOT EXISTS session_recalls (
id TEXT PRIMARY KEY,
session_id TEXT NOT NULL REFERENCES sessions(id) ON DELETE CASCADE,
call_type TEXT NOT NULL CHECK (call_type IN ('once', 'recurring', 'immediate')),
trigger_at TIMESTAMPTZ NOT NULL,
interval_seconds INTEGER,
internal_comment TEXT,
additional_context_message TEXT NOT NULL,
status TEXT NOT NULL DEFAULT 'pending'
CHECK (status IN ('pending', 'fired', 'cancelled')),
created_at TIMESTAMPTZ NOT NULL,
updated_at TIMESTAMPTZ NOT NULL
);
CREATE INDEX IF NOT EXISTS idx_recalls_trigger_at_status
ON session_recalls (trigger_at, status);
CREATE INDEX IF NOT EXISTS idx_recalls_session_status
ON session_recalls (session_id, status);
CREATE UNIQUE INDEX IF NOT EXISTS idx_recalls_active_pending
ON session_recalls (session_id) WHERE status = 'pending';
"""
@dataclass(frozen=True, slots=True)
class Recall:
id: str
session_id: str
call_type: str
trigger_at: datetime
interval_seconds: int | None
internal_comment: str | None
additional_context_message: str
status: str
created_at: datetime
updated_at: datetime
class RecallExistsError(Exception):
"""Raised when a session already has a pending recall."""
class RecallScheduler:
"""PostgreSQL-backed scheduler for session recalls."""
def __init__(self, dsn: str) -> None:
self._dsn = dsn
self._pool: Any | None = None
self._lock = asyncio.Lock()
async def _get_pool(self) -> Any:
if self._pool is not None:
return self._pool
async with self._lock:
if self._pool is not None:
return self._pool
import asyncpg
pool = await asyncpg.create_pool(self._dsn)
async with pool.acquire() as conn:
await conn.execute(_DDL)
self._pool = pool
return self._pool
async def ensure_tables(self) -> None:
await self._get_pool()
def _row_to_recall(self, row: Any) -> Recall:
return Recall(
id=row["id"],
session_id=row["session_id"],
call_type=row["call_type"],
trigger_at=row["trigger_at"],
interval_seconds=row["interval_seconds"],
internal_comment=row["internal_comment"],
additional_context_message=row["additional_context_message"],
status=row["status"],
created_at=row["created_at"],
updated_at=row["updated_at"],
)
async def schedule_recall(
self,
*,
session_id: str,
call_type: str,
trigger_at: datetime,
interval_seconds: int | None = None,
internal_comment: str | None = None,
additional_context_message: str,
) -> Recall:
pool = await self._get_pool()
recall_id = str(uuid.uuid4())
now = datetime.now(timezone.utc)
async with pool.acquire() as conn:
try:
row = await conn.fetchrow(
"""
INSERT INTO session_recalls (
id, session_id, call_type, trigger_at, interval_seconds,
internal_comment, additional_context_message, status, created_at, updated_at
) VALUES ($1, $2, $3, $4, $5, $6, $7, 'pending', $8, $8)
RETURNING *
""",
recall_id,
session_id,
call_type,
trigger_at,
interval_seconds,
internal_comment,
additional_context_message,
now,
)
except Exception as exc:
if "unique" in str(exc).lower() or "duplicate" in str(exc).lower():
raise RecallExistsError(
f"Session {session_id} already has a pending recall."
) from exc
raise
return self._row_to_recall(row)
async def cancel_recall(self, session_id: str) -> bool:
pool = await self._get_pool()
async with pool.acquire() as conn:
result = await conn.execute(
"""
UPDATE session_recalls
SET status = 'cancelled', updated_at = $1
WHERE session_id = $2 AND status = 'pending'
""",
datetime.now(timezone.utc),
session_id,
)
# asyncpg execute returns a status string like "UPDATE 1"
return "UPDATE 1" in result or result.endswith(" 1")
async def skip_next_recall(self, session_id: str) -> bool:
pool = await self._get_pool()
now = datetime.now(timezone.utc)
async with pool.acquire() as conn:
result = await conn.execute(
"""
UPDATE session_recalls
SET trigger_at = trigger_at + (interval_seconds || ' seconds')::interval,
updated_at = $1
WHERE session_id = $2 AND status = 'pending' AND call_type = 'recurring'
""",
now,
session_id,
)
return "UPDATE 1" in result or result.endswith(" 1")
async def list_recalls(
self,
*,
session_id: str | None = None,
user_id: str | None = None,
is_admin: bool = False,
limit: int = 50,
offset: int = 0,
) -> list[Recall]:
pool = await self._get_pool()
params: list[Any] = []
where_clauses: list[str] = []
if session_id:
params.append(session_id)
where_clauses.append(f"session_id = ${len(params)}")
if not is_admin and user_id:
params.append(user_id)
where_clauses.append(
f"session_id IN (SELECT id FROM sessions WHERE user_id = ${len(params)})"
)
where_sql = "WHERE " + " AND ".join(where_clauses) if where_clauses else ""
# Add limit and offset to params
params.append(limit)
limit_idx = len(params)
params.append(offset)
offset_idx = len(params)
query = f"""
SELECT * FROM session_recalls
{where_sql}
ORDER BY trigger_at DESC
LIMIT ${limit_idx} OFFSET ${offset_idx}
"""
async with pool.acquire() as conn:
rows = await conn.fetch(query, *params)
return [self._row_to_recall(r) for r in rows]
async def get_pending_recalls(self, before: datetime) -> list[Recall]:
pool = await self._get_pool()
async with pool.acquire() as conn:
rows = await conn.fetch(
"""
SELECT * FROM session_recalls
WHERE status = 'pending' AND trigger_at <= $1
ORDER BY trigger_at ASC
""",
before,
)
return [self._row_to_recall(r) for r in rows]
async def get_next_trigger_at(self) -> datetime | None:
pool = await self._get_pool()
async with pool.acquire() as conn:
row = await conn.fetchrow(
"""
SELECT trigger_at FROM session_recalls
WHERE status = 'pending'
ORDER BY trigger_at ASC
LIMIT 1
"""
)
return row["trigger_at"] if row else None
async def mark_fired(self, recall_id: str) -> None:
pool = await self._get_pool()
async with pool.acquire() as conn:
await conn.execute(
"""
UPDATE session_recalls
SET status = 'fired', updated_at = $1
WHERE id = $2
""",
datetime.now(timezone.utc),
recall_id,
)
async def mark_cancelled(self, recall_id: str) -> None:
pool = await self._get_pool()
async with pool.acquire() as conn:
await conn.execute(
"""
UPDATE session_recalls
SET status = 'cancelled', updated_at = $1
WHERE id = $2
""",
datetime.now(timezone.utc),
recall_id,
)
async def reschedule(self, recall_id: str, new_trigger_at: datetime) -> None:
pool = await self._get_pool()
async with pool.acquire() as conn:
await conn.execute(
"""
UPDATE session_recalls
SET trigger_at = $1, updated_at = $2
WHERE id = $3
""",
new_trigger_at,
datetime.now(timezone.utc),
recall_id,
)
async def get_pending_session_ids(self, session_ids: list[str]) -> set[str]:
"""Return the subset of session_ids that have a pending recall."""
if not session_ids:
return set()
pool = await self._get_pool()
import asyncpg
async with pool.acquire() as conn:
rows = await conn.fetch(
"""
SELECT DISTINCT session_id FROM session_recalls
WHERE status = 'pending' AND session_id = ANY($1)
""",
session_ids,
)
return {r["session_id"] for r in rows}
async def recall_scheduler_loop(scheduler: RecallScheduler, store: Any) -> None:
"""Background task: poll for due recalls and fire them."""
await scheduler.ensure_tables()
semaphore = asyncio.Semaphore(3)
while True:
try:
now = datetime.now(timezone.utc)
pending = await scheduler.get_pending_recalls(before=now)
if pending:
tasks = [
asyncio.create_task(_fire_recall(semaphore, recall, scheduler, store))
for recall in pending
]
await asyncio.gather(*tasks, return_exceptions=True)
next_trigger = await scheduler.get_next_trigger_at()
if next_trigger:
sleep_for = max(1, (next_trigger - datetime.now(timezone.utc)).total_seconds())
else:
sleep_for = 60
await asyncio.sleep(sleep_for)
except asyncio.CancelledError:
raise
except Exception:
log.exception("scheduler.loop_error")
await asyncio.sleep(60)
async def _fire_recall(
semaphore: asyncio.Semaphore,
recall: Recall,
scheduler: RecallScheduler,
store: Any,
) -> None:
from navi.api.deps import (
get_backend_registry,
get_cp_registry,
get_memory_store,
get_mcp_manager,
get_profile_registry,
get_tool_registry,
get_workers,
)
from navi.api.websocket import is_session_running
from navi.core.agent import Agent
async with semaphore:
# Guard: if user is actively chatting, defer by 60 seconds
if is_session_running(recall.session_id):
log.info("scheduler.defer_busy", session_id=recall.session_id)
await scheduler.reschedule(
recall.id, datetime.now(timezone.utc) + timedelta(seconds=60)
)
return
session = await store.get(recall.session_id)
if session is None:
log.warning("scheduler.session_missing", recall_id=recall.id)
await scheduler.mark_cancelled(recall.id)
return
# Inject system message with recall context
sys_msg = Message(
role="system",
content=f"[Scheduled recall] {recall.additional_context_message}",
)
session.messages.append(sys_msg)
session.context.append(sys_msg)
await store.save(session)
# Build agent (same deps pattern as websocket handler)
tools = get_tool_registry()
profiles = get_profile_registry()
backends = get_backend_registry()
cp_registry = get_cp_registry()
try:
mcp_manager = await get_mcp_manager()
except Exception:
mcp_manager = None
agent = Agent(
store,
profiles,
tools,
backends,
workers=get_workers(),
memory_store=get_memory_store(),
cp_registry=cp_registry,
mcp_manager=mcp_manager,
)
# Block user messages during headless run
from navi.api.websocket import _busy_sessions
_busy_sessions.add(recall.session_id)
try:
reply = await agent.run(
session_id=recall.session_id,
user_message=recall.additional_context_message,
)
log.info(
"scheduler.recall_fired",
recall_id=recall.id,
session_id=recall.session_id,
reply_len=len(reply) if reply else 0,
)
if recall.call_type == "recurring":
next_trigger = datetime.now(timezone.utc) + timedelta(
seconds=recall.interval_seconds or 0
)
await scheduler.reschedule(recall.id, next_trigger)
else:
await scheduler.mark_fired(recall.id)
except Exception:
log.exception("scheduler.recall_failed", recall_id=recall.id)
if recall.call_type == "recurring":
next_trigger = datetime.now(timezone.utc) + timedelta(
seconds=recall.interval_seconds or 0
)
await scheduler.reschedule(recall.id, next_trigger)
else:
await scheduler.mark_cancelled(recall.id)
finally:
_busy_sessions.discard(recall.session_id)