Newer
Older
navi-1 / debug / eval / db.py
"""asyncpg helpers for the eval system.

Phase 1 surface: message feedback (like / dislike / clear).
Phase 3 surface: evaluation runs — bulk insert, per-session list, completion check.

Schema is applied lazily on the first pool acquire — same pattern as
navi/core/pg_session_store.py. The full DDL lives in schema.sql alongside.
"""

from __future__ import annotations

import asyncio
import json
from pathlib import Path

import asyncpg

from .schema import EvalRunMetadata, ExpertResult, StoredEvaluation


_SCHEMA_PATH = Path(__file__).parent / "schema.sql"


class EvalDB:
    """Owns its own asyncpg pool. Reuses settings.database_url."""

    def __init__(self, dsn: str) -> None:
        self._dsn = dsn
        self._pool: asyncpg.Pool | None = None
        self._lock = asyncio.Lock()

    async def _get_pool(self) -> asyncpg.Pool:
        if self._pool is not None:
            return self._pool
        async with self._lock:
            if self._pool is None:
                pool = await asyncpg.create_pool(self._dsn)
                async with pool.acquire() as conn:
                    await conn.execute(_SCHEMA_PATH.read_text(encoding="utf-8"))
                self._pool = pool
        return self._pool

    # ── Feedback ──────────────────────────────────────────────────────────

    async def set_feedback(
        self, session_id: str, message_index: int, rating: int
    ) -> None:
        if rating not in (-1, 1):
            raise ValueError("rating must be -1 or 1")
        pool = await self._get_pool()
        async with pool.acquire() as conn:
            await conn.execute(
                """
                INSERT INTO message_feedback (session_id, message_index, rating)
                VALUES ($1, $2, $3)
                ON CONFLICT (session_id, message_index)
                DO UPDATE SET rating = EXCLUDED.rating, updated_at = now()
                """,
                session_id, message_index, rating,
            )

    async def clear_feedback(self, session_id: str, message_index: int) -> None:
        pool = await self._get_pool()
        async with pool.acquire() as conn:
            await conn.execute(
                "DELETE FROM message_feedback WHERE session_id = $1 AND message_index = $2",
                session_id, message_index,
            )

    async def list_feedback(self, session_id: str) -> list[dict]:
        pool = await self._get_pool()
        async with pool.acquire() as conn:
            rows = await conn.fetch(
                "SELECT message_index, rating, created_at, updated_at "
                "FROM message_feedback WHERE session_id = $1 ORDER BY message_index",
                session_id,
            )
        return [
            {
                "message_index": r["message_index"],
                "rating": r["rating"],
                "created_at": r["created_at"].isoformat(),
                "updated_at": r["updated_at"].isoformat(),
            }
            for r in rows
        ]

    async def feedback_by_index(self, session_id: str) -> dict[int, int]:
        """Convenience: {message_index: rating} for the judge renderer."""
        rows = await self.list_feedback(session_id)
        return {r["message_index"]: r["rating"] for r in rows}

    # ── Evaluations ───────────────────────────────────────────────────────

    async def insert_evaluation_run(
        self,
        metadata: EvalRunMetadata,
        session_id: str,
        results: list[ExpertResult],
    ) -> None:
        """Persist all expert results from one run under a single eval_run_id."""
        if not results:
            return
        pool = await self._get_pool()
        async with pool.acquire() as conn:
            async with conn.transaction():
                for r in results:
                    await conn.execute(
                        """
                        INSERT INTO evaluations
                            (id, session_id, eval_run_id, eval_date, judge_model,
                             judge_version, rubric_version, expert_id, scores, comment)
                        VALUES (gen_random_uuid(), $1, $2, $3, $4, $5, $6, $7, $8::jsonb, $9)
                        """,
                        session_id,
                        metadata.eval_run_id,
                        metadata.eval_date,
                        metadata.judge_model,
                        metadata.judge_version,
                        metadata.rubric_version,
                        r.expert_id,
                        r.scores.model_dump_json(),
                        r.comment,
                    )

    async def list_evaluations(self, session_id: str) -> list[StoredEvaluation]:
        """All evaluation rows for a session, newest run first."""
        pool = await self._get_pool()
        async with pool.acquire() as conn:
            rows = await conn.fetch(
                """
                SELECT id, session_id, eval_run_id, eval_date, judge_model,
                       judge_version, rubric_version, expert_id, scores, comment
                FROM evaluations
                WHERE session_id = $1
                ORDER BY eval_date DESC, expert_id ASC
                """,
                session_id,
            )
        return [
            StoredEvaluation.model_validate(
                {
                    "id": r["id"],
                    "session_id": r["session_id"],
                    "eval_run_id": r["eval_run_id"],
                    "eval_date": r["eval_date"],
                    "judge_model": r["judge_model"],
                    "judge_version": r["judge_version"],
                    "rubric_version": r["rubric_version"],
                    "expert_id": r["expert_id"],
                    "scores": json.loads(r["scores"]) if isinstance(r["scores"], str) else r["scores"],
                    "comment": r["comment"],
                }
            )
            for r in rows
        ]

    async def evaluated_session_ids(
        self, judge_version: str, rubric_version: str, expected_experts: int = 3
    ) -> set[str]:
        """Sessions with a complete run at the current rubric/judge versions.

        A session counts as evaluated only when at least `expected_experts`
        distinct expert rows exist for that (judge_version, rubric_version).
        """
        pool = await self._get_pool()
        async with pool.acquire() as conn:
            rows = await conn.fetch(
                """
                SELECT session_id
                FROM evaluations
                WHERE judge_version = $1 AND rubric_version = $2
                GROUP BY session_id
                HAVING COUNT(DISTINCT expert_id) >= $3
                """,
                judge_version,
                rubric_version,
                expected_experts,
            )
        return {r["session_id"] for r in rows}