Newer
Older
navi-1 / debug / eval / runner.py
"""Background eval runner used by both the CLI and the REST endpoint.

Owns an in-memory registry of in-flight runs keyed by run_id (uuid). The
caller starts a run via `start_run(...)`, then polls `get_run(run_id)` for
status. Runs persist their results to postgres via EvalDB; the in-memory
registry is purely for live-progress reporting.
"""

from __future__ import annotations

import asyncio
import traceback
from datetime import datetime, timezone
from uuid import uuid4

from .db import EvalDB
from .judge import (
    JUDGE_VERSION,
    RUBRIC_VERSION,
    average_scores,
    evaluate_session,
)
from .schema import RunRequest, RunSessionStatus, RunStatus


class _RunRegistry:
    """Single-process in-memory registry. Cleared on server restart."""

    def __init__(self) -> None:
        self._runs: dict[str, RunStatus] = {}
        self._tasks: dict[str, asyncio.Task] = {}

    def register(self, status: RunStatus) -> None:
        self._runs[status.run_id] = status

    def get(self, run_id: str) -> RunStatus | None:
        return self._runs.get(run_id)

    def list_runs(self) -> list[RunStatus]:
        return sorted(self._runs.values(), key=lambda r: r.started_at, reverse=True)

    def attach_task(self, run_id: str, task: asyncio.Task) -> None:
        self._tasks[run_id] = task


_registry = _RunRegistry()


def get_registry() -> _RunRegistry:
    return _registry


# ── Picking sessions ────────────────────────────────────────────────────


async def _resolve_sessions(req: RunRequest, db: EvalDB, session_store):
    if req.scope == "session":
        if not req.session_id:
            raise ValueError("scope=session requires session_id")
        s = await session_store.get(req.session_id)
        if s is None:
            raise ValueError(f"session not found: {req.session_id}")
        return [s]

    sessions = await session_store.list_all()

    if req.scope == "unevaluated":
        already = await db.evaluated_session_ids(JUDGE_VERSION, RUBRIC_VERSION)
        sessions = [s for s in sessions if s.id not in already]

    if req.since is not None:
        sessions = [s for s in sessions if s.created_at >= req.since]
    if req.limit is not None:
        sessions = sessions[: req.limit]
    return sessions


# ── The actual loop ─────────────────────────────────────────────────────


async def _run_loop(
    *,
    run_id: str,
    req: RunRequest,
    db: EvalDB,
    session_store,
    backend_registry,
    profile_registry,
) -> None:
    status = _registry.get(run_id)
    if status is None:
        return

    try:
        sessions = await _resolve_sessions(req, db, session_store)
        status.sessions = [
            RunSessionStatus(session_id=s.id, state="pending") for s in sessions
        ]
    except Exception as e:
        status.state = "failed"
        status.finished_at = datetime.now(timezone.utc)
        status.sessions = [RunSessionStatus(session_id="", state="failed", error=str(e))]
        return

    try:
        llm = backend_registry.get(req.backend)
    except Exception as e:
        status.state = "failed"
        status.finished_at = datetime.now(timezone.utc)
        for s in status.sessions:
            s.state = "failed"
            s.error = f"backend not available: {e}"
        return

    by_id = {s.id: s for s in sessions}
    for entry in status.sessions:
        session = by_id.get(entry.session_id)
        if session is None:
            entry.state = "failed"
            entry.error = "session vanished mid-run"
            continue
        entry.state = "running"
        try:
            feedback = await db.feedback_by_index(session.id)
            try:
                profile = profile_registry.get(session.profile_id)
            except Exception:
                profile = None
            metadata, results = await evaluate_session(
                session=session,
                feedback_by_index=feedback,
                profile=profile,
                llm=llm,
                model=req.model,
            )
            await db.insert_evaluation_run(metadata, session.id, results)
            entry.avg = average_scores(results).model_dump()
            entry.state = "ok"
        except Exception as e:
            entry.state = "failed"
            entry.error = f"{type(e).__name__}: {e}"
            traceback.print_exc()

    status.state = "done"
    status.finished_at = datetime.now(timezone.utc)


# ── Public entry ────────────────────────────────────────────────────────


def start_run(
    *,
    req: RunRequest,
    db: EvalDB,
    session_store,
    backend_registry,
    profile_registry,
) -> RunStatus:
    """Kick off a run in the background. Returns the initial RunStatus."""
    run_id = str(uuid4())
    status = RunStatus(
        run_id=run_id,
        state="running",
        started_at=datetime.now(timezone.utc),
        finished_at=None,
        judge_model=req.model,
        judge_version=JUDGE_VERSION,
        rubric_version=RUBRIC_VERSION,
        sessions=[],
    )
    _registry.register(status)
    task = asyncio.create_task(
        _run_loop(
            run_id=run_id,
            req=req,
            db=db,
            session_store=session_store,
            backend_registry=backend_registry,
            profile_registry=profile_registry,
        )
    )
    _registry.attach_task(run_id, task)
    return status