"""asyncpg helpers for the eval system.
Phase 1 surface: message feedback (like / dislike / clear).
Phase 3 surface: evaluation runs — bulk insert, per-session list, completion check.
Schema is applied lazily on the first pool acquire — same pattern as
navi/core/pg_session_store.py. The full DDL lives in schema.sql alongside.
"""
from __future__ import annotations
import asyncio
import json
from pathlib import Path
import asyncpg
from .schema import EvalRunMetadata, ExpertResult, StoredEvaluation
_SCHEMA_PATH = Path(__file__).parent / "schema.sql"
class EvalDB:
"""Owns its own asyncpg pool. Reuses settings.database_url."""
def __init__(self, dsn: str) -> None:
self._dsn = dsn
self._pool: asyncpg.Pool | None = None
self._lock = asyncio.Lock()
async def _get_pool(self) -> asyncpg.Pool:
if self._pool is not None:
return self._pool
async with self._lock:
if self._pool is None:
pool = await asyncpg.create_pool(self._dsn)
async with pool.acquire() as conn:
await conn.execute(_SCHEMA_PATH.read_text(encoding="utf-8"))
self._pool = pool
return self._pool
# ── Feedback ──────────────────────────────────────────────────────────
async def set_feedback(
self, session_id: str, message_index: int, rating: int
) -> None:
if rating not in (-1, 1):
raise ValueError("rating must be -1 or 1")
pool = await self._get_pool()
async with pool.acquire() as conn:
await conn.execute(
"""
INSERT INTO message_feedback (session_id, message_index, rating)
VALUES ($1, $2, $3)
ON CONFLICT (session_id, message_index)
DO UPDATE SET rating = EXCLUDED.rating, updated_at = now()
""",
session_id, message_index, rating,
)
async def clear_feedback(self, session_id: str, message_index: int) -> None:
pool = await self._get_pool()
async with pool.acquire() as conn:
await conn.execute(
"DELETE FROM message_feedback WHERE session_id = $1 AND message_index = $2",
session_id, message_index,
)
async def list_feedback(self, session_id: str) -> list[dict]:
pool = await self._get_pool()
async with pool.acquire() as conn:
rows = await conn.fetch(
"SELECT message_index, rating, created_at, updated_at "
"FROM message_feedback WHERE session_id = $1 ORDER BY message_index",
session_id,
)
return [
{
"message_index": r["message_index"],
"rating": r["rating"],
"created_at": r["created_at"].isoformat(),
"updated_at": r["updated_at"].isoformat(),
}
for r in rows
]
async def feedback_by_index(self, session_id: str) -> dict[int, int]:
"""Convenience: {message_index: rating} for the judge renderer."""
rows = await self.list_feedback(session_id)
return {r["message_index"]: r["rating"] for r in rows}
# ── Evaluations ───────────────────────────────────────────────────────
async def insert_evaluation_run(
self,
metadata: EvalRunMetadata,
session_id: str,
results: list[ExpertResult],
) -> None:
"""Persist all expert results from one run under a single eval_run_id."""
if not results:
return
pool = await self._get_pool()
async with pool.acquire() as conn:
async with conn.transaction():
for r in results:
await conn.execute(
"""
INSERT INTO evaluations
(id, session_id, eval_run_id, eval_date, judge_model,
judge_version, rubric_version, expert_id, scores, comment)
VALUES (gen_random_uuid(), $1, $2, $3, $4, $5, $6, $7, $8::jsonb, $9)
""",
session_id,
metadata.eval_run_id,
metadata.eval_date,
metadata.judge_model,
metadata.judge_version,
metadata.rubric_version,
r.expert_id,
r.scores.model_dump_json(),
r.comment,
)
async def list_evaluations(self, session_id: str) -> list[StoredEvaluation]:
"""All evaluation rows for a session, newest run first."""
pool = await self._get_pool()
async with pool.acquire() as conn:
rows = await conn.fetch(
"""
SELECT id, session_id, eval_run_id, eval_date, judge_model,
judge_version, rubric_version, expert_id, scores, comment
FROM evaluations
WHERE session_id = $1
ORDER BY eval_date DESC, expert_id ASC
""",
session_id,
)
return [
StoredEvaluation.model_validate(
{
"id": r["id"],
"session_id": r["session_id"],
"eval_run_id": r["eval_run_id"],
"eval_date": r["eval_date"],
"judge_model": r["judge_model"],
"judge_version": r["judge_version"],
"rubric_version": r["rubric_version"],
"expert_id": r["expert_id"],
"scores": json.loads(r["scores"]) if isinstance(r["scores"], str) else r["scores"],
"comment": r["comment"],
}
)
for r in rows
]
async def evaluated_session_ids(
self, judge_version: str, rubric_version: str, expected_experts: int = 3
) -> set[str]:
"""Sessions with a complete run at the current rubric/judge versions.
A session counts as evaluated only when at least `expected_experts`
distinct expert rows exist for that (judge_version, rubric_version).
"""
pool = await self._get_pool()
async with pool.acquire() as conn:
rows = await conn.fetch(
"""
SELECT session_id
FROM evaluations
WHERE judge_version = $1 AND rubric_version = $2
GROUP BY session_id
HAVING COUNT(DISTINCT expert_id) >= $3
""",
judge_version,
rubric_version,
expected_experts,
)
return {r["session_id"] for r in rows}