Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
152 changes: 119 additions & 33 deletions nemo_retriever/src/nemo_retriever/skill_eval/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,12 +92,13 @@ def _preflight_judge_endpoint(api_base: str, timeout: float = 5.0) -> None:
)


def _build_judge(cfg: dict) -> Optional[Any]:
"""Construct an ``LLMJudge`` from ``cfg['judge']`` or return ``None``.
def _build_judge(cfg: dict, *, manifest_path: Path) -> Optional[Any]:
"""Construct a ``JudgeContext`` from ``cfg['judge']`` or return ``None``.

Skips silently (with a console note) when the API key env var is unset, so
runs work end-to-end without network access. Import is deferred so the
``litellm`` extra isn't required when judging is disabled.
Skips silently (with a console note) when the API key env var is unset,
so runs work end-to-end without network access. Prompt paths default
to the manifest's parent directory if not overridden in the config;
the file must exist for any mode the run actually needs.
"""
judge_cfg = cfg.get("judge") or {}
if not judge_cfg.get("enabled", True):
Expand All @@ -109,25 +110,65 @@ def _build_judge(cfg: dict) -> Optional[Any]:
typer.echo(f"Judge disabled: ${api_key_env} is not set in the environment.")
return None
try:
from nemo_retriever.llm.clients.judge import LLMJudge
from nemo_retriever.llm.clients.litellm import LiteLLMClient
except ImportError as exc:
typer.echo(f"Judge disabled: failed to import LLMJudge ({exc}). Install nemo-retriever[llm].")
typer.echo(f"Judge disabled: failed to import LiteLLMClient ({exc}). Install nemo-retriever[llm].")
return None
from nemo_retriever.skill_eval.runner import JudgeContext

api_base = judge_cfg.get("api_base")
if api_base:
_preflight_judge_endpoint(str(api_base))
judge_kwargs: dict[str, Any] = {
"model": str(judge_cfg.get("model", "nvidia_nim/nvidia/llama-3.3-nemotron-super-49b-v1.5")),
"api_base": api_base,
"api_key": api_key,
}
if judge_cfg.get("temperature") is not None:
judge_kwargs["temperature"] = float(judge_cfg["temperature"])
if judge_cfg.get("max_tokens") is not None:
judge_kwargs["max_tokens"] = int(judge_cfg["max_tokens"])
judge = LLMJudge.from_kwargs(**judge_kwargs)
typer.echo(f"Judge enabled: model={judge.model}")
return judge

client = LiteLLMClient.from_kwargs(
model=str(judge_cfg.get("model", "nvidia_nim/nvidia/llama-3.3-nemotron-super-49b-v1.5")),
api_base=api_base,
api_key=api_key,
temperature=float(judge_cfg.get("temperature", 0.1)),
max_tokens=int(judge_cfg.get("max_tokens", 4096)),
)

legacy_judge = None
if judge_cfg.get("legacy_enabled", True):
try:
from nemo_retriever.llm.clients.judge import LLMJudge
except ImportError as exc:
typer.echo(f"Legacy judge disabled: failed to import LLMJudge ({exc}).")
else:
legacy_judge = LLMJudge.from_kwargs(
model=str(judge_cfg.get("model", "nvidia_nim/nvidia/llama-3.3-nemotron-super-49b-v1.5")),
api_base=api_base,
api_key=api_key,
temperature=float(judge_cfg.get("temperature", 0.1)),
max_tokens=int(judge_cfg.get("max_tokens", 4096)),
)

manifest_dir = Path(manifest_path).expanduser().resolve().parent
simple_path = judge_cfg.get("simple_prompt_path")
scenario_path = judge_cfg.get("scenario_prompt_path")
simple_resolved = (
Path(str(simple_path)).expanduser().resolve() if simple_path else manifest_dir / "llm_scorer_prompt.md"
)
scenario_resolved = (
Path(str(scenario_path)).expanduser().resolve()
if scenario_path
else manifest_dir / "llm_scenario_scorer_prompt.md"
)
ctx = JudgeContext(
client=client,
simple_prompt_path=str(simple_resolved) if simple_resolved.is_file() else None,
scenario_prompt_path=str(scenario_resolved) if scenario_resolved.is_file() else None,
legacy_judge=legacy_judge,
)
typer.echo(
"Judge enabled: model={m} simple_prompt={s} scenario_prompt={c} legacy={legacy}".format(
m=client.transport.model,
s=ctx.simple_prompt_path or "(missing)",
c=ctx.scenario_prompt_path or "(missing)",
legacy="on" if ctx.legacy_judge is not None else "off",
)
)
return ctx


def _build_trace_summarizer(cfg: dict) -> Optional[Any]:
Expand Down Expand Up @@ -345,6 +386,15 @@ def run_command(
"Defaults to config.query_parallelism, then 1 (linear session)."
),
),
limit_queries: Optional[int] = typer.Option(
None,
"--limit-queries",
min=1,
help=(
"Cap queries per domain to the first N entries (deterministic). "
"Useful for smoke-tests; omit for the full sweep."
),
),
) -> None:
"""Run the benchmark across the dataset's domains x selected conditions."""
logging.basicConfig(level=logging.INFO, format="%(asctime)s %(levelname)s %(name)s: %(message)s")
Expand Down Expand Up @@ -378,6 +428,10 @@ def run_command(
raise typer.Exit(code=2)
by_domain = {d: by_domain[d] for d in wanted}

if limit_queries is not None:
by_domain = {d: es[:limit_queries] for d, es in by_domain.items()}
typer.echo(f"--limit-queries={limit_queries}: capping each domain to its first {limit_queries} entries.")

domain_order = sorted(by_domain.keys())
typer.echo(f"Domains in this run: {domain_order} ({sum(len(v) for v in by_domain.values())} entries total)")

Expand All @@ -403,7 +457,7 @@ def run_command(
raise typer.Exit(code=2)
testdata_prefixes = tuple(str(p) for p in testdata_prefixes_raw)

judge = _build_judge(cfg)
judge = _build_judge(cfg, manifest_path=Path(str(manifest_path)).expanduser().resolve())
summarizer = _build_trace_summarizer(cfg)

base_dir = str(artifacts_root) if artifacts_root else None
Expand Down Expand Up @@ -474,7 +528,16 @@ def run_command(
for r in results:
save_trial(r, session_dir)
kind = "setup" if r.is_setup else f"entry_id={r.entry_id} query_id={r.query_id}"
judge_str = "" if r.is_setup or r.judge_score is None else f" judge={r.judge_score}"
judge_parts: list[str] = []
if not r.is_setup:
if r.judge_score is not None:
judge_parts.append(f"judge={r.judge_score}")
elif any(v is not None for v in r.judge_subscores.values()):
n = sum(1 for v in r.judge_subscores.values() if v is not None)
judge_parts.append(f"judge_mode={r.judge_mode}/sub_n={n}")
if r.legacy_judge_score is not None:
judge_parts.append(f"legacy={r.legacy_judge_score}")
judge_str = (" " + " ".join(judge_parts)) if judge_parts else ""
cost_str = f"${r.total_cost_usd:.3f}" if r.cost_available else "n/a"
trace_str = f" trace={r.compact_trace_path}" if r.compact_trace_path else ""
typer.echo(
Expand All @@ -500,21 +563,34 @@ def run_command(
if judge is not None:
typer.echo("\nLLM-as-judge scores (mean over query turns, 0-5 scale):")
for cond in selected:
scored: list[int] = []
simple_scored: list[int] = []
scenario_scored = 0
legacy_scored: list[int] = []
errored = 0
for domain in domain_order:
for r in results_by_key.get((agent, cond, domain), []):
if r.is_setup:
continue
has_subscores = any(v is not None for v in r.judge_subscores.values())
if r.judge_score is not None:
scored.append(int(r.judge_score))
simple_scored.append(int(r.judge_score))
elif has_subscores:
scenario_scored += 1
elif r.judge_error:
errored += 1
if scored:
mean_score = sum(scored) / len(scored)
typer.echo(f" {agent}/{cond}: mean={mean_score:.2f} n={len(scored)} errors={errored}")
else:
typer.echo(f" {agent}/{cond}: no scores errors={errored} (check judge config / litellm install)")
if r.legacy_judge_score is not None:
legacy_scored.append(int(r.legacy_judge_score))
parts: list[str] = []
if simple_scored:
parts.append(f"simple mean={sum(simple_scored) / len(simple_scored):.2f} n={len(simple_scored)}")
if scenario_scored:
parts.append(f"scenario_scored={scenario_scored}")
if legacy_scored:
parts.append(f"legacy mean={sum(legacy_scored) / len(legacy_scored):.2f} n={len(legacy_scored)}")
parts.append(f"errors={errored}")
if not simple_scored and not scenario_scored and not legacy_scored:
parts.append("(check judge config / litellm install)")
typer.echo(f" {agent}/{cond}: " + " ".join(parts))

json_path, md_path = write_summary(
session_dir=session_dir,
Expand All @@ -537,8 +613,9 @@ def _needs_rescore(trial: dict[str, Any]) -> bool:
judge_error = trial.get("judge_error") or ""
if judge_error in UNSCORABLE_JUDGE_ERRORS:
return False
score = trial.get("judge_score")
if score is None:
sub_scores = trial.get("judge_subscores") or {}
scored = trial.get("judge_score") is not None or any(v is not None for v in sub_scores.values())
if not scored:
return True
if judge_error:
return True
Expand Down Expand Up @@ -621,7 +698,7 @@ def rescore_command(
entries = load_eval_manifest(Path(str(manifest_path)).expanduser().resolve())
entries_by_id = {e.entry_id: e for e in entries}

judge = _build_judge(cfg)
judge = _build_judge(cfg, manifest_path=Path(str(manifest_path)).expanduser().resolve())
if judge is None:
typer.echo("Error: judge is not configured (see messages above). Cannot rescore.", err=True)
raise typer.Exit(code=2)
Expand Down Expand Up @@ -659,15 +736,24 @@ def rescore_command(
result.judge_score = None
result.judge_reasoning = ""
result.judge_error = ""
result.judge_mode = ""
result.judge_subscores = {}
result.judge_flags = {}
result.judge_lists = {}

_apply_judge(judge, entry, result)

raw.update(asdict(result))
path.write_text(json.dumps(raw, indent=2) + "\n", encoding="utf-8")

if result.judge_score is not None:
scored_ok = result.judge_score is not None or any(v is not None for v in result.judge_subscores.values())
if scored_ok:
rescored += 1
typer.echo(f" {path.name}: entry_id={result.entry_id} judge={result.judge_score}")
if result.judge_mode == "simple" and result.judge_score is not None:
typer.echo(f" {path.name}: entry_id={result.entry_id} judge={result.judge_score}")
else:
n_sub = sum(1 for v in result.judge_subscores.values() if v is not None)
typer.echo(f" {path.name}: entry_id={result.entry_id} mode={result.judge_mode} sub_scores={n_sub}")
elif result.judge_error in UNSCORABLE_JUDGE_ERRORS:
unscorable += 1
typer.echo(f" {path.name}: entry_id={result.entry_id} unscorable ({result.judge_error})")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,20 @@ judge:
api_key_env: NVIDIA_API_KEY
temperature: 0.1
max_tokens: 4096
# On-disk path to the simple QA judge prompt (batch_1's llm_scorer_prompt.md).
# When null, the runner looks for ``llm_scorer_prompt.md`` next to the manifest.
# Required only if the manifest contains entries WITHOUT a `scoring_mode` field.
simple_prompt_path: null
# On-disk path to the scenario-aware judge prompt (batch_2's llm_scenario_scorer_prompt.md).
# When null, the runner looks for ``llm_scenario_scorer_prompt.md`` next to the manifest.
# Required only if the manifest contains entries WITH a `scoring_mode` field.
scenario_prompt_path: null
# When true (default), ALSO run the legacy hardcoded LLMJudge on every
# scoreable trial (one with a non-empty ground_truth_answer + non-empty
# final_answer). Produces an additional 1-5 score directly comparable to
# runs that pre-date the scenario-aware judge. Set to false to skip the
# second judge call and halve judge LLM cost.
legacy_enabled: true

# ---------------------------------------------------------------------------
# Tool-use summarizer
Expand Down
17 changes: 17 additions & 0 deletions nemo_retriever/src/nemo_retriever/skill_eval/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,13 @@ class DatasetEntry(BaseModel):
ground_truth_answer: str = ""
domain: str = ""
domain_label: str = ""
scoring_mode: str = ""
category: str = ""
phase: str = ""
expected_action: str = ""
expected_output_shape: str = ""
validation_signal: str = ""
raw_answers: list[str] = []


def _select_prompt(candidates: list[dict[str, Any]], selected_variant: int | None) -> str:
Expand Down Expand Up @@ -121,6 +128,9 @@ def load_eval_manifest(path: Path) -> list[DatasetEntry]:
continue
pages.append(GroundTruthPage(doc_id=str(doc_id), page_number=int(page), score=int(p.get("score") or 1)))

raw_answers = item.get("raw_answers") or []
if not isinstance(raw_answers, list):
raw_answers = []
entries.append(
DatasetEntry(
entry_id=idx,
Expand All @@ -132,6 +142,13 @@ def load_eval_manifest(path: Path) -> list[DatasetEntry]:
ground_truth_answer=str(item.get("answer") or ""),
domain=domain,
domain_label=domain_label,
scoring_mode=str(item.get("scoring_mode") or ""),
category=str(item.get("category") or ""),
phase=str(item.get("phase") or ""),
expected_action=str(item.get("expected_action") or ""),
expected_output_shape=str(item.get("expected_output_shape") or ""),
validation_signal=str(item.get("validation_signal") or ""),
raw_answers=[str(x) for x in raw_answers],
)
)
return entries
Expand Down
Loading
Loading