Newer
Older
navi-1 / debug / eval / db.py
"""asyncpg helpers for the eval system.

Phase 1 surface: message feedback (like / dislike / clear).
Phase 3 surface: evaluation runs — bulk insert, per-session list, completion check.

Schema is applied lazily on the first pool acquire — same pattern as
navi/core/pg_session_store.py. The full DDL lives in schema.sql alongside.
"""

from __future__ import annotations

import asyncio
import json
from pathlib import Path

import asyncpg

from .schema import EvalRunMetadata, ExpertResult, StoredEvaluation


_SCHEMA_PATH = Path(__file__).parent / "schema.sql"


class EvalDB:
    """Owns its own asyncpg pool. Reuses settings.database_url."""

    def __init__(self, dsn: str) -> None:
        self._dsn = dsn
        self._pool: asyncpg.Pool | None = None
        self._lock = asyncio.Lock()

    async def _get_pool(self) -> asyncpg.Pool:
        if self._pool is not None:
            return self._pool
        async with self._lock:
            if self._pool is None:
                pool = await asyncpg.create_pool(self._dsn)
                async with pool.acquire() as conn:
                    await conn.execute(_SCHEMA_PATH.read_text(encoding="utf-8"))
                self._pool = pool
        return self._pool

    # ── Feedback ──────────────────────────────────────────────────────────

    async def set_feedback(
        self, session_id: str, message_index: int, rating: int
    ) -> None:
        if rating not in (-1, 1):
            raise ValueError("rating must be -1 or 1")
        pool = await self._get_pool()
        async with pool.acquire() as conn:
            await conn.execute(
                """
                INSERT INTO message_feedback (session_id, message_index, rating)
                VALUES ($1, $2, $3)
                ON CONFLICT (session_id, message_index)
                DO UPDATE SET rating = EXCLUDED.rating, updated_at = now()
                """,
                session_id, message_index, rating,
            )

    async def clear_feedback(self, session_id: str, message_index: int) -> None:
        pool = await self._get_pool()
        async with pool.acquire() as conn:
            await conn.execute(
                "DELETE FROM message_feedback WHERE session_id = $1 AND message_index = $2",
                session_id, message_index,
            )

    async def list_feedback(self, session_id: str) -> list[dict]:
        pool = await self._get_pool()
        async with pool.acquire() as conn:
            rows = await conn.fetch(
                "SELECT message_index, rating, created_at, updated_at "
                "FROM message_feedback WHERE session_id = $1 ORDER BY message_index",
                session_id,
            )
        return [
            {
                "message_index": r["message_index"],
                "rating": r["rating"],
                "created_at": r["created_at"].isoformat(),
                "updated_at": r["updated_at"].isoformat(),
            }
            for r in rows
        ]

    async def feedback_by_index(self, session_id: str) -> dict[int, int]:
        """Convenience: {message_index: rating} for the judge renderer."""
        rows = await self.list_feedback(session_id)
        return {r["message_index"]: r["rating"] for r in rows}

    # ── Evaluations ───────────────────────────────────────────────────────

    async def insert_evaluation_run(
        self,
        metadata: EvalRunMetadata,
        session_id: str,
        results: list[ExpertResult],
    ) -> None:
        """Persist all expert results from one run under a single eval_run_id."""
        if not results:
            return
        pool = await self._get_pool()
        async with pool.acquire() as conn:
            async with conn.transaction():
                for r in results:
                    await conn.execute(
                        """
                        INSERT INTO evaluations
                            (id, session_id, eval_run_id, eval_date, judge_model,
                             judge_version, rubric_version, expert_id, scores, comment)
                        VALUES (gen_random_uuid(), $1, $2, $3, $4, $5, $6, $7, $8::jsonb, $9)
                        """,
                        session_id,
                        metadata.eval_run_id,
                        metadata.eval_date,
                        metadata.judge_model,
                        metadata.judge_version,
                        metadata.rubric_version,
                        r.expert_id,
                        r.scores.model_dump_json(),
                        r.comment,
                    )

    async def list_evaluations(self, session_id: str) -> list[StoredEvaluation]:
        """All evaluation rows for a session, newest run first."""
        pool = await self._get_pool()
        async with pool.acquire() as conn:
            rows = await conn.fetch(
                """
                SELECT id, session_id, eval_run_id, eval_date, judge_model,
                       judge_version, rubric_version, expert_id, scores, comment
                FROM evaluations
                WHERE session_id = $1
                ORDER BY eval_date DESC, expert_id ASC
                """,
                session_id,
            )
        return [
            StoredEvaluation.model_validate(
                {
                    "id": r["id"],
                    "session_id": r["session_id"],
                    "eval_run_id": r["eval_run_id"],
                    "eval_date": r["eval_date"],
                    "judge_model": r["judge_model"],
                    "judge_version": r["judge_version"],
                    "rubric_version": r["rubric_version"],
                    "expert_id": r["expert_id"],
                    "scores": json.loads(r["scores"]) if isinstance(r["scores"], str) else r["scores"],
                    "comment": r["comment"],
                }
            )
            for r in rows
        ]

    async def evaluated_session_ids(
        self, judge_version: str, rubric_version: str, expected_experts: int = 3
    ) -> set[str]:
        """Sessions with a complete run at the current rubric/judge versions.

        A session counts as evaluated only when at least `expected_experts`
        distinct expert rows exist for that (judge_version, rubric_version).
        """
        pool = await self._get_pool()
        async with pool.acquire() as conn:
            rows = await conn.fetch(
                """
                SELECT session_id
                FROM evaluations
                WHERE judge_version = $1 AND rubric_version = $2
                GROUP BY session_id
                HAVING COUNT(DISTINCT expert_id) >= $3
                """,
                judge_version,
                rubric_version,
                expected_experts,
            )
        return {r["session_id"] for r in rows}

    # ── Aggregations for the read endpoints ────────────────────────────────

    async def list_sessions_overview(
        self,
        *,
        judge_version: str,
        rubric_version: str,
        limit: int = 50,
        offset: int = 0,
        profile: str | None = None,
        status: str | None = None,  # "evaluated" | "pending" | "stale" | None
    ) -> list[dict]:
        """One row per session with feedback counts and the latest eval summary.

        Joined in a single query against `sessions`, `message_feedback`,
        `evaluations`. Status is derived against the current pinned versions.
        """
        pool = await self._get_pool()
        clauses: list[str] = []
        params: list = []

        # judge_version / rubric_version are only used when filtering by status;
        # otherwise we don't bind them as parameters (Postgres would error on
        # unused-parameter type inference).
        if status == "evaluated":
            params.extend([judge_version, rubric_version])
            clauses.append(
                f"latest.judge_version = ${len(params) - 1} "
                f"AND latest.rubric_version = ${len(params)}"
            )
        elif status == "pending":
            clauses.append("latest.eval_run_id IS NULL")
        elif status == "stale":
            params.extend([judge_version, rubric_version])
            clauses.append(
                "latest.eval_run_id IS NOT NULL AND "
                f"(latest.judge_version <> ${len(params) - 1} "
                f"OR latest.rubric_version <> ${len(params)})"
            )

        if profile is not None:
            params.append(profile)
            clauses.append(f"s.profile_id = ${len(params)}")

        where = ("WHERE " + " AND ".join(clauses)) if clauses else ""
        params.extend([limit, offset])

        sql = f"""
            WITH latest_run AS (
                SELECT DISTINCT ON (session_id)
                    session_id, eval_run_id, eval_date, judge_version, rubric_version
                FROM evaluations
                ORDER BY session_id, eval_date DESC
            ),
            run_avg AS (
                SELECT e.eval_run_id,
                    AVG((e.scores->>'task_complexity')::numeric)        AS task_complexity,
                    AVG((e.scores->>'goal_completion')::numeric)        AS goal_completion,
                    AVG((e.scores->>'tool_usage_quality')::numeric)     AS tool_usage_quality,
                    AVG((e.scores->>'efficiency')::numeric)             AS efficiency,
                    AVG((e.scores->>'communication')::numeric)          AS communication,
                    AVG(NULLIF(e.scores->>'subagent_orchestration','null')::numeric) AS subagent_orchestration,
                    AVG(NULLIF(e.scores->>'self_extension','null')::numeric)         AS self_extension,
                    COUNT(*) AS expert_count
                FROM evaluations e
                JOIN latest_run lr ON lr.eval_run_id = e.eval_run_id
                GROUP BY e.eval_run_id
            ),
            feedback_counts AS (
                SELECT session_id,
                    SUM(CASE WHEN rating =  1 THEN 1 ELSE 0 END)::int AS likes,
                    SUM(CASE WHEN rating = -1 THEN 1 ELSE 0 END)::int AS dislikes
                FROM message_feedback
                GROUP BY session_id
            )
            SELECT
                s.id, s.profile_id, s.name, s.created_at, s.last_active, s.pinned,
                COALESCE(jsonb_array_length(s.messages::jsonb), 0) AS msg_count,
                COALESCE(fc.likes, 0)    AS likes,
                COALESCE(fc.dislikes, 0) AS dislikes,
                latest.eval_run_id   AS latest_eval_run_id,
                latest.eval_date     AS latest_eval_date,
                latest.judge_version AS latest_judge_version,
                latest.rubric_version AS latest_rubric_version,
                ra.task_complexity, ra.goal_completion, ra.tool_usage_quality,
                ra.efficiency, ra.communication,
                ra.subagent_orchestration, ra.self_extension,
                ra.expert_count
            FROM sessions s
            LEFT JOIN feedback_counts fc ON fc.session_id = s.id
            LEFT JOIN latest_run     latest ON latest.session_id = s.id
            LEFT JOIN run_avg        ra     ON ra.eval_run_id   = latest.eval_run_id
            {where}
            ORDER BY s.pinned DESC, s.last_active DESC
            LIMIT ${len(params) - 1} OFFSET ${len(params)}
        """

        async with pool.acquire() as conn:
            rows = await conn.fetch(sql, *params)

        result = []
        for r in rows:
            if r["latest_eval_run_id"] is None:
                eval_status = "pending"
            elif (
                r["latest_judge_version"] == judge_version
                and r["latest_rubric_version"] == rubric_version
            ):
                eval_status = "evaluated"
            else:
                eval_status = "stale"

            avg = None
            if r["latest_eval_run_id"] is not None:
                avg = {
                    "task_complexity": _round_or_none(r["task_complexity"]),
                    "goal_completion": _round_or_none(r["goal_completion"]),
                    "tool_usage_quality": _round_or_none(r["tool_usage_quality"]),
                    "efficiency": _round_or_none(r["efficiency"]),
                    "communication": _round_or_none(r["communication"]),
                    "subagent_orchestration": _round_or_none(r["subagent_orchestration"]),
                    "self_extension": _round_or_none(r["self_extension"]),
                }

            result.append(
                {
                    "session_id": r["id"],
                    "profile_id": r["profile_id"],
                    "name": r["name"],
                    "created_at": r["created_at"],
                    "last_active": r["last_active"],
                    "pinned": r["pinned"],
                    "msg_count": r["msg_count"],
                    "likes": r["likes"],
                    "dislikes": r["dislikes"],
                    "eval_status": eval_status,
                    "latest_avg": avg,
                    "latest_eval_date": r["latest_eval_date"],
                    "latest_judge_version": r["latest_judge_version"],
                    "latest_rubric_version": r["latest_rubric_version"],
                }
            )
        return result

    async def aggregate_stats(
        self,
        *,
        judge_version: str,
        rubric_version: str,
        days: int = 30,
        by_complexity_bucket: bool = False,
    ) -> dict:
        """Weekly rolling per-axis averages over the last `days` days.

        Returns:
          {
            buckets: [bucket_label, ...],         # ["overall"] or ["0-25", ...]
            weekly: [
              {week_start, bucket, axis_means: {...}, sample_count}
            ]
          }
        Bucket is computed from task_complexity for each session's latest run.
        """
        pool = await self._get_pool()
        async with pool.acquire() as conn:
            rows = await conn.fetch(
                """
                WITH latest_run AS (
                    SELECT DISTINCT ON (session_id)
                        session_id, eval_run_id, eval_date
                    FROM evaluations
                    WHERE judge_version = $1 AND rubric_version = $2
                      AND eval_date >= now() - ($3::text || ' days')::interval
                    ORDER BY session_id, eval_date DESC
                )
                SELECT
                    date_trunc('week', lr.eval_date) AS week_start,
                    AVG((e.scores->>'task_complexity')::numeric)        AS task_complexity,
                    AVG((e.scores->>'goal_completion')::numeric)        AS goal_completion,
                    AVG((e.scores->>'tool_usage_quality')::numeric)     AS tool_usage_quality,
                    AVG((e.scores->>'efficiency')::numeric)             AS efficiency,
                    AVG((e.scores->>'communication')::numeric)          AS communication,
                    AVG(NULLIF(e.scores->>'subagent_orchestration','null')::numeric) AS subagent_orchestration,
                    AVG(NULLIF(e.scores->>'self_extension','null')::numeric)         AS self_extension,
                    COUNT(DISTINCT lr.session_id) AS sample_count
                FROM evaluations e
                JOIN latest_run lr ON lr.eval_run_id = e.eval_run_id
                GROUP BY week_start
                ORDER BY week_start
                """,
                judge_version,
                rubric_version,
                str(days),
            )

        weekly = [
            {
                "week_start": r["week_start"].isoformat(),
                "bucket": "overall",
                "sample_count": r["sample_count"],
                "axis_means": {
                    "task_complexity": _round_or_none(r["task_complexity"]),
                    "goal_completion": _round_or_none(r["goal_completion"]),
                    "tool_usage_quality": _round_or_none(r["tool_usage_quality"]),
                    "efficiency": _round_or_none(r["efficiency"]),
                    "communication": _round_or_none(r["communication"]),
                    "subagent_orchestration": _round_or_none(r["subagent_orchestration"]),
                    "self_extension": _round_or_none(r["self_extension"]),
                },
            }
            for r in rows
        ]

        # Optional bucket split — second query, grouped by complexity bucket.
        if by_complexity_bucket:
            async with pool.acquire() as conn:
                bucket_rows = await conn.fetch(
                    """
                    WITH latest_run AS (
                        SELECT DISTINCT ON (session_id)
                            session_id, eval_run_id, eval_date
                        FROM evaluations
                        WHERE judge_version = $1 AND rubric_version = $2
                          AND eval_date >= now() - ($3::text || ' days')::interval
                        ORDER BY session_id, eval_date DESC
                    ),
                    run_avg AS (
                        SELECT
                            lr.session_id,
                            date_trunc('week', lr.eval_date) AS week_start,
                            AVG((e.scores->>'task_complexity')::numeric)        AS task_complexity,
                            AVG((e.scores->>'goal_completion')::numeric)        AS goal_completion,
                            AVG((e.scores->>'tool_usage_quality')::numeric)     AS tool_usage_quality,
                            AVG((e.scores->>'efficiency')::numeric)             AS efficiency,
                            AVG((e.scores->>'communication')::numeric)          AS communication
                        FROM evaluations e
                        JOIN latest_run lr ON lr.eval_run_id = e.eval_run_id
                        GROUP BY lr.session_id, week_start
                    )
                    SELECT
                        week_start,
                        CASE
                            WHEN task_complexity <= 25 THEN '0-25'
                            WHEN task_complexity <= 50 THEN '26-50'
                            WHEN task_complexity <= 75 THEN '51-75'
                            ELSE '76+'
                        END AS bucket,
                        AVG(task_complexity)    AS task_complexity,
                        AVG(goal_completion)    AS goal_completion,
                        AVG(tool_usage_quality) AS tool_usage_quality,
                        AVG(efficiency)         AS efficiency,
                        AVG(communication)      AS communication,
                        COUNT(*)                AS sample_count
                    FROM run_avg
                    GROUP BY week_start, bucket
                    ORDER BY week_start, bucket
                    """,
                    judge_version,
                    rubric_version,
                    str(days),
                )
            weekly.extend(
                {
                    "week_start": r["week_start"].isoformat(),
                    "bucket": r["bucket"],
                    "sample_count": r["sample_count"],
                    "axis_means": {
                        "task_complexity": _round_or_none(r["task_complexity"]),
                        "goal_completion": _round_or_none(r["goal_completion"]),
                        "tool_usage_quality": _round_or_none(r["tool_usage_quality"]),
                        "efficiency": _round_or_none(r["efficiency"]),
                        "communication": _round_or_none(r["communication"]),
                    },
                }
                for r in bucket_rows
            )

        buckets = ["0-25", "26-50", "51-75", "76+"] if by_complexity_bucket else ["overall"]
        return {"buckets": buckets, "weekly": weekly}


def _round_or_none(v) -> int | None:
    if v is None:
        return None
    return round(float(v))