Newer
Older
navi-1 / debug / eval / judge.py
"""Judge orchestration: render a session, fan out across 3 experts, average."""

from __future__ import annotations

import asyncio
import json
from dataclasses import dataclass
from datetime import datetime, timezone
from pathlib import Path
from typing import Iterable
from uuid import uuid4

import yaml

from navi.core.session import Session
from navi.llm.base import LLMBackend, Message
from navi.profiles.base import AgentProfile

from .schema import AXIS_NAMES, EvalRunMetadata, EvalScores, ExpertResult


# Pinned versions for this rubric / judge generation. Bumping either forces
# re-evaluation of the whole archive (or a parallel run, depending on policy).
JUDGE_VERSION: str = "v1"
RUBRIC_VERSION: str = "v1"

EXPERT_IDS: tuple[str, ...] = ("strict_critic", "pragmatist", "tech_lead")

_PROMPTS_DIR = Path(__file__).parent / "prompts"


@dataclass(frozen=True)
class RenderedSession:
    """The text blob the judge actually sees, plus the metadata header."""

    header: str
    transcript: str

    def as_user_message(self) -> str:
        return f"{self.header}\n\n=== Session transcript ===\n{self.transcript}"


# ── Prompt + rubric loading ──────────────────────────────────────────────


def load_rubric() -> dict:
    """Load rubric_v1.yaml as a dict. Raises if the file is missing/malformed."""
    path = _PROMPTS_DIR / f"rubric_{RUBRIC_VERSION}.yaml"
    return yaml.safe_load(path.read_text(encoding="utf-8"))


def load_expert_prompt(expert_id: str) -> str:
    """Read the system prompt for one expert from prompts/expert_<id>.txt."""
    if expert_id not in EXPERT_IDS:
        raise ValueError(f"unknown expert_id: {expert_id}")
    return (_PROMPTS_DIR / f"expert_{expert_id}.txt").read_text(encoding="utf-8")


def render_rubric_for_prompt(rubric: dict) -> str:
    """Compact text rendering of the rubric to inline into the user message.

    Each axis lists three level descriptions (weak / typical / strong).
    Only the `typical` tier carries a numeric reference score — surfaced as
    a single calibration point so judges aren't left without an anchor, but
    intentionally not a multiple of 5 to discourage round-number snapping.
    """
    lines: list[str] = [f"=== Rubric {rubric['version']} ==="]
    for axis_name, axis in rubric["axes"].items():
        lines.append(f"\n## {axis_name}")
        lines.append(axis["description"].strip())
        if axis.get("nullable"):
            lines.append("(nullable: score this null when the mechanic was not used)")
        for level in axis["levels"]:
            score = level.get("score")
            ref = f" (reference ≈ {score})" if score is not None else ""
            lines.append(f"  • {level['label']}{ref} — {level['what']}")
    return "\n".join(lines)


# ── Session rendering ────────────────────────────────────────────────────


_REACTION = {1: "👍", -1: "👎"}


def _format_duration(start: datetime, end: datetime) -> str:
    delta = end - start
    total = int(delta.total_seconds())
    hours, rem = divmod(total, 3600)
    minutes, seconds = divmod(rem, 60)
    if hours:
        return f"{hours}h{minutes:02d}m"
    if minutes:
        return f"{minutes}m{seconds:02d}s"
    return f"{seconds}s"


def _render_header(
    session: Session,
    profile: AgentProfile | None,
    feedback_by_index: dict[int, int],
) -> str:
    likes = sum(1 for v in feedback_by_index.values() if v == 1)
    dislikes = sum(1 for v in feedback_by_index.values() if v == -1)

    counts = {"user": 0, "assistant": 0, "tool": 0, "system": 0}
    for m in session.messages:
        counts[m.role] = counts.get(m.role, 0) + 1

    lines = ["=== Session metadata ==="]
    lines.append(f"Session id: {session.id}")
    if session.name:
        lines.append(f"Name: {session.name}")
    lines.append(f"Profile: {session.profile_id}")
    if profile is not None:
        lines.append(f"Profile description: {profile.description}")
        if profile.model:
            lines.append(f"Model priority list: {', '.join(profile.model)}")
        flags = []
        if profile.planning_enabled:
            flags.append("planning")
        if profile.planning_phase2_enabled:
            flags.append("phase2-review")
        if profile.planning_phase3_enabled:
            flags.append("phase3-plan")
        if profile.think_enabled:
            flags.append("think")
        if profile.adaptive_replan_enabled:
            flags.append("adaptive-replan")
        if profile.step_validation_enabled:
            flags.append("step-validation")
        lines.append(f"Active mechanics: {', '.join(flags) if flags else '(none)'}")
        lines.append(f"Max iterations: {profile.max_iterations}")
        lines.append(f"Temperature: {profile.temperature} | top_k: {profile.top_k} | top_p: {profile.top_p}")
    lines.append(f"Started: {session.created_at.isoformat()}")
    lines.append(f"Last active: {session.last_active.isoformat()}")
    lines.append(f"Duration: {_format_duration(session.created_at, session.last_active)}")
    lines.append(
        f"Messages: total={len(session.messages)} "
        f"user={counts['user']} assistant={counts['assistant']} tool={counts['tool']} system={counts['system']}"
    )
    lines.append(f"User feedback: 👍 {likes} | 👎 {dislikes}")
    return "\n".join(lines)


def _truncate(text: str, limit: int = 4000) -> str:
    if len(text) <= limit:
        return text
    return text[:limit] + f"\n[...truncated, {len(text) - limit} more chars]"


def _render_transcript(
    session: Session,
    feedback_by_index: dict[int, int],
) -> str:
    """Render messages in order with reactions inlined.

    Index is the position in session.messages — same key used for feedback.
    Each assistant block (assistant tool_calls + tool results + final text)
    starts at its first message; the reaction (if any) is keyed on that index.
    """
    out: list[str] = []
    i = 0
    n = len(session.messages)
    while i < n:
        m = session.messages[i]
        idx_marker = f"[#{i}]"
        reaction = _REACTION.get(feedback_by_index.get(i, 0))

        if m.is_compression:
            out.append(f"{idx_marker} [Context compression event] {m.content or ''}")
            i += 1
            continue

        if m.is_summary:
            # Summaries are inserted into context, not the displayed history.
            # If one ended up in messages, surface it but mark it explicitly so
            # the judge knows it isn't original work.
            out.append(f"{idx_marker} [Compressed history block — not original]")
            out.append(_truncate(m.content or ""))
            i += 1
            continue

        if m.role == "system":
            # Bare system messages aren't usually saved to .messages, but if
            # they are, keep them visible.
            out.append(f"{idx_marker} SYSTEM: {_truncate(m.content or '')}")
            i += 1
            continue

        if m.role == "user":
            block = [f"{idx_marker} USER:"]
            if m.content:
                block.append(_truncate(m.content))
            if m.images:
                block.append(f"[user attached {len(m.images)} image(s)]")
            out.append("\n".join(block))
            i += 1
            continue

        if m.role == "assistant":
            block = [f"{idx_marker} ASSISTANT"]
            if reaction:
                block[0] += f" {reaction}"
            block[0] += ":"

            if m.is_plan:
                block.append("[plan]")
                block.append(_truncate(m.content or ""))
                out.append("\n".join(block))
                i += 1
                continue

            if m.thinking:
                block.append("[thinking]")
                block.append(_truncate(m.thinking))

            if m.tool_calls:
                for tc in m.tool_calls:
                    args_preview = json.dumps(tc.arguments, ensure_ascii=False)[:300]
                    block.append(f"[tool_call: {tc.name}] {args_preview}")
                # Collect matching tool results that follow
                i += 1
                while i < n and session.messages[i].role == "tool":
                    tr = session.messages[i]
                    block.append(
                        f"[tool_result: {tr.name or '?'} (id={tr.tool_call_id or '?'})]\n"
                        + _truncate(tr.content or "")
                    )
                    i += 1
                if m.elapsed_seconds is not None or m.tool_call_count is not None:
                    meta = []
                    if m.elapsed_seconds is not None:
                        meta.append(f"elapsed={m.elapsed_seconds}s")
                    if m.tool_call_count is not None:
                        meta.append(f"tool_calls={m.tool_call_count}")
                    if m.token_count is not None:
                        meta.append(f"tokens={m.token_count}")
                    block.append(f"[meta] {' '.join(meta)}")
                out.append("\n".join(block))
                continue

            # Plain assistant text
            if m.content:
                block.append(_truncate(m.content))
            if m.elapsed_seconds is not None or m.token_count is not None:
                meta = []
                if m.elapsed_seconds is not None:
                    meta.append(f"elapsed={m.elapsed_seconds}s")
                if m.token_count is not None:
                    meta.append(f"tokens={m.token_count}")
                block.append(f"[meta] {' '.join(meta)}")
            out.append("\n".join(block))
            i += 1
            continue

        if m.role == "tool":
            # Stray tool message without preceding tool_calls — render bare.
            out.append(
                f"{idx_marker} [tool_result orphan: {m.name or '?'}]\n"
                + _truncate(m.content or "")
            )
            i += 1
            continue

        # Unknown role — keep moving.
        i += 1

    return "\n\n".join(out)


def _render_planning_appendix(session: Session) -> str:
    if not session.planning_logs:
        return ""
    lines = ["=== Planning logs ==="]
    for n, log in enumerate(session.planning_logs, 1):
        ts = log.get("timestamp", "?")
        result = log.get("result", "?")
        lines.append(f"\n--- Turn {n} (timestamp={ts}, classification={result}) ---")
        phases = log.get("phases", {})
        for phase_id, phase in phases.items():
            output = phase.get("output", "")
            lines.append(f"\n[phase {phase_id}]")
            lines.append(_truncate(output, limit=2000))
    return "\n".join(lines)


def render_session(
    session: Session,
    feedback_by_index: dict[int, int] | None = None,
    profile: AgentProfile | None = None,
) -> RenderedSession:
    """Render a session into the text the judge will see.

    Header → metadata + counts.
    Transcript → every message in order with reactions inlined and tool
    results folded under the calling assistant block. No compression-summary
    substitution; the judge sees the original messages.
    Appendix → planning_logs if any.
    """
    feedback = feedback_by_index or {}
    header = _render_header(session, profile, feedback)
    transcript = _render_transcript(session, feedback)
    appendix = _render_planning_appendix(session)
    if appendix:
        transcript = f"{transcript}\n\n{appendix}"
    return RenderedSession(header=header, transcript=transcript)


# ── Expert call ──────────────────────────────────────────────────────────


async def run_expert(
    *,
    expert_id: str,
    rendered: RenderedSession,
    rubric_text: str,
    llm: LLMBackend,
    model: str | list[str],
) -> ExpertResult:
    """Run one expert against the rendered session, parse JSON, validate.

    On invalid JSON or schema mismatch: one retry with a corrective nudge.
    Subsequent failure raises so the caller can record / skip this session.
    """
    system_prompt = load_expert_prompt(expert_id)
    user_message = f"{rubric_text}\n\n{rendered.as_user_message()}"

    base_messages = [
        Message(role="system", content=system_prompt),
        Message(role="user", content=user_message),
    ]

    response = await llm.complete(
        messages=base_messages,
        tools=None,
        temperature=0.5,
        model=model,
        think=False,
    )
    raw = response.content or ""
    try:
        return parse_expert_json(raw, expert_id)
    except (json.JSONDecodeError, ValueError) as first_err:
        # Single retry with explicit correction nudge.
        retry_messages = base_messages + [
            Message(role="assistant", content=raw),
            Message(
                role="user",
                content=(
                    f"Your previous output was invalid: {first_err}. "
                    "Return ONLY the JSON object matching the required schema, "
                    "no prose, no code fences, no extra fields."
                ),
            ),
        ]
        response = await llm.complete(
            messages=retry_messages,
            tools=None,
            temperature=0.1,
            model=model,
            think=False,
        )
        return parse_expert_json(response.content or "", expert_id)


# ── Run orchestration ────────────────────────────────────────────────────


async def evaluate_session(
    *,
    session: Session,
    feedback_by_index: dict[int, int] | None,
    profile: AgentProfile | None,
    llm: LLMBackend,
    model: str | list[str],
) -> tuple[EvalRunMetadata, list[ExpertResult]]:
    """Run all three experts on one session. Returns (metadata, results).

    Experts run in parallel via asyncio.gather. If any expert fails after the
    retry, the whole gather fails and the caller decides whether to skip.
    """
    rubric = load_rubric()
    rubric_text = render_rubric_for_prompt(rubric)
    rendered = render_session(session, feedback_by_index, profile)

    judge_model = model[0] if isinstance(model, list) and model else (
        model if isinstance(model, str) else "unknown"
    )
    metadata = new_run_metadata(judge_model=judge_model)

    coros = [
        run_expert(
            expert_id=eid,
            rendered=rendered,
            rubric_text=rubric_text,
            llm=llm,
            model=model,
        )
        for eid in EXPERT_IDS
    ]
    results = await asyncio.gather(*coros)
    return metadata, results


def average_scores(results: Iterable[ExpertResult]) -> EvalScores:
    """Mean across experts. Nullable axes average over non-null values only;
    if every expert returned null for an axis, the mean stays null."""
    sums: dict[str, int] = {a: 0 for a in AXIS_NAMES}
    counts: dict[str, int] = {a: 0 for a in AXIS_NAMES}
    for r in results:
        for a in AXIS_NAMES:
            v = getattr(r.scores, a)
            if v is None:
                continue
            sums[a] += v
            counts[a] += 1
    averaged: dict[str, int | None] = {}
    for a in AXIS_NAMES:
        averaged[a] = round(sums[a] / counts[a]) if counts[a] else None
    return EvalScores(**averaged)


def new_run_metadata(judge_model: str) -> EvalRunMetadata:
    """Stamp a fresh eval_run_id and pin the rubric/judge versions."""
    return EvalRunMetadata(
        eval_run_id=uuid4(),
        eval_date=datetime.now(timezone.utc),
        judge_model=judge_model,
        judge_version=JUDGE_VERSION,
        rubric_version=RUBRIC_VERSION,
    )


# ── Convenience: parse JSON output that may have stray fences ────────────


def parse_expert_json(raw: str, expected_expert_id: str) -> ExpertResult:
    """Strip code fences if present, parse, validate against ExpertResult.

    Tolerant to ```json fences and surrounding whitespace; strict otherwise.
    """
    text = raw.strip()
    if text.startswith("```"):
        # Drop the opening fence (with optional language tag) and the closing fence.
        first_nl = text.find("\n")
        text = text[first_nl + 1:] if first_nl != -1 else text[3:]
        if text.rstrip().endswith("```"):
            text = text.rstrip()[:-3]
    data = json.loads(text)
    if data.get("expert_id") != expected_expert_id:
        raise ValueError(
            f"expert_id mismatch: expected {expected_expert_id!r}, got {data.get('expert_id')!r}"
        )
    return ExpertResult.model_validate(data)