"""Scheduled recall system — background scheduler that fires headless agent runs."""

import asyncio
import uuid
from dataclasses import dataclass
from datetime import datetime, timedelta, timezone
from typing import Any

import structlog

from navi.llm.base import Message

log = structlog.get_logger()


async def _publish_recall_update(
    session_id: str,
    recall_id: str | None = None,
    call_type: str | None = None,
    trigger_at: str | None = None,
    status: str | None = None,
    action: str | None = None,
) -> None:
    from navi.core.event_bus import get_event_bus
    from navi.core.events import RecallUpdate

    await get_event_bus().publish(
        RecallUpdate(
            session_id=session_id,
            recall_id=recall_id,
            call_type=call_type,
            trigger_at=trigger_at,
            status=status,
            action=action,
        )
    )

_DDL = """
CREATE TABLE IF NOT EXISTS session_recalls (
    id                  TEXT PRIMARY KEY,
    session_id          TEXT NOT NULL REFERENCES sessions(id) ON DELETE CASCADE,
    call_type           TEXT NOT NULL CHECK (call_type IN ('once', 'recurring', 'immediate')),
    trigger_at          TIMESTAMPTZ NOT NULL,
    interval_seconds    INTEGER,
    internal_comment    TEXT,
    additional_context_message TEXT NOT NULL,
    status              TEXT NOT NULL DEFAULT 'pending'
                        CHECK (status IN ('pending', 'fired', 'cancelled')),
    created_at          TIMESTAMPTZ NOT NULL,
    updated_at          TIMESTAMPTZ NOT NULL
);

CREATE INDEX IF NOT EXISTS idx_recalls_trigger_at_status
    ON session_recalls (trigger_at, status);

CREATE INDEX IF NOT EXISTS idx_recalls_session_status
    ON session_recalls (session_id, status);

CREATE UNIQUE INDEX IF NOT EXISTS idx_recalls_active_pending
    ON session_recalls (session_id) WHERE status = 'pending';
"""


@dataclass(frozen=True, slots=True)
class Recall:
    id: str
    session_id: str
    call_type: str
    trigger_at: datetime
    interval_seconds: int | None
    internal_comment: str | None
    additional_context_message: str
    status: str
    created_at: datetime
    updated_at: datetime


class RecallExistsError(Exception):
    """Raised when a session already has a pending recall."""


class RecallScheduler:
    """PostgreSQL-backed scheduler for session recalls."""

    def __init__(self, dsn: str) -> None:
        self._dsn = dsn
        self._pool: Any | None = None
        self._lock = asyncio.Lock()

    async def _get_pool(self) -> Any:
        if self._pool is not None:
            return self._pool
        async with self._lock:
            if self._pool is not None:
                return self._pool
            import asyncpg

            pool = await asyncpg.create_pool(self._dsn)
            async with pool.acquire() as conn:
                await conn.execute(_DDL)
            self._pool = pool
        return self._pool

    async def ensure_tables(self) -> None:
        await self._get_pool()

    def _row_to_recall(self, row: Any) -> Recall:
        return Recall(
            id=row["id"],
            session_id=row["session_id"],
            call_type=row["call_type"],
            trigger_at=row["trigger_at"],
            interval_seconds=row["interval_seconds"],
            internal_comment=row["internal_comment"],
            additional_context_message=row["additional_context_message"],
            status=row["status"],
            created_at=row["created_at"],
            updated_at=row["updated_at"],
        )

    async def schedule_recall(
        self,
        *,
        session_id: str,
        call_type: str,
        trigger_at: datetime,
        interval_seconds: int | None = None,
        internal_comment: str | None = None,
        additional_context_message: str,
    ) -> Recall:
        pool = await self._get_pool()
        recall_id = str(uuid.uuid4())
        now = datetime.now(timezone.utc)
        async with pool.acquire() as conn:
            try:
                row = await conn.fetchrow(
                    """
                    INSERT INTO session_recalls (
                        id, session_id, call_type, trigger_at, interval_seconds,
                        internal_comment, additional_context_message, status, created_at, updated_at
                    ) VALUES ($1, $2, $3, $4, $5, $6, $7, 'pending', $8, $8)
                    RETURNING *
                    """,
                    recall_id,
                    session_id,
                    call_type,
                    trigger_at,
                    interval_seconds,
                    internal_comment,
                    additional_context_message,
                    now,
                )
            except Exception as exc:
                if "unique" in str(exc).lower() or "duplicate" in str(exc).lower():
                    raise RecallExistsError(
                        f"Session {session_id} already has a pending recall."
                    ) from exc
                raise
        return self._row_to_recall(row)

    async def cancel_recall(self, session_id: str) -> bool:
        pool = await self._get_pool()
        async with pool.acquire() as conn:
            result = await conn.execute(
                """
                UPDATE session_recalls
                SET status = 'cancelled', updated_at = $1
                WHERE session_id = $2 AND status = 'pending'
                """,
                datetime.now(timezone.utc),
                session_id,
            )
        # asyncpg execute returns a status string like "UPDATE 1"
        return "UPDATE 1" in result or result.endswith(" 1")

    async def skip_next_recall(self, session_id: str) -> bool:
        pool = await self._get_pool()
        now = datetime.now(timezone.utc)
        async with pool.acquire() as conn:
            result = await conn.execute(
                """
                UPDATE session_recalls
                SET trigger_at = GREATEST(trigger_at, $1) + (interval_seconds || ' seconds')::interval,
                    updated_at = $1
                WHERE session_id = $2 AND status = 'pending' AND call_type = 'recurring'
                """,
                now,
                session_id,
            )
        return "UPDATE 1" in result or result.endswith(" 1")

    async def list_recalls(
        self,
        *,
        session_id: str | None = None,
        user_id: str | None = None,
        is_admin: bool = False,
        limit: int = 50,
        offset: int = 0,
    ) -> list[Recall]:
        pool = await self._get_pool()
        params: list[Any] = []
        where_clauses: list[str] = []

        if session_id:
            params.append(session_id)
            where_clauses.append(f"session_id = ${len(params)}")

        if not is_admin and user_id:
            params.append(user_id)
            where_clauses.append(
                f"session_id IN (SELECT id FROM sessions WHERE user_id = ${len(params)})"
            )

        where_sql = "WHERE " + " AND ".join(where_clauses) if where_clauses else ""

        # Add limit and offset to params
        params.append(limit)
        limit_idx = len(params)
        params.append(offset)
        offset_idx = len(params)

        query = f"""
            SELECT * FROM session_recalls
            {where_sql}
            ORDER BY trigger_at DESC
            LIMIT ${limit_idx} OFFSET ${offset_idx}
        """

        async with pool.acquire() as conn:
            rows = await conn.fetch(query, *params)
        return [self._row_to_recall(r) for r in rows]

    async def get_pending_recalls(self, before: datetime) -> list[Recall]:
        pool = await self._get_pool()
        async with pool.acquire() as conn:
            rows = await conn.fetch(
                """
                SELECT * FROM session_recalls
                WHERE status = 'pending' AND trigger_at <= $1
                ORDER BY trigger_at ASC
                """,
                before,
            )
        return [self._row_to_recall(r) for r in rows]

    async def get_next_trigger_at(self) -> datetime | None:
        pool = await self._get_pool()
        async with pool.acquire() as conn:
            row = await conn.fetchrow(
                """
                SELECT trigger_at FROM session_recalls
                WHERE status = 'pending'
                ORDER BY trigger_at ASC
                LIMIT 1
                """
            )
        return row["trigger_at"] if row else None

    async def mark_fired(self, recall_id: str) -> None:
        pool = await self._get_pool()
        async with pool.acquire() as conn:
            await conn.execute(
                """
                UPDATE session_recalls
                SET status = 'fired', updated_at = $1
                WHERE id = $2
                """,
                datetime.now(timezone.utc),
                recall_id,
            )

    async def mark_cancelled(self, recall_id: str) -> None:
        pool = await self._get_pool()
        async with pool.acquire() as conn:
            await conn.execute(
                """
                UPDATE session_recalls
                SET status = 'cancelled', updated_at = $1
                WHERE id = $2
                """,
                datetime.now(timezone.utc),
                recall_id,
            )

    async def reschedule(self, recall_id: str, new_trigger_at: datetime) -> None:
        pool = await self._get_pool()
        async with pool.acquire() as conn:
            await conn.execute(
                """
                UPDATE session_recalls
                SET trigger_at = $1, updated_at = $2
                WHERE id = $3
                """,
                new_trigger_at,
                datetime.now(timezone.utc),
                recall_id,
            )

    async def get_pending_session_ids(self, session_ids: list[str]) -> set[str]:
        """Return the subset of session_ids that have a pending recall."""
        if not session_ids:
            return set()
        pool = await self._get_pool()
        import asyncpg

        async with pool.acquire() as conn:
            rows = await conn.fetch(
                """
                SELECT DISTINCT session_id FROM session_recalls
                WHERE status = 'pending' AND session_id = ANY($1)
                """,
                session_ids,
            )
        return {r["session_id"] for r in rows}


async def recall_scheduler_loop(scheduler: RecallScheduler, store: Any) -> None:
    """Background task: poll for due recalls and fire them."""
    await scheduler.ensure_tables()
    semaphore = asyncio.Semaphore(3)

    while True:
        try:
            now = datetime.now(timezone.utc)
            pending = await scheduler.get_pending_recalls(before=now)

            if pending:
                tasks = [
                    asyncio.create_task(_fire_recall(semaphore, recall, scheduler, store))
                    for recall in pending
                ]
                await asyncio.gather(*tasks, return_exceptions=True)

            next_trigger = await scheduler.get_next_trigger_at()
            if next_trigger:
                sleep_for = max(1, (next_trigger - datetime.now(timezone.utc)).total_seconds())
            else:
                sleep_for = 60
            await asyncio.sleep(sleep_for)

        except asyncio.CancelledError:
            raise
        except Exception:
            log.exception("scheduler.loop_error")
            await asyncio.sleep(60)


async def _fire_recall(
    semaphore: asyncio.Semaphore,
    recall: Recall,
    scheduler: RecallScheduler,
    store: Any,
) -> None:
    from navi.api.deps import (
        get_backend_registry,
        get_cp_registry,
        get_memory_store,
        get_mcp_manager,
        get_profile_registry,
        get_tool_registry,
        get_workers,
    )
    from navi.api.websocket import (
        _AgentRun,
        _busy_sessions,
        _notify_session,
        _runs,
    )
    from navi.core.agent import Agent, MaxIterationsReached

    async with semaphore:
        # Guard: if a websocket run is active for this session, defer by 60 seconds
        if recall.session_id in _runs:
            log.info("scheduler.defer_busy", session_id=recall.session_id)
            await scheduler.reschedule(
                recall.id, datetime.now(timezone.utc) + timedelta(seconds=60)
            )
            return

        session = await store.get(recall.session_id)
        if session is None:
            log.warning("scheduler.session_missing", recall_id=recall.id)
            await scheduler.mark_cancelled(recall.id)
            return

        # Set user context for tools so sandboxing and ownership checks work
        from navi.tools._internal.base import (
            current_user_id as _uid_var,
            current_user_role as _role_var,
            current_user_info as _uinfo_var,
        )
        if session and session.user_id is not None:
            _uid_var.set(session.user_id)
            _role_var.set("user")
            _uinfo_var.set(None)
        else:
            _uid_var.set(None)
            _role_var.set("user")
            _uinfo_var.set(None)

        # Build agent (same deps pattern as websocket handler)
        tools = get_tool_registry()
        profiles = get_profile_registry()
        backends = get_backend_registry()
        cp_registry = get_cp_registry()
        try:
            mcp_manager = await get_mcp_manager()
        except Exception:
            mcp_manager = None

        agent = Agent(
            store,
            profiles,
            tools,
            backends,
            workers=get_workers(),
            memory_store=get_memory_store(),
            cp_registry=cp_registry,
            mcp_manager=mcp_manager,
        )

        from navi.core.event_bus import get_event_bus
        from navi.core.events import StreamEnd
        from navi.tools._internal.base import current_stop_event

        stop_event = asyncio.Event()
        _busy_sessions[recall.session_id] = stop_event
        token = current_stop_event.set(stop_event)

        # Register a headless run so reconnecting clients can replay events
        run = _AgentRun()
        _runs[recall.session_id] = run

        accumulated_text = ""
        try:
            # Notify any open WebSocket clients that a headless turn is starting.
            # We intentionally do NOT send session_sync here — it races with
            # stream_start and can wipe the in-progress streaming message before
            # any deltas arrive. The final session_sync in the finally block
            # ensures the saved turn (recall user msg + assistant response) is
            # loaded once streaming is complete.
            await _notify_session(recall.session_id, {"type": "stream_start"})

            async for event in agent.run_stream(
                session_id=recall.session_id,
                user_message=recall.additional_context_message,
                display_message=recall.additional_context_message,
                is_recall=True,
            ):
                await get_event_bus().publish(event)
                wire = event.to_wire()
                if wire:
                    await _notify_session(recall.session_id, wire)
                    await run.broadcast(("event", event))
                if isinstance(event, StreamEnd):
                    accumulated_text = event.full_content

            log.info(
                "scheduler.recall_fired",
                recall_id=recall.id,
                session_id=recall.session_id,
                reply_len=len(accumulated_text),
            )
            if recall.call_type == "recurring":
                next_trigger = datetime.now(timezone.utc) + timedelta(
                    seconds=recall.interval_seconds or 0
                )
                await scheduler.reschedule(recall.id, next_trigger)
                await _publish_recall_update(
                    recall.session_id, recall.id, recall.call_type,
                    trigger_at=next_trigger.isoformat(), status="pending", action="rescheduled"
                )
            else:
                await scheduler.mark_fired(recall.id)
                await _publish_recall_update(
                    recall.session_id, recall.id, recall.call_type,
                    trigger_at=recall.trigger_at.isoformat(), status="fired", action="fired"
                )
        except MaxIterationsReached:
            log.info("scheduler.max_iterations", recall_id=recall.id)
            if recall.call_type == "recurring":
                next_trigger = datetime.now(timezone.utc) + timedelta(
                    seconds=recall.interval_seconds or 0
                )
                await scheduler.reschedule(recall.id, next_trigger)
                await _publish_recall_update(
                    recall.session_id, recall.id, recall.call_type,
                    trigger_at=next_trigger.isoformat(), status="pending", action="rescheduled"
                )
            else:
                await scheduler.mark_fired(recall.id)
                await _publish_recall_update(
                    recall.session_id, recall.id, recall.call_type,
                    trigger_at=recall.trigger_at.isoformat(), status="fired", action="fired"
                )
        except Exception:
            log.exception("scheduler.recall_failed", recall_id=recall.id)
            if recall.call_type == "recurring":
                next_trigger = datetime.now(timezone.utc) + timedelta(
                    seconds=recall.interval_seconds or 0
                )
                await scheduler.reschedule(recall.id, next_trigger)
                await _publish_recall_update(
                    recall.session_id, recall.id, recall.call_type,
                    trigger_at=next_trigger.isoformat(), status="pending", action="rescheduled"
                )
            else:
                await scheduler.mark_cancelled(recall.id)
                await _publish_recall_update(
                    recall.session_id, recall.id, recall.call_type,
                    trigger_at=recall.trigger_at.isoformat(), status="cancelled", action="cancelled"
                )
        finally:
            _busy_sessions.pop(recall.session_id, None)
            current_stop_event.reset(token)
            _runs.pop(recall.session_id, None)
            # Tell all connected clients to reload session history so they see the
            # full headless turn (recall message + assistant response).
            await _notify_session(recall.session_id, {"type": "session_sync"})
