From 59a1c59dbccfd7b07aa9fbf9e358c63717510a48 Mon Sep 17 00:00:00 2001 From: Mahika Wason Date: Mon, 11 May 2026 15:40:40 -0700 Subject: [PATCH 1/6] agentic retrieval init Signed-off-by: Mahika Wason --- .../src/nemo_retriever/agentic/README.md | 140 +++++++ .../src/nemo_retriever/agentic/__init__.py | 21 + .../src/nemo_retriever/agentic/retrieval.py | 368 ++++++++++++++++++ .../graph/react_agent_operator.py | 90 ++++- .../graph/selection_agent_operator.py | 59 +++ .../src/nemo_retriever/pipeline/__main__.py | 198 ++++++++-- nemo_retriever/tests/test_agentic_eval.py | 201 ++++++++++ 7 files changed, 1043 insertions(+), 34 deletions(-) create mode 100644 nemo_retriever/src/nemo_retriever/agentic/README.md create mode 100644 nemo_retriever/src/nemo_retriever/agentic/__init__.py create mode 100644 nemo_retriever/src/nemo_retriever/agentic/retrieval.py create mode 100644 nemo_retriever/tests/test_agentic_eval.py diff --git a/nemo_retriever/src/nemo_retriever/agentic/README.md b/nemo_retriever/src/nemo_retriever/agentic/README.md new file mode 100644 index 0000000000..6fad0fe2b0 --- /dev/null +++ b/nemo_retriever/src/nemo_retriever/agentic/README.md @@ -0,0 +1,140 @@ +# Agentic Retrieval Mode + +Agentic retrieval mode is a retrieval strategy for the main NeMo Retriever +pipeline. It is not a separate evaluation benchmark. The evaluation mode still +answers "how do we score results?", while retrieval mode answers "how do we +produce ranked results?". + +The first integration supports: + +```bash +--evaluation-mode recall --retrieval-mode agentic +``` + +In this mode, the pipeline ingests documents and uploads them to the configured +vector database exactly as it does today. The difference starts at evaluation +time: instead of one standard dense retrieval pass, an LLM-driven graph +retrieval pipeline searches the same vector database and produces ranked +results that are scored with recall-style metrics. + +## Graph Pipeline + +The agentic retriever composes the existing graph operators: + +```mermaid +flowchart LR + QueryCsv[Query CSV] --> Normalize[Normalize Queries] + Normalize --> ReactAgent[ReActAgentOperator] + ReactAgent --> RetrieverTool[Retriever Tool] + RetrieverTool --> VDB[Vector DB] + ReactAgent --> RRFAggregator[RRFAggregatorOperator] + RRFAggregator --> SelectionAgent[SelectionAgentOperator] + SelectionAgent --> RankedResults[Ranked Results] + RankedResults --> Metrics[Recall Metrics] +``` + +`ReActAgentOperator` runs an LLM-driven ReAct loop per query. The agent can +think, issue retrieval subqueries, inspect retrieved candidates, and decide +when it has enough evidence. + +`RRFAggregatorOperator` combines candidates from multiple retrieval steps using +reciprocal rank fusion. This gives more weight to documents that appear near +the top across multiple search attempts. + +`SelectionAgentOperator` runs a final LLM-based selection pass over the fused +candidate set and emits the ranked document IDs used for scoring. + +## CLI Integration + +The main CLI adds a retrieval strategy option: + +```bash +--retrieval-mode standard|agentic +``` + +`--evaluation-mode` remains evaluation-focused: + +```bash +--evaluation-mode recall|beir|qa +``` + +Supported combinations in the first integration: + +- `--evaluation-mode=recall --retrieval-mode=standard` +- `--evaluation-mode=recall --retrieval-mode=agentic` +- `--evaluation-mode=qa --retrieval-mode=standard` + +Unsupported initially: + +- `--evaluation-mode=qa --retrieval-mode=agentic` +- BEIR through the generic pipeline path remains unchanged and unavailable, as + it is in the existing pipeline. + +## Agentic Options + +`--agentic-llm-model` sets the chat model used by both `ReActAgentOperator` and +`SelectionAgentOperator`. It is required when `--retrieval-mode=agentic`. + +`--agentic-invoke-url` optionally sets the OpenAI-compatible chat completions +endpoint used by the agent operators. If omitted, the operators use their +default endpoint. + +`--agentic-react-max-steps` controls the maximum ReAct loop iterations per +query. The default is `10`. + +## Wrapped Standard Retrieval + +Every agent `retrieve` tool call delegates to the existing +`nemo_retriever.retriever.Retriever`. That means agentic mode searches the same +vector database populated by ingestion and reuses the same retrieval settings +where possible. + +Existing options reused by the wrapped retriever: + +- `--api-key`: authentication for agentic LLM calls and remote services unless + a more specific key is provided. +- `--embed-model-name`, `--embed-invoke-url`, `--local-query-embed-backend`, + `--local-hf-batch-size`: query embedding configuration. +- `--reranker`, `--reranker-model-name`, `--reranker-invoke-url`, + `--reranker-api-key`, `--local-reranker-backend`: optional reranking inside + the wrapped retriever. + +The first integration intentionally keeps the lower-level agentic retrieval +parameters fixed: + +- retriever top-k: `10` +- target top-k: `10` +- selection top-k: `10` +- query concurrency: `1` +- text truncation: `500` +- max tokens: provider default +- parallel tool calls: disabled + +## Examples + +Local in-process run: + +```bash +retriever pipeline run ./data/bo767 \ + --run-mode inprocess \ + --evaluation-mode recall \ + --retrieval-mode agentic \ + --query-csv ./data/bo767_query_gt.csv \ + --agentic-llm-model meta/llama-3.3-70b-instruct \ + --api-key os.environ/NVIDIA_API_KEY +``` + +Batch run with remote embedding and agent endpoints: + +```bash +retriever pipeline run ./data/bo767 \ + --run-mode batch \ + --evaluation-mode recall \ + --retrieval-mode agentic \ + --query-csv ./data/bo767_query_gt.csv \ + --embed-invoke-url http://localhost:8000/v1 \ + --agentic-invoke-url http://localhost:9000/v1/chat/completions \ + --agentic-llm-model meta/llama-3.3-70b-instruct \ + --agentic-react-max-steps 5 \ + --api-key os.environ/NVIDIA_API_KEY +``` diff --git a/nemo_retriever/src/nemo_retriever/agentic/__init__.py b/nemo_retriever/src/nemo_retriever/agentic/__init__.py new file mode 100644 index 0000000000..f0536322bc --- /dev/null +++ b/nemo_retriever/src/nemo_retriever/agentic/__init__.py @@ -0,0 +1,21 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024-25, NVIDIA CORPORATION & AFFILIATES. +# All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +"""Agentic retrieval utilities.""" + +from nemo_retriever.agentic.retrieval import ( + AgenticRetrievalConfig, + AgenticRetriever, + build_beir_run_from_agentic_result, + build_qrels, + run_agentic_recall_evaluation, +) + +__all__ = [ + "AgenticRetrievalConfig", + "AgenticRetriever", + "build_beir_run_from_agentic_result", + "build_qrels", + "run_agentic_recall_evaluation", +] diff --git a/nemo_retriever/src/nemo_retriever/agentic/retrieval.py b/nemo_retriever/src/nemo_retriever/agentic/retrieval.py new file mode 100644 index 0000000000..cef60e9321 --- /dev/null +++ b/nemo_retriever/src/nemo_retriever/agentic/retrieval.py @@ -0,0 +1,368 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024-25, NVIDIA CORPORATION & AFFILIATES. +# All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +"""Graph-backed agentic retrieval mode. + +The implementation is intentionally additive: it composes the existing graph +operators and wraps :class:`nemo_retriever.retriever.Retriever` without changing +the standard retrieval path. +""" + +from __future__ import annotations + +import logging +import threading +import time +from dataclasses import dataclass, field +from pathlib import Path +from typing import Any, Optional, Sequence + +import pandas as pd + +from nemo_retriever.graph.abstract_operator import AbstractOperator +from nemo_retriever.model import VL_EMBED_MODEL, VL_RERANK_MODEL +from nemo_retriever.recall.beir import compute_beir_metrics +from nemo_retriever.recall.core import ( + _hit_to_audio_segment_key, + _normalize_pdf_name, + _normalize_query_df, +) +from nemo_retriever.retriever import Retriever + +logger = logging.getLogger(__name__) + +AGENTIC_TOP_K = 10 +AGENTIC_NUM_CONCURRENT = 1 +AGENTIC_TEXT_TRUNCATION = 500 +AGENTIC_PARALLEL_TOOL_CALLS = False +AGENTIC_RRF_K = 60 +AGENTIC_REACT_MAX_STEPS = 10 + + +class AgenticQueryInputOperator(AbstractOperator): + """Adapt ``Retriever(graph=...)`` input DataFrames to agentic query schema.""" + + def preprocess(self, data: Any, **kwargs: Any) -> pd.DataFrame: + _ = kwargs + if not isinstance(data, pd.DataFrame): + raise TypeError(f"AgenticQueryInputOperator expects a pd.DataFrame, got {type(data).__name__}.") + return data.copy() + + def process(self, data: pd.DataFrame, **kwargs: Any) -> pd.DataFrame: + _ = kwargs + out = data.copy() + if "query_text" not in out.columns: + if "query" in out.columns: + out["query_text"] = out["query"].astype(str) + elif "text" in out.columns: + out["query_text"] = out["text"].astype(str) + else: + raise ValueError("Agentic query graph input requires 'query_text', 'query', or 'text'.") + if "query_id" not in out.columns: + out["query_id"] = [str(idx) for idx in range(len(out.index))] + return out[["query_id", "query_text"]] + + def postprocess(self, data: pd.DataFrame, **kwargs: Any) -> pd.DataFrame: + _ = kwargs + return data + + +class AgenticSelectionOutputOperator(AbstractOperator): + """Convert final agentic selection DataFrame to ``Retriever`` hit-list output.""" + + def preprocess(self, data: Any, **kwargs: Any) -> pd.DataFrame: + _ = kwargs + if not isinstance(data, pd.DataFrame): + raise TypeError(f"AgenticSelectionOutputOperator expects a pd.DataFrame, got {type(data).__name__}.") + return data.copy() + + def process(self, data: pd.DataFrame, **kwargs: Any) -> list[list[dict[str, Any]]]: + _ = kwargs + if data.empty: + return [] + required = {"query_id", "doc_id", "rank"} + missing = required - set(data.columns) + if missing: + raise ValueError(f"Agentic selection output missing required columns: {sorted(missing)}") + + hits: list[list[dict[str, Any]]] = [] + for _query_id, group in data.groupby("query_id", sort=False): + query_hits: list[dict[str, Any]] = [] + for _, row in group.sort_values("rank").iterrows(): + hit = row.to_dict() + doc_id = str(hit.get("doc_id", "")) + if doc_id and not hit.get("pdf_page"): + hit["pdf_page"] = doc_id + query_hits.append(hit) + hits.append(query_hits) + return hits + + def postprocess(self, data: list[list[dict[str, Any]]], **kwargs: Any) -> list[list[dict[str, Any]]]: + _ = kwargs + return data + + +@dataclass(frozen=True) +class AgenticRetrievalConfig: + """Configuration for graph-backed agentic retrieval.""" + + vdb_op: str = "lancedb" + vdb_kwargs: dict[str, Any] = field(default_factory=dict) + query_embedder: str = VL_EMBED_MODEL + embedding_endpoint: Optional[str] = None + embedding_api_key: str = "" + embedding_use_grpc: Optional[bool] = None + local_hf_batch_size: int = 32 + local_query_embed_backend: str = "hf" + reranker: Optional[str] = None + reranker_endpoint: Optional[str] = None + reranker_api_key: str = "" + local_reranker_backend: str = "vllm" + embed_modality: str = "text" + llm_model: str = "" + invoke_url: Optional[str] = None + api_key: Optional[str] = None + react_max_steps: int = AGENTIC_REACT_MAX_STEPS + + def __post_init__(self) -> None: + if not str(self.llm_model).strip(): + raise ValueError("Agentic retrieval requires a non-empty llm_model.") + if int(self.react_max_steps) < 1: + raise ValueError("react_max_steps must be >= 1.") + + +class AgenticRetriever: + """Run graph-backed agentic retrieval over query IDs and query texts.""" + + def __init__(self, cfg: AgenticRetrievalConfig, *, match_mode: str = "pdf_page") -> None: + self._cfg = cfg + self._match_mode = str(match_mode) + self._retriever = Retriever( + vdb_kwargs={ + "vdb_op": str(cfg.vdb_op), + "vdb_kwargs": dict(cfg.vdb_kwargs or {}), + }, + embed_kwargs={ + "model_name": str(cfg.query_embedder or VL_EMBED_MODEL), + "embed_model_name": str(cfg.query_embedder or VL_EMBED_MODEL), + "embedding_endpoint": cfg.embedding_endpoint, + "api_key": cfg.embedding_api_key, + "input_type": "query", + "local_ingest_embed_backend": str(cfg.local_query_embed_backend), + "inference_batch_size": int(cfg.local_hf_batch_size), + "embed_inference_batch_size": int(cfg.local_hf_batch_size), + }, + top_k=AGENTIC_TOP_K, + rerank=bool(cfg.reranker), + rerank_kwargs={ + "model_name": cfg.reranker or VL_RERANK_MODEL, + "invoke_url": cfg.reranker_endpoint, + "api_key": cfg.reranker_api_key, + "local_reranker_backend": str(cfg.local_reranker_backend), + "modality": str(cfg.embed_modality), + }, + ) + self._lock = threading.Lock() + + def retrieve(self, query_ids: Sequence[str], query_texts: Sequence[str]) -> pd.DataFrame: + """Return selected ranked documents for each query. + + The output schema matches ``SelectionAgentOperator``: ``query_id``, + ``doc_id``, ``rank``, and ``message``. + """ + + if len(query_ids) != len(query_texts): + raise ValueError("query_ids and query_texts must have the same length.") + + from nemo_retriever.graph.react_agent_operator import ReActAgentOperator + from nemo_retriever.graph.rrf_aggregator_operator import RRFAggregatorOperator + from nemo_retriever.graph.selection_agent_operator import SelectionAgentOperator + + pipeline = ( + AgenticQueryInputOperator() + >> ReActAgentOperator( + invoke_url=_none_if_empty(self._cfg.invoke_url), + llm_model=str(self._cfg.llm_model), + retriever_fn=self._retrieve_for_agent, + retriever_top_k=AGENTIC_TOP_K, + target_top_k=AGENTIC_TOP_K, + user_msg_type="with_results", + max_steps=int(self._cfg.react_max_steps), + api_key=_none_if_empty(self._cfg.api_key), + parallel_tool_calls=AGENTIC_PARALLEL_TOOL_CALLS, + num_concurrent=AGENTIC_NUM_CONCURRENT, + ) + >> RRFAggregatorOperator(k=AGENTIC_RRF_K) + >> SelectionAgentOperator( + invoke_url=_none_if_empty(self._cfg.invoke_url), + llm_model=str(self._cfg.llm_model), + top_k=AGENTIC_TOP_K, + api_key=_none_if_empty(self._cfg.api_key), + parallel_tool_calls=AGENTIC_PARALLEL_TOOL_CALLS, + ) + >> AgenticSelectionOutputOperator() + ) + graph_retriever = Retriever( + graph=pipeline, + top_k=AGENTIC_TOP_K, + embed_kwargs={"text_column": "query_text"}, + ) + raw_hits = graph_retriever.queries([str(query_text) for query_text in query_texts], top_k=AGENTIC_TOP_K) + return _raw_hits_to_agentic_result([str(query_id) for query_id in query_ids], raw_hits) + + def _retrieve_for_agent(self, query_text: str, top_k: int) -> list[dict[str, Any]]: + """Retriever callback used by ``ReActAgentOperator``.""" + + with self._lock: + hits = self._retriever.query(str(query_text), top_k=int(top_k)) + + docs: list[dict[str, Any]] = [] + for hit in hits: + doc_id = _doc_id_for_match_mode(dict(hit), match_mode=self._match_mode) + if not doc_id: + continue + docs.append( + { + "doc_id": doc_id, + "text": str(hit.get("text", ""))[:AGENTIC_TEXT_TRUNCATION], + "score": _hit_score(hit), + } + ) + if len(docs) >= int(top_k): + break + return docs + + +def run_agentic_recall_evaluation( + *, + query_csv: Path, + cfg: AgenticRetrievalConfig, + match_mode: str, + ks: Sequence[int] = (1, 5, 10), +) -> tuple[pd.DataFrame, pd.DataFrame, dict[str, dict[str, int]], dict[str, dict[str, float]], dict[str, float]]: + """Run agentic retrieval for a recall query CSV and compute metrics.""" + + df_query = _normalize_query_df(pd.read_csv(Path(query_csv)), match_mode=str(match_mode)) + query_ids = [str(idx) for idx in df_query.index] + query_texts = df_query["query"].astype(str).tolist() + qrels = build_qrels(query_ids, df_query["golden_answer"].astype(str).tolist()) + + start = time.time() + result = AgenticRetriever(cfg, match_mode=str(match_mode)).retrieve(query_ids, query_texts) + elapsed = time.time() - start + if elapsed > 0: + logger.info( + "Agentic retrieval time for %d queries: %.2f seconds (average %.2f queries/second)", + len(query_ids), + elapsed, + len(query_ids) / elapsed, + ) + + run = build_beir_run_from_agentic_result(query_ids, result) + metrics = compute_beir_metrics(qrels, run, ks=ks) + return df_query, result, qrels, run, metrics + + +def build_qrels(query_ids: Sequence[str], gold_keys: Sequence[str]) -> dict[str, dict[str, int]]: + """Build BEIR-style qrels from normalized recall gold keys.""" + + if len(query_ids) != len(gold_keys): + raise ValueError("query_ids and gold_keys must have the same length.") + return {str(query_id): {str(gold_key): 1} for query_id, gold_key in zip(query_ids, gold_keys)} + + +def build_beir_run_from_agentic_result( + query_ids: Sequence[str], + result: pd.DataFrame, +) -> dict[str, dict[str, float]]: + """Convert ``SelectionAgentOperator`` output to BEIR run format.""" + + run: dict[str, dict[str, float]] = {str(query_id): {} for query_id in query_ids} + if result.empty: + return run + + required = {"query_id", "doc_id", "rank"} + missing = required - set(result.columns) + if missing: + raise ValueError(f"Agentic result missing required columns: {sorted(missing)}") + + for query_id, group in result.groupby("query_id", sort=False): + ordered = group.sort_values("rank") + n = len(ordered.index) + scores: dict[str, float] = {} + for rank, (_, row) in enumerate(ordered.iterrows(), start=1): + doc_id = str(row["doc_id"]) + if doc_id and doc_id not in scores: + scores[doc_id] = float(n - rank + 1) + run[str(query_id)] = scores + return run + + +def _raw_hits_to_agentic_result(query_ids: Sequence[str], raw_hits: Sequence[Sequence[dict[str, Any]]]) -> pd.DataFrame: + rows: list[dict[str, Any]] = [] + for query_id, hits in zip(query_ids, raw_hits): + for rank, hit in enumerate(hits, start=1): + rows.append( + { + "query_id": str(query_id), + "doc_id": str(hit.get("doc_id") or hit.get("pdf_page") or ""), + "rank": int(hit.get("rank", rank)), + "message": str(hit.get("message", "")), + } + ) + if not rows: + return pd.DataFrame(columns=["query_id", "doc_id", "rank", "message"]) + return pd.DataFrame(rows) + + +def _doc_id_for_match_mode(hit: dict[str, Any], *, match_mode: str) -> str: + if match_mode == "audio_segment": + return _hit_to_audio_segment_key(hit) or "" + if match_mode == "pdf_only": + return _doc_id_from_hit(hit) + return _pdf_page_from_hit(hit) + + +def _pdf_page_from_hit(hit: dict[str, Any]) -> str: + pdf_page = hit.get("pdf_page") + if isinstance(pdf_page, str) and pdf_page.strip(): + return pdf_page.strip() + + source = hit.get("source") or hit.get("source_id") or hit.get("path") + page_number = hit.get("page_number") + if source and page_number is not None: + return f"{Path(str(source)).stem}_{page_number}" + return _doc_id_from_hit(hit) + + +def _doc_id_from_hit(hit: dict[str, Any]) -> str: + for key in ("pdf_basename", "source_id", "path", "source", "doc_id"): + value = hit.get(key) + if isinstance(value, str) and value.strip(): + return _normalize_pdf_name(Path(value).stem) + return "" + + +def _hit_score(hit: dict[str, Any]) -> float: + for key in ("_rerank_score", "_score", "score"): + if key in hit: + try: + return float(hit[key]) + except (TypeError, ValueError): + return 0.0 + if "_distance" in hit: + try: + return -float(hit["_distance"]) + except (TypeError, ValueError): + return 0.0 + return 0.0 + + +def _none_if_empty(value: Optional[str]) -> Optional[str]: + if value is None: + return None + stripped = str(value).strip() + if not stripped or stripped.lower() in {"none", "null"}: + return None + return stripped diff --git a/nemo_retriever/src/nemo_retriever/graph/react_agent_operator.py b/nemo_retriever/src/nemo_retriever/graph/react_agent_operator.py index 94777db5d9..c2bf7fe6ef 100644 --- a/nemo_retriever/src/nemo_retriever/graph/react_agent_operator.py +++ b/nemo_retriever/src/nemo_retriever/graph/react_agent_operator.py @@ -22,6 +22,20 @@ logger = logging.getLogger(__name__) +_LOG_PREVIEW_CHARS = 300 +_LOG_DOC_ID_LIMIT = 10 + + +def _preview_text(value: Any, *, limit: int = _LOG_PREVIEW_CHARS) -> str: + text = " ".join(str(value or "").split()) + if len(text) <= limit: + return text + return text[:limit].rstrip() + "..." + + +def _preview_doc_ids(docs: List[Dict[str, Any]], *, limit: int = _LOG_DOC_ID_LIMIT) -> List[str]: + return [str(doc.get("doc_id", "")) for doc in docs[:limit]] + # --------------------------------------------------------------------------- # Prompt rendering (verbatim content of 02_v1.j2, rendered via Python) @@ -502,6 +516,15 @@ def _run_single_query( seen_doc_ids: set[str] = set() step_counter = 0 + logger.info( + "ReActAgentOperator: query=%s start max_steps=%d retriever_top_k=%d target_top_k=%d query=%r", + query_id, + self._max_steps, + self._retriever_top_k, + self._target_top_k, + _preview_text(query_text), + ) + # ------ optional initial retrieval (with_results mode) ------ if with_init_docs: init_docs = self._call_retriever(query_text, seen_doc_ids, api_key) @@ -509,6 +532,13 @@ def _run_single_query( step_counter += 1 for d in init_docs: seen_doc_ids.add(d["doc_id"]) + logger.info( + "ReActAgentOperator: query=%s initial_retrieve docs=%d seen=%d doc_ids=%s", + query_id, + len(init_docs), + len(seen_doc_ids), + _preview_doc_ids(init_docs), + ) doc_content = _docs_to_message_content(init_docs) user_msg_content: List[Dict[str, Any]] = [ @@ -522,7 +552,7 @@ def _run_single_query( # ------ main ReAct loop ------ for _step in range(self._max_steps): - logger.debug("query=%r loop_step=%d seen_docs=%d", query_id, _step, len(seen_doc_ids)) + logger.info("ReActAgentOperator: query=%s step=%d begin seen_docs=%d", query_id, _step, len(seen_doc_ids)) try: response = invoke_chat_completion_step( invoke_url=self._invoke_url, @@ -580,6 +610,20 @@ def _run_single_query( msg = choice["message"] finish_reason = choice.get("finish_reason") tool_calls = msg.get("tool_calls") or [] + if msg.get("content"): + logger.info( + "ReActAgentOperator: query=%s step=%d assistant content=%r", + query_id, + _step, + _preview_text(msg.get("content")), + ) + logger.info( + "ReActAgentOperator: query=%s step=%d finish_reason=%s tool_calls=%s", + query_id, + _step, + finish_reason, + [((tc.get("function") or {}).get("name") or "") for tc in tool_calls], + ) # Append assistant turn assistant_turn: Dict[str, Any] = {"role": "assistant"} @@ -590,6 +634,11 @@ def _run_single_query( messages.append(assistant_turn) if finish_reason == "stop" or not tool_calls: + logger.info( + "ReActAgentOperator: query=%s step=%d no tool call; requesting continuation", + query_id, + _step, + ) messages.append({"role": "user", "content": _AUTO_USER_MSG}) continue @@ -609,20 +658,38 @@ def _run_single_query( continue if fn_name == "think": - logger.debug("query=%r step=%d [think] %s", query_id, _step, str(fn_args.get("thought", ""))[:120]) + logger.info( + "ReActAgentOperator: query=%s step=%d think=%r", + query_id, + _step, + _preview_text(fn_args.get("thought")), + ) tool_messages.append( {"role": "tool", "tool_call_id": tc_id, "content": "Your thought has been logged."} ) elif fn_name == "retrieve": subquery = str(fn_args.get("query", query_text)) - logger.debug("query=%r step=%d [retrieve] subquery=%r", query_id, _step, subquery) + logger.info( + "ReActAgentOperator: query=%s step=%d retrieve subquery=%r seen_before=%d", + query_id, + _step, + _preview_text(subquery), + len(seen_doc_ids), + ) retrieved = self._call_retriever(subquery, seen_doc_ids, api_key) - logger.debug("query=%r step=%d [retrieve] got %d new docs", query_id, _step, len(retrieved)) retrieval_log.append(retrieved) step_counter += 1 for d in retrieved: seen_doc_ids.add(d["doc_id"]) + logger.info( + "ReActAgentOperator: query=%s step=%d retrieve docs=%d seen_after=%d doc_ids=%s", + query_id, + _step, + len(retrieved), + len(seen_doc_ids), + _preview_doc_ids(retrieved), + ) doc_content = _docs_to_message_content(retrieved) tool_content: List[Dict[str, Any]] = [ {"type": "text", "text": f"Retrieved {len(retrieved)} documents:"} @@ -631,7 +698,13 @@ def _run_single_query( elif fn_name == "final_results": raw_ids: List[str] = fn_args.get("doc_ids", []) - logger.debug("query=%r step=%d [final_results] doc_ids=%s", query_id, _step, raw_ids) + logger.info( + "ReActAgentOperator: query=%s step=%d final_results doc_ids=%s message=%r", + query_id, + _step, + raw_ids[:_LOG_DOC_ID_LIMIT] if isinstance(raw_ids, list) else raw_ids, + _preview_text(fn_args.get("message")), + ) if isinstance(raw_ids, list) and raw_ids: final_doc_ids = [str(d) for d in raw_ids] tool_messages.append( @@ -652,6 +725,13 @@ def _run_single_query( if loop_done: break + logger.info( + "ReActAgentOperator: query=%s done retrieval_steps=%d seen_docs=%d final_doc_ids=%s", + query_id, + len(retrieval_log), + len(seen_doc_ids), + final_doc_ids[:_LOG_DOC_ID_LIMIT] if final_doc_ids else [], + ) return _build_output_rows(query_id, query_text, retrieval_log, final_doc_ids) def _call_retriever( diff --git a/nemo_retriever/src/nemo_retriever/graph/selection_agent_operator.py b/nemo_retriever/src/nemo_retriever/graph/selection_agent_operator.py index 14371969a8..bddb2bbb37 100644 --- a/nemo_retriever/src/nemo_retriever/graph/selection_agent_operator.py +++ b/nemo_retriever/src/nemo_retriever/graph/selection_agent_operator.py @@ -21,6 +21,16 @@ logger = logging.getLogger(__name__) +_LOG_PREVIEW_CHARS = 300 +_LOG_DOC_ID_LIMIT = 10 + + +def _preview_text(value: Any, *, limit: int = _LOG_PREVIEW_CHARS) -> str: + text = " ".join(str(value or "").split()) + if len(text) <= limit: + return text + return text[:limit].rstrip() + "..." + # --------------------------------------------------------------------------- # Prompt rendering (verbatim content of 01_v0.j2, rendered via Python) # --------------------------------------------------------------------------- @@ -237,8 +247,21 @@ def process(self, data: pd.DataFrame, **kwargs: Any) -> pd.DataFrame: for query_id, group in data.groupby("query_id", sort=False): query_text = str(group["query_text"].iloc[0]) docs = [{"id": str(row["doc_id"]), "text": str(row["text"])} for _, row in group.iterrows()] + logger.info( + "SelectionAgentOperator: query=%s start candidates=%d unique_candidates=%d query=%r", + query_id, + len(docs), + len({doc["id"] for doc in docs}), + _preview_text(query_text), + ) result = self._select_documents(query_text, docs) message = result.get("message", "") + logger.info( + "SelectionAgentOperator: query=%s selected=%s message=%r", + query_id, + result.get("doc_ids", [])[:_LOG_DOC_ID_LIMIT], + _preview_text(message), + ) for rank, doc_id in enumerate(result.get("doc_ids", []), 1): rows.append( { @@ -363,6 +386,12 @@ def _select_documents( """Run the agentic selection loop for a single query.""" valid_ids = list(dict.fromkeys(d["id"] for d in docs)) feasible_k = min(self._top_k, len(valid_ids)) + logger.info( + "SelectionAgentOperator: selecting top_k=%d feasible_k=%d valid_doc_ids=%s", + self._top_k, + feasible_k, + valid_ids[:_LOG_DOC_ID_LIMIT], + ) system_prompt = self._build_system_prompt(feasible_k) tools = self._build_tools(feasible_k, valid_ids) @@ -379,6 +408,12 @@ def _select_documents( extra_body["parallel_tool_calls"] = False for _step in range(self._max_steps): + logger.info( + "SelectionAgentOperator: step=%d begin candidates=%d feasible_k=%d", + _step, + len(valid_ids), + feasible_k, + ) try: response = invoke_chat_completion_step( invoke_url=self._invoke_url, @@ -438,12 +473,24 @@ def _select_documents( assistant_turn: Dict[str, Any] = {"role": "assistant"} if msg.get("content"): assistant_turn["content"] = msg["content"] + logger.info( + "SelectionAgentOperator: step=%d assistant content=%r", + _step, + _preview_text(msg.get("content")), + ) tool_calls = msg.get("tool_calls") or [] + logger.info( + "SelectionAgentOperator: step=%d finish_reason=%s tool_calls=%s", + _step, + finish_reason, + [((tc.get("function") or {}).get("name") or "") for tc in tool_calls], + ) if tool_calls: assistant_turn["tool_calls"] = tool_calls messages.append(assistant_turn) if finish_reason == "stop" or not tool_calls: + logger.info("SelectionAgentOperator: step=%d no tool call; asking for final selection", _step) messages.append( { "role": "user", @@ -468,6 +515,11 @@ def _select_documents( continue if fn.get("name") == "think": + logger.info( + "SelectionAgentOperator: step=%d think=%r", + _step, + _preview_text(fn_args.get("thought")), + ) tool_messages.append( {"role": "tool", "tool_call_id": tc_id, "content": "Your thought has been logged."} ) @@ -480,6 +532,13 @@ def _select_documents( except json.JSONDecodeError: raw_doc_ids = [] doc_ids = [d for d in raw_doc_ids if d in valid_id_set][:feasible_k] + logger.info( + "SelectionAgentOperator: step=%d log_selected_documents raw=%s accepted=%s message=%r", + _step, + raw_doc_ids[:_LOG_DOC_ID_LIMIT] if isinstance(raw_doc_ids, list) else raw_doc_ids, + doc_ids[:_LOG_DOC_ID_LIMIT], + _preview_text(fn_args.get("message")), + ) if not doc_ids and raw_doc_ids: logger.warning( "SelectionAgentOperator: LLM returned %d doc_id(s) for query %r " diff --git a/nemo_retriever/src/nemo_retriever/pipeline/__main__.py b/nemo_retriever/src/nemo_retriever/pipeline/__main__.py index f0e4c8113e..c8c647a5f9 100644 --- a/nemo_retriever/src/nemo_retriever/pipeline/__main__.py +++ b/nemo_retriever/src/nemo_retriever/pipeline/__main__.py @@ -99,6 +99,8 @@ ) DEFAULT_VDB_OP = "lancedb" +AGENTIC_LLM_MODEL_ENV = "NEMO_RETRIEVER_AGENTIC_LLM_MODEL" +AGENTIC_INVOKE_URL_ENV = "NEMO_RETRIEVER_AGENTIC_INVOKE_URL" # Help panel labels (keep stable so --help groupings read consistently). _PANEL_IO = "I/O and Execution" @@ -818,6 +820,76 @@ def _run_evaluation( return "Audio Recall", time.perf_counter() - evaluation_start, metrics, len(df_query.index), True +def _run_agentic_recall_evaluation( + *, + vdb_op: str, + vdb_kwargs: dict[str, Any], + embed_model_name: str, + embed_invoke_url: Optional[str], + embed_remote_api_key: Optional[str], + embed_modality: str, + query_csv: Path, + recall_match_mode: str, + reranker: Optional[bool], + reranker_model_name: str, + reranker_invoke_url: Optional[str], + reranker_api_key: str, + local_reranker_backend: str, + local_hf_batch_size: int, + local_query_embed_backend: str, + agentic_llm_model: str, + agentic_invoke_url: Optional[str], + agentic_api_key: Optional[str], + agentic_react_max_steps: int, +) -> tuple[str, float, dict[str, float], Optional[int], bool]: + """Run recall evaluation using graph-backed agentic retrieval.""" + + query_csv_path = Path(query_csv) + if not query_csv_path.exists(): + logger.warning("Query CSV not found at %s; skipping agentic recall evaluation.", query_csv_path) + return "Agentic Recall", 0.0, {}, None, False + + from nemo_retriever.agentic.retrieval import AgenticRetrievalConfig, run_agentic_recall_evaluation + from nemo_retriever.model import resolve_embed_model + + embed_model = resolve_embed_model(str(embed_model_name)) + cfg = AgenticRetrievalConfig( + vdb_op=str(vdb_op), + vdb_kwargs=dict(vdb_kwargs or {}), + query_embedder=embed_model, + embedding_endpoint=embed_invoke_url, + embedding_api_key=embed_remote_api_key or "", + embedding_use_grpc=False if embed_invoke_url else None, + local_hf_batch_size=int(local_hf_batch_size), + local_query_embed_backend=local_query_embed_backend, + reranker=reranker_model_name if reranker else None, + reranker_endpoint=reranker_invoke_url, + reranker_api_key=reranker_api_key, + local_reranker_backend=local_reranker_backend, + embed_modality=embed_modality, + llm_model=agentic_llm_model, + invoke_url=agentic_invoke_url, + api_key=agentic_api_key, + react_max_steps=int(agentic_react_max_steps), + ) + evaluation_start = time.perf_counter() + df_query, _result, _qrels, _run, metrics = run_agentic_recall_evaluation( + query_csv=query_csv_path, + cfg=cfg, + match_mode=recall_match_mode, + ks=(1, 5, 10), + ) + logger.info("Agentic recall gold ids: %s", {qid: list(docs.keys()) for qid, docs in _qrels.items()}) + logger.info("Agentic recall retrieved ids: %s", {qid: list(docs.keys())[:10] for qid, docs in _run.items()}) + logger.info("Agentic recall result columns: %s", list(_result.columns)) + if {"query_id", "doc_id", "rank"}.issubset(_result.columns): + logger.info( + "Agentic recall top result rows:\n%s", + _result[["query_id", "doc_id", "rank"]].head(20).to_string(index=False), + ) + return "Agentic Recall", time.perf_counter() - evaluation_start, metrics, len(df_query.index), True + + # --------------------------------------------------------------------------- # Typer command: `retriever pipeline run` # --------------------------------------------------------------------------- @@ -1219,6 +1291,12 @@ def run( help="Post-ingest evaluation: none (default), audio_recall, beir, or qa.", rich_help_panel=_PANEL_EVAL, ), + retrieval_mode: str = typer.Option( + "standard", + "--retrieval-mode", + help="Retrieval strategy for evaluation: 'standard' (default) or 'agentic' for recall evaluation.", + rich_help_panel=_PANEL_EVAL, + ), query_csv: Path = typer.Option( "./data/bo767_query_gt.csv", "--query-csv", @@ -1267,6 +1345,28 @@ def run( help="Fixed token length for local HF query embeddings; longer queries are truncated.", rich_help_panel=_PANEL_EVAL, ), + agentic_llm_model: Optional[str] = typer.Option( + None, + "--agentic-llm-model", + help=f"Chat model for --retrieval-mode=agentic; may also be set with {AGENTIC_LLM_MODEL_ENV}.", + rich_help_panel=_PANEL_EVAL, + ), + agentic_invoke_url: Optional[str] = typer.Option( + None, + "--agentic-invoke-url", + help=( + "OpenAI-compatible chat completions endpoint for --retrieval-mode=agentic; " + f"may also be set with {AGENTIC_INVOKE_URL_ENV}." + ), + rich_help_panel=_PANEL_EVAL, + ), + agentic_react_max_steps: int = typer.Option( + 10, + "--agentic-react-max-steps", + min=1, + help="Maximum ReAct loop iterations per query for --retrieval-mode=agentic.", + rich_help_panel=_PANEL_EVAL, + ), beir_loader: Optional[str] = typer.Option(None, "--beir-loader", rich_help_panel=_PANEL_EVAL), beir_dataset_name: Optional[str] = typer.Option(None, "--beir-dataset-name", rich_help_panel=_PANEL_EVAL), beir_split: str = typer.Option("test", "--beir-split", rich_help_panel=_PANEL_EVAL), @@ -1325,6 +1425,12 @@ def run( raise ValueError(f"Unsupported --audio-split-type: {audio_split_type!r}") if evaluation_mode not in {"none", "audio_recall", "beir", "qa"}: raise ValueError(f"Unsupported --evaluation-mode: {evaluation_mode!r}") + if retrieval_mode not in {"standard", "agentic"}: + raise ValueError(f"Unsupported --retrieval-mode: {retrieval_mode!r}") + if retrieval_mode == "agentic" and evaluation_mode != "audio_recall": + raise typer.BadParameter( + "--retrieval-mode=agentic is currently supported only with --evaluation-mode=audio_recall." + ) if evaluation_mode == "audio_recall": if input_type != "audio": raise ValueError("--evaluation-mode=audio_recall is only supported with --input-type=audio") @@ -1335,6 +1441,12 @@ def run( "--evaluation-mode=qa requires --eval-config (QA sweep YAML/JSON). " "Use the same file format as `retriever eval run --config` (dataset, retrieval, models, ...)." ) + resolved_agentic_llm_model = (agentic_llm_model or os.environ.get(AGENTIC_LLM_MODEL_ENV) or "").strip() + resolved_agentic_invoke_url = (agentic_invoke_url or os.environ.get(AGENTIC_INVOKE_URL_ENV) or "").strip() or None + if retrieval_mode == "agentic" and not resolved_agentic_llm_model: + raise typer.BadParameter( + f"--retrieval-mode=agentic requires --agentic-llm-model or {AGENTIC_LLM_MODEL_ENV}." + ) if run_mode == "batch": # --quiet implies --no-ray-log-to-driver: Ray flushes worker stdout @@ -1742,6 +1854,7 @@ def _run_ingest() -> Any: "evaluation_secs": float(evaluation_total_time), "total_secs": float(total_time), "evaluation_mode": "qa", + "retrieval_mode": retrieval_mode, "evaluation_metrics": {}, "evaluation_count": None, "recall_details": bool(recall_details), @@ -1778,35 +1891,60 @@ def _run_ingest() -> Any: raise typer.Exit(code=qa_code) return - evaluation_label, evaluation_total_time, evaluation_metrics, evaluation_query_count, ran = _run_evaluation( - evaluation_mode=evaluation_mode, - vdb_op=resolved_vdb_op, - vdb_kwargs=resolved_vdb_kwargs, - embed_model_name=embed_model_name, - embed_invoke_url=embed_invoke_url, - embed_remote_api_key=embed_remote_api_key, - embed_modality=embed_modality, - query_csv=query_csv, - recall_match_mode=recall_match_mode, - audio_match_tolerance_secs=audio_match_tolerance_secs, - reranker=reranker, - reranker_model_name=reranker_model_name, - reranker_invoke_url=reranker_invoke_url, - reranker_api_key=reranker_bearer, - local_reranker_backend=local_reranker_backend, - local_hf_batch_size=local_hf_batch_size, - local_query_max_length=local_query_max_length, - beir_loader=beir_loader, - beir_dataset_name=beir_dataset_name, - beir_split=beir_split, - beir_query_language=beir_query_language, - beir_doc_id_field=beir_doc_id_field, - beir_k=beir_k, - local_query_embed_backend=local_query_embed_backend, - run_mode=run_mode, - service_url=service_url, - service_api_token=service_api_token, - ) + if retrieval_mode == "agentic": + evaluation_label, evaluation_total_time, evaluation_metrics, evaluation_query_count, ran = ( + _run_agentic_recall_evaluation( + vdb_op=resolved_vdb_op, + vdb_kwargs=resolved_vdb_kwargs, + embed_model_name=embed_model_name, + embed_invoke_url=embed_invoke_url, + embed_remote_api_key=embed_remote_api_key, + embed_modality=embed_modality, + query_csv=query_csv, + recall_match_mode=recall_match_mode, + reranker=reranker, + reranker_model_name=reranker_model_name, + reranker_invoke_url=reranker_invoke_url, + reranker_api_key=reranker_bearer, + local_reranker_backend=local_reranker_backend, + local_hf_batch_size=local_hf_batch_size, + local_query_embed_backend=local_query_embed_backend, + agentic_llm_model=resolved_agentic_llm_model, + agentic_invoke_url=resolved_agentic_invoke_url, + agentic_api_key=remote_api_key, + agentic_react_max_steps=agentic_react_max_steps, + ) + ) + else: + evaluation_label, evaluation_total_time, evaluation_metrics, evaluation_query_count, ran = _run_evaluation( + evaluation_mode=evaluation_mode, + vdb_op=resolved_vdb_op, + vdb_kwargs=resolved_vdb_kwargs, + embed_model_name=embed_model_name, + embed_invoke_url=embed_invoke_url, + embed_remote_api_key=embed_remote_api_key, + embed_modality=embed_modality, + query_csv=query_csv, + recall_match_mode=recall_match_mode, + audio_match_tolerance_secs=audio_match_tolerance_secs, + reranker=reranker, + reranker_model_name=reranker_model_name, + reranker_invoke_url=reranker_invoke_url, + reranker_api_key=reranker_bearer, + local_reranker_backend=local_reranker_backend, + local_hf_batch_size=local_hf_batch_size, + local_query_max_length=local_query_max_length, + beir_loader=beir_loader, + beir_dataset_name=beir_dataset_name, + beir_split=beir_split, + beir_query_language=beir_query_language, + beir_doc_id_field=beir_doc_id_field, + beir_k=beir_k, + local_query_embed_backend=local_query_embed_backend, + run_mode=run_mode, + service_url=service_url, + service_api_token=service_api_token, + ) if not ran: no_eval_total_time = time.perf_counter() - ingest_start @@ -1825,6 +1963,7 @@ def _run_ingest() -> Any: "evaluation_secs": 0.0, "total_secs": float(no_eval_total_time), "evaluation_mode": evaluation_mode, + "retrieval_mode": retrieval_mode, "evaluation_metrics": {}, "recall_details": bool(recall_details), "vdb_op": str(resolved_vdb_op), @@ -1857,6 +1996,7 @@ def _run_ingest() -> Any: "evaluation_secs": float(evaluation_total_time), "total_secs": float(total_time), "evaluation_mode": evaluation_mode, + "retrieval_mode": retrieval_mode, "evaluation_metrics": dict(evaluation_metrics), "evaluation_count": evaluation_query_count, "recall_details": bool(recall_details), diff --git a/nemo_retriever/tests/test_agentic_eval.py b/nemo_retriever/tests/test_agentic_eval.py new file mode 100644 index 0000000000..05f050ca9a --- /dev/null +++ b/nemo_retriever/tests/test_agentic_eval.py @@ -0,0 +1,201 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024-25, NVIDIA CORPORATION & AFFILIATES. +# All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +from __future__ import annotations + +import json +from unittest.mock import patch + +import pandas as pd +import pytest +from typer.testing import CliRunner + + +def _make_tool_call_response(fn_name: str, fn_args: dict, tc_id: str = "call_1") -> dict: + return { + "choices": [ + { + "message": { + "content": None, + "tool_calls": [ + { + "id": tc_id, + "type": "function", + "function": {"name": fn_name, "arguments": json.dumps(fn_args)}, + } + ], + }, + "finish_reason": "tool_calls", + } + ] + } + + +class FakeRetriever: + def __init__(self, **kwargs): + self.kwargs = kwargs + self.graph = kwargs.get("graph") + self.top_k = int(kwargs.get("top_k", 10)) + + def query(self, query: str, *, top_k: int | None = None): + if self.graph is not None: + return self.queries([query], top_k=top_k)[0] + _ = query + hits = [ + { + "source": "/tmp/doc.pdf", + "source_id": "/tmp/doc.pdf", + "page_number": 1, + "pdf_page": "doc_1", + "text": "matching document", + "_score": 0.9, + }, + { + "source": "/tmp/other.pdf", + "source_id": "/tmp/other.pdf", + "page_number": 2, + "pdf_page": "other_2", + "text": "other document", + "_score": 0.1, + }, + ] + return hits[:top_k] + + def queries(self, queries, *, top_k: int | None = None): + if self.graph is None: + return [self.query(query, top_k=top_k) for query in queries] + limit = int(top_k) if top_k is not None else self.top_k + df = pd.DataFrame({"query_text": [str(query) for query in queries]}) + graph = self.graph.resolve_for_local_execution() + raw_hits = graph.execute(df)[0] + return [list(hits)[:limit] for hits in raw_hits] + + +def test_build_qrels_requires_aligned_lengths(): + from nemo_retriever.agentic.retrieval import build_qrels + + with pytest.raises(ValueError, match="same length"): + build_qrels(["q1"], ["doc_1", "doc_2"]) + + +def test_build_beir_run_from_agentic_result_orders_by_rank(): + from nemo_retriever.agentic.retrieval import build_beir_run_from_agentic_result + + result = pd.DataFrame( + { + "query_id": ["q1", "q1", "q1"], + "doc_id": ["d2", "d1", "d3"], + "rank": [2, 1, 3], + "message": ["ok", "ok", "ok"], + } + ) + run = build_beir_run_from_agentic_result(["q1", "q2"], result) + + assert list(run["q1"]) == ["d1", "d2", "d3"] + assert run["q1"]["d1"] > run["q1"]["d2"] > run["q1"]["d3"] + assert run["q2"] == {} + + +@patch("nemo_retriever.graph.selection_agent_operator.invoke_chat_completion_step") +@patch("nemo_retriever.graph.react_agent_operator.invoke_chat_completion_step") +@patch("nemo_retriever.agentic.retrieval.Retriever", FakeRetriever) +def test_agentic_retriever_runs_graph_with_wrapped_retriever(mock_react_step, mock_selection_step): + from nemo_retriever.agentic.retrieval import AgenticRetrievalConfig, AgenticRetriever + + mock_react_step.return_value = _make_tool_call_response( + "final_results", + {"doc_ids": ["doc_1"], "message": "done", "search_successful": "true"}, + ) + mock_selection_step.return_value = _make_tool_call_response( + "log_selected_documents", + {"doc_ids": ["doc_1"], "message": "doc_1 is best"}, + ) + + cfg = AgenticRetrievalConfig(llm_model="test-model", invoke_url="http://localhost/v1/chat/completions") + result = AgenticRetriever(cfg, match_mode="pdf_page").retrieve(["0"], ["find doc"]) + + assert list(result.columns) == ["query_id", "doc_id", "rank", "message"] + assert result["query_id"].tolist() == ["0"] + assert result["doc_id"].tolist() == ["doc_1"] + assert result["rank"].tolist() == [1] + + +@patch("nemo_retriever.graph.selection_agent_operator.invoke_chat_completion_step") +@patch("nemo_retriever.graph.react_agent_operator.invoke_chat_completion_step") +@patch("nemo_retriever.agentic.retrieval.Retriever", FakeRetriever) +def test_run_agentic_recall_evaluation_computes_metrics(mock_react_step, mock_selection_step, tmp_path): + from nemo_retriever.agentic.retrieval import AgenticRetrievalConfig, run_agentic_recall_evaluation + + query_csv = tmp_path / "queries.csv" + pd.DataFrame({"query": ["find doc"], "pdf_page": ["doc_1"]}).to_csv(query_csv, index=False) + + mock_react_step.return_value = _make_tool_call_response( + "final_results", + {"doc_ids": ["doc_1"], "message": "done", "search_successful": "true"}, + ) + mock_selection_step.return_value = _make_tool_call_response( + "log_selected_documents", + {"doc_ids": ["doc_1"], "message": "doc_1 is best"}, + ) + + cfg = AgenticRetrievalConfig(llm_model="test-model", invoke_url="http://localhost/v1/chat/completions") + df_query, result, qrels, run, metrics = run_agentic_recall_evaluation( + query_csv=query_csv, + cfg=cfg, + match_mode="pdf_page", + ks=(1, 5, 10), + ) + + assert df_query["golden_answer"].tolist() == ["doc_1"] + assert result["doc_id"].tolist() == ["doc_1"] + assert qrels == {"0": {"doc_1": 1}} + assert run["0"]["doc_1"] == 1.0 + assert metrics["recall@1"] == 1.0 + assert metrics["ndcg@1"] == 1.0 + + +def test_agentic_config_requires_llm_model(): + from nemo_retriever.agentic.retrieval import AgenticRetrievalConfig + + with pytest.raises(ValueError, match="llm_model"): + AgenticRetrievalConfig(llm_model="") + + +def test_pipeline_rejects_agentic_qa_mode(): + from nemo_retriever.pipeline.__main__ import app + + result = CliRunner().invoke( + app, + [ + ".", + "--evaluation-mode", + "qa", + "--retrieval-mode", + "agentic", + "--agentic-llm-model", + "test-model", + ], + ) + + assert result.exit_code != 0 + assert "--retrieval-mode=agentic is currently supported only with" in result.output + assert "--evaluation-mode=recall" in result.output + + +def test_pipeline_requires_agentic_llm_model(): + from nemo_retriever.pipeline.__main__ import app + + result = CliRunner().invoke( + app, + [ + ".", + "--evaluation-mode", + "recall", + "--retrieval-mode", + "agentic", + ], + ) + + assert result.exit_code != 0 + assert "--retrieval-mode=agentic requires --agentic-llm-model" in result.output From b915c46781e03b684696a99fd5c9a9beedf911d7 Mon Sep 17 00:00:00 2001 From: Mahika Wason Date: Tue, 12 May 2026 14:37:48 -0700 Subject: [PATCH 2/6] param defaults updated Signed-off-by: Mahika Wason --- .../src/nemo_retriever/agentic/retrieval.py | 6 ++- .../graph/react_agent_operator.py | 39 ++++++++++++++++--- 2 files changed, 37 insertions(+), 8 deletions(-) diff --git a/nemo_retriever/src/nemo_retriever/agentic/retrieval.py b/nemo_retriever/src/nemo_retriever/agentic/retrieval.py index cef60e9321..c4c2329949 100644 --- a/nemo_retriever/src/nemo_retriever/agentic/retrieval.py +++ b/nemo_retriever/src/nemo_retriever/agentic/retrieval.py @@ -32,9 +32,11 @@ logger = logging.getLogger(__name__) -AGENTIC_TOP_K = 10 +AGENTIC_RETRIEVER_TOP_K = 20 +AGENTIC_TARGET_TOP_K = 10 +AGENTIC_SELECTION_TOP_K = 10 AGENTIC_NUM_CONCURRENT = 1 -AGENTIC_TEXT_TRUNCATION = 500 +AGENTIC_TEXT_TRUNCATION = 2000 AGENTIC_PARALLEL_TOOL_CALLS = False AGENTIC_RRF_K = 60 AGENTIC_REACT_MAX_STEPS = 10 diff --git a/nemo_retriever/src/nemo_retriever/graph/react_agent_operator.py b/nemo_retriever/src/nemo_retriever/graph/react_agent_operator.py index c2bf7fe6ef..bcb976475d 100644 --- a/nemo_retriever/src/nemo_retriever/graph/react_agent_operator.py +++ b/nemo_retriever/src/nemo_retriever/graph/react_agent_operator.py @@ -758,14 +758,17 @@ def _call_retriever( logger.warning("ReActAgentOperator: retriever_fn failed for query %r: %s", query_text, exc, exc_info=True) return [] - # Filter already-seen and normalise keys + # Filter already-seen and normalise keys. Track this batch separately + # so duplicate rows from the vector DB do not reduce the effective top-k. results: List[Dict[str, Any]] = [] + batch_doc_ids: set[str] = set() for item in raw: doc_id = str(item.get("doc_id", item.get("id", ""))) text = str(item.get("text", "")) score = float(item.get("score", 0.0)) - if doc_id and doc_id not in seen_doc_ids: + if doc_id and doc_id not in seen_doc_ids and doc_id not in batch_doc_ids: results.append({"doc_id": doc_id, "text": text, "score": score}) + batch_doc_ids.add(doc_id) if len(results) >= self._retriever_top_k: break @@ -827,8 +830,32 @@ def _build_output_rows( ) # If final_results was called, also emit those as a synthetic final step - # (step_idx = len(retrieval_log)) so RRF can optionally weight it. - # These are already covered by the existing steps, so we skip deduplication - # here — RRF will naturally up-weight docs that appeared in final_results - # because they were retrieved in earlier steps. + # (step_idx = len(retrieval_log)) so RRF can weight the agent's final + # judgment in addition to the raw retrieval history. + if final_doc_ids: + first_doc_by_id: Dict[str, Dict[str, Any]] = {} + for step_docs in retrieval_log: + for doc in step_docs: + doc_id = str(doc.get("doc_id", "")) + if doc_id and doc_id not in first_doc_by_id: + first_doc_by_id[doc_id] = doc + + emitted: set[str] = set() + final_step_idx = len(retrieval_log) + for rank, doc_id in enumerate(final_doc_ids, 1): + doc_id = str(doc_id) + if not doc_id or doc_id in emitted: + continue + emitted.add(doc_id) + doc = first_doc_by_id.get(doc_id, {}) + rows.append( + { + "query_id": query_id, + "query_text": query_text, + "step_idx": final_step_idx, + "doc_id": doc_id, + "text": doc.get("text", ""), + "rank": rank, + } + ) return rows From bf7fd58f3e0f0d9c624edc057982125dbd0bfc77 Mon Sep 17 00:00:00 2001 From: Mahika Wason Date: Tue, 9 Jun 2026 09:27:50 -0700 Subject: [PATCH 3/6] adding Beir evaluation wiring and pinning down defaults Signed-off-by: Mahika Wason --- nemo_retriever/README.md | 26 +++ .../src/nemo_retriever/agentic/retrieval.py | 139 ++++++++++-- .../graph/react_agent_operator.py | 136 +++++++++--- .../graph/rrf_aggregator_operator.py | 24 ++- .../graph/selection_agent_operator.py | 92 +++++++- .../src/nemo_retriever/pipeline/__main__.py | 143 +++++++++--- nemo_retriever/tests/test_agentic_eval.py | 140 +++++++++++- .../tests/test_agentic_operators.py | 203 ++++++++++++++++-- 8 files changed, 800 insertions(+), 103 deletions(-) diff --git a/nemo_retriever/README.md b/nemo_retriever/README.md index 94f34b2f82..1957034aa7 100644 --- a/nemo_retriever/README.md +++ b/nemo_retriever/README.md @@ -259,6 +259,32 @@ hits = retriever.query(query) {'text': '| Table | 1 |\n| This | table | describes | some | animals, | and | some | activities | they | might | be | doing | in | specific |\n| locations. |\n| Animal | Activity | Place |\n| Giraffe | Driving | a | car | At | the | beach |\n| Lion | Putting | on | sunscreen | At | the | park |\n| Cat | Jumping | onto | a | laptop | In | a | home | office |\n| Dog | Chasing | a | squirrel | In | the | front | yard |\n| Chart | 1 |', 'metadata': '{"page_number": 1, "pdf_page": "multimodal_test_1", "page_elements_v3_num_detections": 9, "page_elements_v3_counts_by_label": {"table": 1, "chart": 1, "title": 3, "text": 4}, "ocr_table_detections": 1, "ocr_chart_detections": 1, "ocr_infographic_detections": 0}', 'source': '{"source_id": "/home/dev/projects/NeMo-Retriever/data/multimodal_test.pdf"}', 'page_number': 1, '_distance': 1.614684820175171} ``` +### Agentic BEIR evaluation + +The pipeline CLI can evaluate a LanceDB corpus with agentic retrieval against a +BEIR-style dataset: + +```bash +retriever pipeline run ./data \ + --evaluation-mode beir \ + --retrieval-mode agentic \ + --beir-loader vidore_hf \ + --beir-dataset-name vidore_v3_finance_en \ + --beir-doc-id-field pdf_basename \ + --agentic-llm-model nvidia/llama-3.3-nemotron-super-49b-v1.5 \ + --agentic-invoke-url http:///v1/chat/completions \ + --embed-invoke-url http:///v1 \ + --agentic-reasoning-effort high \ + --agentic-num-concurrent 10 +``` + +Common BEIR options are `--beir-split`, `--beir-query-language`, and +`--beir-doc-id-field`. Agentic controls include `--agentic-react-max-steps` +(default `50`), `--agentic-backend-top-k` (default `20`), and +`--agentic-text-truncation` (`0` disables truncation), +`--agentic-reasoning-effort`, and `--agentic-num-concurrent`. Throughput with +high concurrency is bounded by the configured LLM endpoint. + ### Generate a query answer using an LLM The above retrieval results are often feedable directly to an LLM for answer generation. diff --git a/nemo_retriever/src/nemo_retriever/agentic/retrieval.py b/nemo_retriever/src/nemo_retriever/agentic/retrieval.py index c4c2329949..bb36e59184 100644 --- a/nemo_retriever/src/nemo_retriever/agentic/retrieval.py +++ b/nemo_retriever/src/nemo_retriever/agentic/retrieval.py @@ -22,7 +22,9 @@ from nemo_retriever.graph.abstract_operator import AbstractOperator from nemo_retriever.model import VL_EMBED_MODEL, VL_RERANK_MODEL -from nemo_retriever.recall.beir import compute_beir_metrics +from nemo_retriever.recall.beir import VALID_BEIR_DOC_ID_FIELDS, compute_beir_metrics +from nemo_retriever.recall.beir import _extract_doc_id_from_hit as _beir_doc_id_from_hit +from nemo_retriever.recall.beir import load_beir_dataset from nemo_retriever.recall.core import ( _hit_to_audio_segment_key, _normalize_pdf_name, @@ -32,14 +34,15 @@ logger = logging.getLogger(__name__) -AGENTIC_RETRIEVER_TOP_K = 20 +AGENTIC_RETRIEVER_TOP_K = 10 AGENTIC_TARGET_TOP_K = 10 +AGENTIC_BACKEND_TOP_K = 20 # backend retrieve-pool depth. show-count stays AGENTIC_TARGET_TOP_K=10 AGENTIC_SELECTION_TOP_K = 10 AGENTIC_NUM_CONCURRENT = 1 -AGENTIC_TEXT_TRUNCATION = 2000 +AGENTIC_TEXT_TRUNCATION = 0 AGENTIC_PARALLEL_TOOL_CALLS = False AGENTIC_RRF_K = 60 -AGENTIC_REACT_MAX_STEPS = 10 +AGENTIC_REACT_MAX_STEPS = 50 class AgenticQueryInputOperator(AbstractOperator): @@ -126,20 +129,38 @@ class AgenticRetrievalConfig: invoke_url: Optional[str] = None api_key: Optional[str] = None react_max_steps: int = AGENTIC_REACT_MAX_STEPS + text_truncation: int = AGENTIC_TEXT_TRUNCATION + num_concurrent: int = AGENTIC_NUM_CONCURRENT + # Forwarded verbatim as the OpenAI `reasoning_effort` field on every LLM + # call. Defaults to Path A's validated setting. + reasoning_effort: Optional[str] = "high" + # Backend retrieve-pool depth, Distinct from the per-call show count (AGENTIC_TARGET_TOP_K). + backend_top_k: int = AGENTIC_BACKEND_TOP_K def __post_init__(self) -> None: if not str(self.llm_model).strip(): raise ValueError("Agentic retrieval requires a non-empty llm_model.") if int(self.react_max_steps) < 1: raise ValueError("react_max_steps must be >= 1.") + if int(self.text_truncation) < 0: + raise ValueError("text_truncation must be >= 0.") class AgenticRetriever: """Run graph-backed agentic retrieval over query IDs and query texts.""" - def __init__(self, cfg: AgenticRetrievalConfig, *, match_mode: str = "pdf_page") -> None: + def __init__( + self, + cfg: AgenticRetrievalConfig, + *, + match_mode: str = "pdf_page", + doc_id_field: str | None = None, + ) -> None: self._cfg = cfg self._match_mode = str(match_mode) + self._doc_id_field = str(doc_id_field) if doc_id_field else None + if self._doc_id_field is not None and self._doc_id_field not in VALID_BEIR_DOC_ID_FIELDS: + raise ValueError(f"Unsupported doc_id_field: {self._doc_id_field}") self._retriever = Retriever( vdb_kwargs={ "vdb_op": str(cfg.vdb_op), @@ -155,7 +176,7 @@ def __init__(self, cfg: AgenticRetrievalConfig, *, match_mode: str = "pdf_page") "inference_batch_size": int(cfg.local_hf_batch_size), "embed_inference_batch_size": int(cfg.local_hf_batch_size), }, - top_k=AGENTIC_TOP_K, + top_k=AGENTIC_RETRIEVER_TOP_K, rerank=bool(cfg.reranker), rerank_kwargs={ "model_name": cfg.reranker or VL_RERANK_MODEL, @@ -187,30 +208,39 @@ def retrieve(self, query_ids: Sequence[str], query_texts: Sequence[str]) -> pd.D invoke_url=_none_if_empty(self._cfg.invoke_url), llm_model=str(self._cfg.llm_model), retriever_fn=self._retrieve_for_agent, - retriever_top_k=AGENTIC_TOP_K, - target_top_k=AGENTIC_TOP_K, + retriever_top_k=AGENTIC_RETRIEVER_TOP_K, + target_top_k=AGENTIC_TARGET_TOP_K, user_msg_type="with_results", max_steps=int(self._cfg.react_max_steps), + extended_relevance=True, api_key=_none_if_empty(self._cfg.api_key), parallel_tool_calls=AGENTIC_PARALLEL_TOOL_CALLS, - num_concurrent=AGENTIC_NUM_CONCURRENT, + num_concurrent=int(self._cfg.num_concurrent), + reasoning_effort=self._cfg.reasoning_effort, + backend_top_k=self._cfg.backend_top_k, ) >> RRFAggregatorOperator(k=AGENTIC_RRF_K) >> SelectionAgentOperator( invoke_url=_none_if_empty(self._cfg.invoke_url), llm_model=str(self._cfg.llm_model), - top_k=AGENTIC_TOP_K, + top_k=AGENTIC_SELECTION_TOP_K, api_key=_none_if_empty(self._cfg.api_key), parallel_tool_calls=AGENTIC_PARALLEL_TOOL_CALLS, + extended_relevance=True, # match Path A + text_truncation=int(self._cfg.text_truncation), + reasoning_effort=self._cfg.reasoning_effort, ) >> AgenticSelectionOutputOperator() ) graph_retriever = Retriever( graph=pipeline, - top_k=AGENTIC_TOP_K, + top_k=AGENTIC_SELECTION_TOP_K, embed_kwargs={"text_column": "query_text"}, ) - raw_hits = graph_retriever.queries([str(query_text) for query_text in query_texts], top_k=AGENTIC_TOP_K) + raw_hits = graph_retriever.queries( + [str(query_text) for query_text in query_texts], + top_k=AGENTIC_SELECTION_TOP_K, + ) return _raw_hits_to_agentic_result([str(query_id) for query_id in query_ids], raw_hits) def _retrieve_for_agent(self, query_text: str, top_k: int) -> list[dict[str, Any]]: @@ -220,15 +250,24 @@ def _retrieve_for_agent(self, query_text: str, top_k: int) -> list[dict[str, Any hits = self._retriever.query(str(query_text), top_k=int(top_k)) docs: list[dict[str, Any]] = [] + doc_id_field = getattr(self, "_doc_id_field", None) for hit in hits: - doc_id = _doc_id_for_match_mode(dict(hit), match_mode=self._match_mode) + hit_dict = dict(hit) + doc_id = ( + _beir_doc_id_from_hit(hit_dict, doc_id_field=doc_id_field) + if doc_id_field is not None + else _doc_id_for_match_mode(hit_dict, match_mode=self._match_mode) + ) if not doc_id: continue + text = str(hit_dict.get("text", "")) + if int(self._cfg.text_truncation) > 0: + text = text[: int(self._cfg.text_truncation)] docs.append( { "doc_id": doc_id, - "text": str(hit.get("text", ""))[:AGENTIC_TEXT_TRUNCATION], - "score": _hit_score(hit), + "text": text, + "score": _hit_score(hit_dict), } ) if len(docs) >= int(top_k): @@ -266,6 +305,52 @@ def run_agentic_recall_evaluation( return df_query, result, qrels, run, metrics +def run_agentic_beir_evaluation( + *, + loader: str, + dataset_name: str, + cfg: AgenticRetrievalConfig, + split: str = "test", + query_language: str | None = None, + doc_id_field: str = "pdf_basename", + ks: Sequence[int] = (1, 3, 5, 10), +) -> tuple[pd.DataFrame, pd.DataFrame, dict[str, dict[str, int]], dict[str, dict[str, float]], dict[str, float]]: + """Run agentic retrieval using BEIR-style queries and qrels.""" + + beir_dataset = load_beir_dataset( + str(loader), + dataset_name=str(dataset_name), + split=str(split), + query_language=query_language, + doc_id_field=str(doc_id_field), + ) + + start = time.time() + result = AgenticRetriever(cfg, match_mode="pdf_page", doc_id_field=str(doc_id_field)).retrieve( + beir_dataset.query_ids, + beir_dataset.queries, + ) + elapsed = time.time() - start + if elapsed > 0: + logger.info( + "Agentic BEIR retrieval time for %d queries: %.2f seconds (average %.2f queries/second)", + len(beir_dataset.query_ids), + elapsed, + len(beir_dataset.query_ids) / elapsed, + ) + + run = build_beir_run_from_agentic_result(beir_dataset.query_ids, result) + metrics = compute_beir_metrics(beir_dataset.qrels, run, ks=ks) + df_query = pd.DataFrame( + { + "query_id": beir_dataset.query_ids, + "query": beir_dataset.queries, + "golden_answer": [",".join(beir_dataset.qrels.get(qid, {}).keys()) for qid in beir_dataset.query_ids], + } + ) + return df_query, result, beir_dataset.qrels, run, metrics + + def build_qrels(query_ids: Sequence[str], gold_keys: Sequence[str]) -> dict[str, dict[str, int]]: """Build BEIR-style qrels from normalized recall gold keys.""" @@ -303,18 +388,36 @@ def build_beir_run_from_agentic_result( def _raw_hits_to_agentic_result(query_ids: Sequence[str], raw_hits: Sequence[Sequence[dict[str, Any]]]) -> pd.DataFrame: rows: list[dict[str, Any]] = [] - for query_id, hits in zip(query_ids, raw_hits): + # The agentic graph (AgenticQueryInputOperator) assigns POSITIONAL query_ids + # "0".."N-1" to its inputs, independent of the caller's real ids. Each hit + # therefore carries its positional index; map that back through `query_ids` + # to recover the caller's real id. This is robust to (a) ThreadPool + # completion-order reordering at num_concurrent>1, (b) queries that produced + # no rows (gaps don't shift anything), and (c) sharded offset ranges where the + # positional index != the caller's real id (e.g. query_ids=["1000".."1049"]). + # For a full sweep with sequential ids ("0".."N-1") this is a no-op vs the + # old positional zip. + n = len(query_ids) + for pos, hits in enumerate(raw_hits): for rank, hit in enumerate(hits, start=1): + raw_qid = hit.get("query_id") + if raw_qid is not None and str(raw_qid).isdigit() and int(raw_qid) < n: + qid = str(query_ids[int(raw_qid)]) + elif pos < n: + qid = str(query_ids[pos]) + else: + qid = str(raw_qid) if raw_qid is not None else "" rows.append( { - "query_id": str(query_id), + "query_id": qid, "doc_id": str(hit.get("doc_id") or hit.get("pdf_page") or ""), "rank": int(hit.get("rank", rank)), "message": str(hit.get("message", "")), + "result_source": str(hit.get("result_source", "")), } ) if not rows: - return pd.DataFrame(columns=["query_id", "doc_id", "rank", "message"]) + return pd.DataFrame(columns=["query_id", "doc_id", "rank", "message", "result_source"]) return pd.DataFrame(rows) diff --git a/nemo_retriever/src/nemo_retriever/graph/react_agent_operator.py b/nemo_retriever/src/nemo_retriever/graph/react_agent_operator.py index bcb976475d..785c20d357 100644 --- a/nemo_retriever/src/nemo_retriever/graph/react_agent_operator.py +++ b/nemo_retriever/src/nemo_retriever/graph/react_agent_operator.py @@ -23,7 +23,7 @@ logger = logging.getLogger(__name__) _LOG_PREVIEW_CHARS = 300 -_LOG_DOC_ID_LIMIT = 10 +_LOG_DOC_ID_LIMIT = 20 def _preview_text(value: Any, *, limit: int = _LOG_PREVIEW_CHARS) -> str: @@ -259,6 +259,7 @@ def _make_final_results_tool_spec(top_k: Optional[int]) -> Dict[str, Any]: "doc_ids": { "type": "array", "items": {"type": "string"}, + "minItems": 1, "description": ( "List of document IDs that are relevant to the user's query sorted descending " "by their level of relevance to the user's query. I.e., the first document is " @@ -410,6 +411,8 @@ def __init__( api_key: Optional[str] = None, max_tokens: Optional[int] = None, parallel_tool_calls: bool = True, + reasoning_effort: Optional[str] = None, + backend_top_k: Optional[int] = None, ) -> None: super().__init__() self._invoke_url = invoke_url or self._NVIDIA_BUILD_ENDPOINT @@ -425,6 +428,17 @@ def __init__( self._api_key = api_key self._max_tokens = max_tokens self._parallel_tool_calls = parallel_tool_calls + self._reasoning_effort = reasoning_effort + self._backend_top_k = backend_top_k + + def _build_extra_body(self) -> Optional[Dict[str, Any]]: + """Assemble per-call extra payload fields (parallel_tool_calls, reasoning_effort).""" + extra: Dict[str, Any] = {} + if not self._parallel_tool_calls: + extra["parallel_tool_calls"] = False + if self._reasoning_effort: + extra["reasoning_effort"] = self._reasoning_effort + return extra or None # ------------------------------------------------------------------ # AbstractOperator interface @@ -453,28 +467,32 @@ def process(self, data: pd.DataFrame, **kwargs: Any) -> pd.DataFrame: qid, qtxt = query_rows[0] rows.extend(self._run_single_query(qid, qtxt, api_key)) else: + # Collect per-query results keyed by query_id, then re-emit in the ORIGINAL + # input order. as_completed() yields futures in nondeterministic completion + # order; emitting in that order would make downstream groupby(sort=False) + # output order depend on which query finished first. Re-ordering here keeps + # the operator output deterministic regardless of concurrency. + results_by_qid: Dict[str, List[Dict[str, Any]]] = {} with ThreadPoolExecutor(max_workers=min(self._num_concurrent, len(query_rows))) as executor: futures = { executor.submit(self._run_single_query, qid, qtxt, api_key): (qid, qtxt) for qid, qtxt in query_rows } for future in as_completed(futures): + qid, qtxt = futures[future] try: - rows.extend(future.result()) + results_by_qid[qid] = future.result() except TimeoutError as exc: - qid, qtxt = futures[future] logger.warning("ReActAgentOperator: query %r timed out: %s", qid, exc, exc_info=True) except RuntimeError as exc: - qid, qtxt = futures[future] logger.warning("ReActAgentOperator: query %r retries exhausted: %s", qid, exc, exc_info=True) except requests.RequestException as exc: - qid, qtxt = futures[future] logger.warning("ReActAgentOperator: query %r HTTP error: %s", qid, exc, exc_info=True) except (json.JSONDecodeError, ValueError) as exc: - qid, qtxt = futures[future] logger.warning("ReActAgentOperator: query %r data error: %s", qid, exc, exc_info=True) except Exception as exc: # catches unexpected worker errors not covered above - qid, qtxt = futures[future] logger.warning("ReActAgentOperator: query %r failed: %s", qid, exc, exc_info=True) + for qid, _qtxt in query_rows: + rows.extend(results_by_qid.get(qid, [])) if not rows: return pd.DataFrame(columns=["query_id", "query_text", "step_idx", "doc_id", "text", "rank"]) @@ -562,7 +580,7 @@ def _run_single_query( tools=tools, tool_choice="auto", max_tokens=self._max_tokens, - extra_body={"parallel_tool_calls": False} if not self._parallel_tool_calls else None, + extra_body=self._build_extra_body(), ) except TimeoutError as exc: logger.warning( @@ -705,16 +723,25 @@ def _run_single_query( raw_ids[:_LOG_DOC_ID_LIMIT] if isinstance(raw_ids, list) else raw_ids, _preview_text(fn_args.get("message")), ) - if isinstance(raw_ids, list) and raw_ids: - final_doc_ids = [str(d) for d in raw_ids] - tool_messages.append( - { - "role": "tool", - "tool_call_id": tc_id, - "content": "The results have been successfully logged and the interaction ended.", - } - ) - loop_done = True + validation_error = self._validate_final_results_args(fn_args) + if validation_error is None: + final_doc_ids = list(raw_ids) + tool_messages.append( + { + "role": "tool", + "tool_call_id": tc_id, + "content": "The results have been successfully logged and the interaction ended.", + } + ) + loop_done = True + else: + tool_messages.append( + { + "role": "tool", + "tool_call_id": tc_id, + "content": f"Error: {validation_error}", + } + ) else: tool_messages.append( @@ -742,6 +769,11 @@ def _call_retriever( ) -> List[Dict[str, Any]]: """Call retriever_fn, over-fetching to ensure new results after dedup.""" fetch_k = self._retriever_top_k + len(seen_doc_ids) + # Optional fixed ceiling on backend depth, matching Path A's --retriever-top-k + # cap. Once the agent has seen the whole capped pool, retrieves return no new + # docs, so the prompt stops growing (prevents context-window overflow). + if self._backend_top_k: + fetch_k = min(fetch_k, int(self._backend_top_k)) try: raw = self._retriever_fn(query_text, fetch_k) except TimeoutError as exc: @@ -758,22 +790,74 @@ def _call_retriever( logger.warning("ReActAgentOperator: retriever_fn failed for query %r: %s", query_text, exc, exc_info=True) return [] - # Filter already-seen and normalise keys. Track this batch separately - # so duplicate rows from the vector DB do not reduce the effective top-k. + # Walk the ranked results, normalising keys and de-duplicating within the + # batch. Already-seen docs are always re-presented as short stubs, matching + # Path A's retrieve_with_guarantees, and do not count toward top_k. results: List[Dict[str, Any]] = [] batch_doc_ids: set[str] = set() + new_count = 0 for item in raw: doc_id = str(item.get("doc_id", item.get("id", ""))) - text = str(item.get("text", "")) + if not doc_id or doc_id in batch_doc_ids: + continue score = float(item.get("score", 0.0)) - if doc_id and doc_id not in seen_doc_ids and doc_id not in batch_doc_ids: - results.append({"doc_id": doc_id, "text": text, "score": score}) + if doc_id in seen_doc_ids: batch_doc_ids.add(doc_id) - if len(results) >= self._retriever_top_k: + results.append( + { + "doc_id": doc_id, + "text": ( + "This document was retrieved before. See the earlier retrieval " + f"results for its content (id: {doc_id})." + ), + "score": score, + } + ) + continue + batch_doc_ids.add(doc_id) + results.append( + { + "doc_id": doc_id, + "text": str(item.get("text", "")), + "score": score, + } + ) + new_count += 1 + if new_count >= self._retriever_top_k: break return results + def _validate_final_results_args(self, fn_args: Dict[str, Any]) -> Optional[str]: + """Validate final_results tool args outside the prompt/schema.""" + message = fn_args.get("message") + if not isinstance(message, str): + return f"`message` must be a string. Got `{type(message)}` type." + + doc_ids = fn_args.get("doc_ids") + if not isinstance(doc_ids, list): + return f"`doc_ids` must be a list. Got `{type(doc_ids)}` type." + if len(doc_ids) == 0: + return "`doc_ids` cannot be empty. You must choose at least one relevant document." + if not all(isinstance(doc_id, str) for doc_id in doc_ids): + return "Items in `doc_ids` must be of type string (i.e., python's `str` type)." + + search_successful = fn_args.get("search_successful") + if not isinstance(search_successful, str): + return f"`search_successful` must be a string. Got `{type(search_successful)}` type." + if search_successful not in {"true", "false", "partial"}: + return ( + f"`search_successful` must be one of `true`, `false`, or `partial`. Got `{search_successful}` instead." + ) + + if self._enforce_top_k and len(doc_ids) != self._target_top_k: + return ( + f"`doc_ids` must contain exactly {self._target_top_k} documents. " + f"But got {len(doc_ids)} document IDs instead." + ) + + return None + def _resolve_api_key(self) -> Optional[str]: api_key = self._api_key if api_key is not None and api_key.strip().startswith("os.environ/"): @@ -825,6 +909,8 @@ def _build_output_rows( "step_idx": step_idx, "doc_id": doc.get("doc_id", ""), "text": doc.get("text", ""), + "has_valid_final_results": final_doc_ids is not None, + "is_final_result": False, "rank": rank, } ) @@ -855,6 +941,8 @@ def _build_output_rows( "step_idx": final_step_idx, "doc_id": doc_id, "text": doc.get("text", ""), + "has_valid_final_results": True, + "is_final_result": True, "rank": rank, } ) diff --git a/nemo_retriever/src/nemo_retriever/graph/rrf_aggregator_operator.py b/nemo_retriever/src/nemo_retriever/graph/rrf_aggregator_operator.py index 39053bc3da..b9988efc00 100644 --- a/nemo_retriever/src/nemo_retriever/graph/rrf_aggregator_operator.py +++ b/nemo_retriever/src/nemo_retriever/graph/rrf_aggregator_operator.py @@ -102,6 +102,12 @@ def process(self, data: pd.DataFrame, **kwargs: Any) -> pd.DataFrame: rrf_scores: Dict[str, float] = defaultdict(float) first_text: Dict[str, str] = {} + react_final_rank: Dict[str, int] = {} + has_valid_final_results = ( + bool(qgroup.get("has_valid_final_results", False).astype(bool).any()) + if "has_valid_final_results" in qgroup.columns + else False + ) # Process each step's ranked list for _step_idx, sgroup in qgroup.groupby("step_idx", sort=True): @@ -112,6 +118,10 @@ def process(self, data: pd.DataFrame, **kwargs: Any) -> pd.DataFrame: rrf_scores[doc_id] += 1.0 / (rank + k) if doc_id not in first_text: first_text[doc_id] = str(row["text"]) + if bool(row.get("is_final_result", False)): + previous = react_final_rank.get(doc_id) + if previous is None or rank < previous: + react_final_rank[doc_id] = rank for doc_id, score in sorted(rrf_scores.items(), key=lambda kv: kv[1], reverse=True): rows.append( @@ -121,11 +131,23 @@ def process(self, data: pd.DataFrame, **kwargs: Any) -> pd.DataFrame: "doc_id": doc_id, "rrf_score": score, "text": first_text.get(doc_id, ""), + "has_valid_final_results": has_valid_final_results, + "react_final_rank": react_final_rank.get(doc_id), } ) if not rows: - return pd.DataFrame(columns=["query_id", "query_text", "doc_id", "rrf_score", "text"]) + return pd.DataFrame( + columns=[ + "query_id", + "query_text", + "doc_id", + "rrf_score", + "text", + "has_valid_final_results", + "react_final_rank", + ] + ) return pd.DataFrame(rows) diff --git a/nemo_retriever/src/nemo_retriever/graph/selection_agent_operator.py b/nemo_retriever/src/nemo_retriever/graph/selection_agent_operator.py index bddb2bbb37..6f53342388 100644 --- a/nemo_retriever/src/nemo_retriever/graph/selection_agent_operator.py +++ b/nemo_retriever/src/nemo_retriever/graph/selection_agent_operator.py @@ -22,7 +22,7 @@ logger = logging.getLogger(__name__) _LOG_PREVIEW_CHARS = 300 -_LOG_DOC_ID_LIMIT = 10 +_LOG_DOC_ID_LIMIT = 20 def _preview_text(value: Any, *, limit: int = _LOG_PREVIEW_CHARS) -> str: @@ -31,6 +31,7 @@ def _preview_text(value: Any, *, limit: int = _LOG_PREVIEW_CHARS) -> str: return text return text[:limit].rstrip() + "..." + # --------------------------------------------------------------------------- # Prompt rendering (verbatim content of 01_v0.j2, rendered via Python) # --------------------------------------------------------------------------- @@ -198,8 +199,10 @@ def __init__( text_truncation: int = 2000, parallel_tool_calls: bool = True, base_url: Optional[str] = None, + reasoning_effort: Optional[str] = None, ) -> None: super().__init__() + self._reasoning_effort = reasoning_effort self._llm_model = llm_model self._top_k = top_k self._api_key = api_key @@ -246,7 +249,16 @@ def process(self, data: pd.DataFrame, **kwargs: Any) -> pd.DataFrame: for query_id, group in data.groupby("query_id", sort=False): query_text = str(group["query_text"].iloc[0]) - docs = [{"id": str(row["doc_id"]), "text": str(row["text"])} for _, row in group.iterrows()] + ordered_group = group + if "rrf_score" in group.columns: + ordered_group = group.sort_values("rrf_score", ascending=False) + docs = [ + { + "id": str(row["doc_id"]), + "text": str(row["text"]), + } + for _, row in ordered_group.iterrows() + ] logger.info( "SelectionAgentOperator: query=%s start candidates=%d unique_candidates=%d query=%r", query_id, @@ -254,26 +266,49 @@ def process(self, data: pd.DataFrame, **kwargs: Any) -> pd.DataFrame: len({doc["id"] for doc in docs}), _preview_text(query_text), ) - result = self._select_documents(query_text, docs) - message = result.get("message", "") + preferred_doc_ids, message, result_source = self._preferred_doc_ids(ordered_group) + if preferred_doc_ids is None: + result = self._select_documents(query_text, docs) + message = result.get("message", "") + doc_ids = list(result.get("doc_ids", [])) + result_source = "selection_agent" + else: + doc_ids = preferred_doc_ids + if not doc_ids: + if preferred_doc_ids is None: + doc_ids = ordered_group["doc_id"].astype(str).drop_duplicates().head(int(self._top_k)).tolist() + message = ( + f"{message} Falling back to top {len(doc_ids)} RRF-ranked candidates." + if message + else f"Falling back to top {len(doc_ids)} RRF-ranked candidates." + ) + result_source = "candidate_ranking" + logger.warning( + "SelectionAgentOperator: query=%s selection failed; " + "falling back to candidate ranking doc_ids=%s", + query_id, + doc_ids[:_LOG_DOC_ID_LIMIT], + ) logger.info( - "SelectionAgentOperator: query=%s selected=%s message=%r", + "SelectionAgentOperator: query=%s result_source=%s selected=%s message=%r", query_id, - result.get("doc_ids", [])[:_LOG_DOC_ID_LIMIT], + result_source, + doc_ids[:_LOG_DOC_ID_LIMIT], _preview_text(message), ) - for rank, doc_id in enumerate(result.get("doc_ids", []), 1): + for rank, doc_id in enumerate(doc_ids, 1): rows.append( { "query_id": query_id, "doc_id": doc_id, "rank": rank, "message": message, + "result_source": result_source, } ) if not rows: - return pd.DataFrame(columns=["query_id", "doc_id", "rank", "message"]) + return pd.DataFrame(columns=["query_id", "doc_id", "rank", "message", "result_source"]) return pd.DataFrame(rows) @@ -357,6 +392,38 @@ def _build_tools(self, top_k: int, valid_doc_ids: List[str]) -> List[Dict[str, A }, ] + def _preferred_doc_ids(self, ordered_group: pd.DataFrame) -> tuple[List[str] | None, str, str]: + """Apply retrieval-bench-style source priority before invoking selection.""" + doc_ids = self._react_final_doc_ids(ordered_group) + if doc_ids is not None: + return doc_ids, "Using ReAct final_results.", "final_results" + + if "rrf_score" in ordered_group.columns: + doc_ids = ordered_group["doc_id"].astype(str).drop_duplicates().head(int(self._top_k)).tolist() + if doc_ids: + return doc_ids, "Using RRF ranking.", "rrf" + + return None, "", "" + + def _react_final_doc_ids(self, ordered_group: pd.DataFrame) -> List[str] | None: + if "has_valid_final_results" in ordered_group.columns and not bool( + ordered_group["has_valid_final_results"].astype(bool).any() + ): + return None + if "react_final_rank" not in ordered_group.columns: + return None + final_rows = ordered_group[ordered_group["react_final_rank"].notna()].copy() + if final_rows.empty: + return [] if "has_valid_final_results" in ordered_group.columns else None + final_rows["react_final_rank"] = final_rows["react_final_rank"].astype(int) + doc_ids: List[str] = [] + for doc_id in final_rows.sort_values("react_final_rank")["doc_id"].astype(str): + if doc_id and doc_id not in doc_ids: + doc_ids.append(doc_id) + if len(doc_ids) >= int(self._top_k): + break + return doc_ids + def _build_user_message(self, query_text: str, docs: List[Dict[str, Any]]) -> Dict[str, Any]: """Format query + candidate documents as a multi-part user message.""" content: List[Dict[str, Any]] = [ @@ -372,8 +439,11 @@ def _build_user_message(self, query_text: str, docs: List[Dict[str, Any]]) -> Di content.append({"type": "text", "text": f"Doc ID: {doc_id}"}) text = doc.get("text", "").strip() if text: - truncated = text[: self._text_truncation] - if len(text) > self._text_truncation: + if self._text_truncation > 0: + truncated = text[: self._text_truncation] + else: + truncated = text + if self._text_truncation > 0 and len(text) > self._text_truncation: truncated += "..." content.append({"type": "text", "text": f"Doc Text: {truncated}"}) return {"role": "user", "content": content} @@ -406,6 +476,8 @@ def _select_documents( extra_body: Dict[str, Any] = {} if not self._parallel_tool_calls: extra_body["parallel_tool_calls"] = False + if self._reasoning_effort: + extra_body["reasoning_effort"] = self._reasoning_effort for _step in range(self._max_steps): logger.info( diff --git a/nemo_retriever/src/nemo_retriever/pipeline/__main__.py b/nemo_retriever/src/nemo_retriever/pipeline/__main__.py index c8c647a5f9..e61fda1d70 100644 --- a/nemo_retriever/src/nemo_retriever/pipeline/__main__.py +++ b/nemo_retriever/src/nemo_retriever/pipeline/__main__.py @@ -688,7 +688,7 @@ def _run_evaluation( embed_invoke_url: Optional[str], embed_remote_api_key: Optional[str], embed_modality: str, - query_csv: Path, + query_csv: Optional[Path], recall_match_mode: str, audio_match_tolerance_secs: float, reranker: Optional[bool], @@ -820,15 +820,16 @@ def _run_evaluation( return "Audio Recall", time.perf_counter() - evaluation_start, metrics, len(df_query.index), True -def _run_agentic_recall_evaluation( +def _run_agentic_evaluation( *, + evaluation_mode: str, vdb_op: str, vdb_kwargs: dict[str, Any], embed_model_name: str, embed_invoke_url: Optional[str], embed_remote_api_key: Optional[str], embed_modality: str, - query_csv: Path, + query_csv: Optional[Path], recall_match_mode: str, reranker: Optional[bool], reranker_model_name: str, @@ -841,15 +842,24 @@ def _run_agentic_recall_evaluation( agentic_invoke_url: Optional[str], agentic_api_key: Optional[str], agentic_react_max_steps: int, + agentic_backend_top_k: int, + agentic_text_truncation: int, + agentic_reasoning_effort: Optional[str], + agentic_num_concurrent: int, + beir_loader: Optional[str], + beir_dataset_name: Optional[str], + beir_split: str, + beir_query_language: Optional[str], + beir_doc_id_field: str, + beir_k: list[int], ) -> tuple[str, float, dict[str, float], Optional[int], bool]: - """Run recall evaluation using graph-backed agentic retrieval.""" - - query_csv_path = Path(query_csv) - if not query_csv_path.exists(): - logger.warning("Query CSV not found at %s; skipping agentic recall evaluation.", query_csv_path) - return "Agentic Recall", 0.0, {}, None, False + """Run evaluation using graph-backed agentic retrieval.""" - from nemo_retriever.agentic.retrieval import AgenticRetrievalConfig, run_agentic_recall_evaluation + from nemo_retriever.agentic.retrieval import ( + AgenticRetrievalConfig, + run_agentic_beir_evaluation, + run_agentic_recall_evaluation, + ) from nemo_retriever.model import resolve_embed_model embed_model = resolve_embed_model(str(embed_model_name)) @@ -871,23 +881,54 @@ def _run_agentic_recall_evaluation( invoke_url=agentic_invoke_url, api_key=agentic_api_key, react_max_steps=int(agentic_react_max_steps), + backend_top_k=int(agentic_backend_top_k), + text_truncation=int(agentic_text_truncation), + reasoning_effort=agentic_reasoning_effort, + num_concurrent=int(agentic_num_concurrent), ) evaluation_start = time.perf_counter() - df_query, _result, _qrels, _run, metrics = run_agentic_recall_evaluation( - query_csv=query_csv_path, - cfg=cfg, - match_mode=recall_match_mode, - ks=(1, 5, 10), - ) - logger.info("Agentic recall gold ids: %s", {qid: list(docs.keys()) for qid, docs in _qrels.items()}) - logger.info("Agentic recall retrieved ids: %s", {qid: list(docs.keys())[:10] for qid, docs in _run.items()}) - logger.info("Agentic recall result columns: %s", list(_result.columns)) + if evaluation_mode == "beir": + if not beir_loader: + raise ValueError("--beir-loader is required when --evaluation-mode=beir") + if not beir_dataset_name: + raise ValueError("--beir-dataset-name is required when --evaluation-mode=beir") + df_query, _result, _qrels, _run, metrics = run_agentic_beir_evaluation( + loader=str(beir_loader), + dataset_name=str(beir_dataset_name), + cfg=cfg, + split=str(beir_split), + query_language=beir_query_language, + doc_id_field=str(beir_doc_id_field), + ks=tuple(beir_k) if beir_k else (1, 3, 5, 10), + ) + evaluation_label = "Agentic BEIR" + elif evaluation_mode == "audio_recall": + if query_csv is None: + logger.warning("No query CSV configured; skipping agentic recall evaluation.") + return "Agentic Recall", 0.0, {}, None, False + query_csv_path = Path(query_csv) + if not query_csv_path.exists(): + logger.warning("Query CSV not found at %s; skipping agentic recall evaluation.", query_csv_path) + return "Agentic Recall", 0.0, {}, None, False + df_query, _result, _qrels, _run, metrics = run_agentic_recall_evaluation( + query_csv=query_csv_path, + cfg=cfg, + match_mode=recall_match_mode, + ks=(1, 5, 10), + ) + evaluation_label = "Agentic Recall" + else: + raise ValueError(f"Unsupported agentic evaluation mode: {evaluation_mode!r}") + logger.info("%s gold ids: %s", evaluation_label, {qid: list(docs.keys()) for qid, docs in _qrels.items()}) + logger.info("%s retrieved ids: %s", evaluation_label, {qid: list(docs.keys())[:10] for qid, docs in _run.items()}) + logger.info("%s result columns: %s", evaluation_label, list(_result.columns)) if {"query_id", "doc_id", "rank"}.issubset(_result.columns): logger.info( - "Agentic recall top result rows:\n%s", + "%s top result rows:\n%s", + evaluation_label, _result[["query_id", "doc_id", "rank"]].head(20).to_string(index=False), ) - return "Agentic Recall", time.perf_counter() - evaluation_start, metrics, len(df_query.index), True + return evaluation_label, time.perf_counter() - evaluation_start, metrics, len(df_query.index), True # --------------------------------------------------------------------------- @@ -1297,8 +1338,8 @@ def run( help="Retrieval strategy for evaluation: 'standard' (default) or 'agentic' for recall evaluation.", rich_help_panel=_PANEL_EVAL, ), - query_csv: Path = typer.Option( - "./data/bo767_query_gt.csv", + query_csv: Optional[Path] = typer.Option( + Path("./data/bo767_query_gt.csv"), "--query-csv", path_type=Path, rich_help_panel=_PANEL_EVAL, @@ -1361,12 +1402,39 @@ def run( rich_help_panel=_PANEL_EVAL, ), agentic_react_max_steps: int = typer.Option( - 10, + 50, "--agentic-react-max-steps", min=1, help="Maximum ReAct loop iterations per query for --retrieval-mode=agentic.", rich_help_panel=_PANEL_EVAL, ), + agentic_backend_top_k: int = typer.Option( + 20, + "--agentic-backend-top-k", + min=1, + help=("Backend retrieve-pool depth per agentic query, distinct from the per-call candidate show count."), + rich_help_panel=_PANEL_EVAL, + ), + agentic_text_truncation: int = typer.Option( + 0, + "--agentic-text-truncation", + min=0, + help="Maximum characters of each candidate text shown to agentic LLMs; 0 disables truncation.", + rich_help_panel=_PANEL_EVAL, + ), + agentic_reasoning_effort: Optional[str] = typer.Option( + "high", + "--agentic-reasoning-effort", + help="OpenAI-compatible reasoning_effort value forwarded to agentic LLM calls.", + rich_help_panel=_PANEL_EVAL, + ), + agentic_num_concurrent: int = typer.Option( + 1, + "--agentic-num-concurrent", + min=1, + help="Maximum number of agentic queries to run concurrently.", + rich_help_panel=_PANEL_EVAL, + ), beir_loader: Optional[str] = typer.Option(None, "--beir-loader", rich_help_panel=_PANEL_EVAL), beir_dataset_name: Optional[str] = typer.Option(None, "--beir-dataset-name", rich_help_panel=_PANEL_EVAL), beir_split: str = typer.Option("test", "--beir-split", rich_help_panel=_PANEL_EVAL), @@ -1427,9 +1495,10 @@ def run( raise ValueError(f"Unsupported --evaluation-mode: {evaluation_mode!r}") if retrieval_mode not in {"standard", "agentic"}: raise ValueError(f"Unsupported --retrieval-mode: {retrieval_mode!r}") - if retrieval_mode == "agentic" and evaluation_mode != "audio_recall": + if retrieval_mode == "agentic" and evaluation_mode not in {"audio_recall", "beir"}: raise typer.BadParameter( - "--retrieval-mode=agentic is currently supported only with --evaluation-mode=audio_recall." + "--retrieval-mode=agentic is currently supported only with --evaluation-mode=audio_recall or " + "--evaluation-mode=beir." ) if evaluation_mode == "audio_recall": if input_type != "audio": @@ -1442,11 +1511,18 @@ def run( "Use the same file format as `retriever eval run --config` (dataset, retrieval, models, ...)." ) resolved_agentic_llm_model = (agentic_llm_model or os.environ.get(AGENTIC_LLM_MODEL_ENV) or "").strip() - resolved_agentic_invoke_url = (agentic_invoke_url or os.environ.get(AGENTIC_INVOKE_URL_ENV) or "").strip() or None + resolved_agentic_invoke_url = ( + agentic_invoke_url or os.environ.get(AGENTIC_INVOKE_URL_ENV) or "" + ).strip() or None if retrieval_mode == "agentic" and not resolved_agentic_llm_model: raise typer.BadParameter( f"--retrieval-mode=agentic requires --agentic-llm-model or {AGENTIC_LLM_MODEL_ENV}." ) + if evaluation_mode == "beir": + if not beir_loader: + raise typer.BadParameter("--beir-loader is required when --evaluation-mode=beir.") + if not beir_dataset_name: + raise typer.BadParameter("--beir-dataset-name is required when --evaluation-mode=beir.") if run_mode == "batch": # --quiet implies --no-ray-log-to-driver: Ray flushes worker stdout @@ -1893,7 +1969,8 @@ def _run_ingest() -> Any: if retrieval_mode == "agentic": evaluation_label, evaluation_total_time, evaluation_metrics, evaluation_query_count, ran = ( - _run_agentic_recall_evaluation( + _run_agentic_evaluation( + evaluation_mode=evaluation_mode, vdb_op=resolved_vdb_op, vdb_kwargs=resolved_vdb_kwargs, embed_model_name=embed_model_name, @@ -1913,6 +1990,16 @@ def _run_ingest() -> Any: agentic_invoke_url=resolved_agentic_invoke_url, agentic_api_key=remote_api_key, agentic_react_max_steps=agentic_react_max_steps, + agentic_backend_top_k=agentic_backend_top_k, + agentic_text_truncation=agentic_text_truncation, + agentic_reasoning_effort=agentic_reasoning_effort, + agentic_num_concurrent=agentic_num_concurrent, + beir_loader=beir_loader, + beir_dataset_name=beir_dataset_name, + beir_split=beir_split, + beir_query_language=beir_query_language, + beir_doc_id_field=beir_doc_id_field, + beir_k=beir_k, ) ) else: diff --git a/nemo_retriever/tests/test_agentic_eval.py b/nemo_retriever/tests/test_agentic_eval.py index 05f050ca9a..abe1b9fe25 100644 --- a/nemo_retriever/tests/test_agentic_eval.py +++ b/nemo_retriever/tests/test_agentic_eval.py @@ -103,9 +103,10 @@ def test_build_beir_run_from_agentic_result_orders_by_rank(): def test_agentic_retriever_runs_graph_with_wrapped_retriever(mock_react_step, mock_selection_step): from nemo_retriever.agentic.retrieval import AgenticRetrievalConfig, AgenticRetriever + final_ids = ["doc_1"] + [f"extra_{i}" for i in range(9)] mock_react_step.return_value = _make_tool_call_response( "final_results", - {"doc_ids": ["doc_1"], "message": "done", "search_successful": "true"}, + {"doc_ids": final_ids, "message": "done", "search_successful": "true"}, ) mock_selection_step.return_value = _make_tool_call_response( "log_selected_documents", @@ -115,10 +116,10 @@ def test_agentic_retriever_runs_graph_with_wrapped_retriever(mock_react_step, mo cfg = AgenticRetrievalConfig(llm_model="test-model", invoke_url="http://localhost/v1/chat/completions") result = AgenticRetriever(cfg, match_mode="pdf_page").retrieve(["0"], ["find doc"]) - assert list(result.columns) == ["query_id", "doc_id", "rank", "message"] - assert result["query_id"].tolist() == ["0"] - assert result["doc_id"].tolist() == ["doc_1"] - assert result["rank"].tolist() == [1] + assert list(result.columns) == ["query_id", "doc_id", "rank", "message", "result_source"] + assert result["query_id"].tolist() == ["0"] * 10 + assert result["doc_id"].tolist()[0] == "doc_1" + assert result["rank"].tolist() == list(range(1, 11)) @patch("nemo_retriever.graph.selection_agent_operator.invoke_chat_completion_step") @@ -130,9 +131,10 @@ def test_run_agentic_recall_evaluation_computes_metrics(mock_react_step, mock_se query_csv = tmp_path / "queries.csv" pd.DataFrame({"query": ["find doc"], "pdf_page": ["doc_1"]}).to_csv(query_csv, index=False) + final_ids = ["doc_1"] + [f"extra_{i}" for i in range(9)] mock_react_step.return_value = _make_tool_call_response( "final_results", - {"doc_ids": ["doc_1"], "message": "done", "search_successful": "true"}, + {"doc_ids": final_ids, "message": "done", "search_successful": "true"}, ) mock_selection_step.return_value = _make_tool_call_response( "log_selected_documents", @@ -148,13 +150,126 @@ def test_run_agentic_recall_evaluation_computes_metrics(mock_react_step, mock_se ) assert df_query["golden_answer"].tolist() == ["doc_1"] - assert result["doc_id"].tolist() == ["doc_1"] + assert result["doc_id"].tolist()[0] == "doc_1" assert qrels == {"0": {"doc_1": 1}} - assert run["0"]["doc_1"] == 1.0 + assert run["0"]["doc_1"] == 10.0 assert metrics["recall@1"] == 1.0 assert metrics["ndcg@1"] == 1.0 +@patch("nemo_retriever.graph.selection_agent_operator.invoke_chat_completion_step") +@patch("nemo_retriever.graph.react_agent_operator.invoke_chat_completion_step") +@patch("nemo_retriever.agentic.retrieval.Retriever", FakeRetriever) +def test_run_agentic_beir_evaluation_loads_queries_and_qrels(mock_react_step, mock_selection_step): + from nemo_retriever.agentic.retrieval import AgenticRetrievalConfig, run_agentic_beir_evaluation + from nemo_retriever.recall.beir import BeirDataset + + final_ids = ["doc"] + [f"extra_{i}" for i in range(9)] + mock_react_step.return_value = _make_tool_call_response( + "final_results", + {"doc_ids": final_ids, "message": "done", "search_successful": "true"}, + ) + mock_selection_step.return_value = _make_tool_call_response( + "log_selected_documents", + {"doc_ids": ["doc"], "message": "doc is best"}, + ) + + beir_dataset = BeirDataset( + dataset_name="vidore_v3_finance_en", + query_ids=["q1"], + queries=["find doc"], + qrels={"q1": {"doc": 1}}, + ) + cfg = AgenticRetrievalConfig(llm_model="test-model", invoke_url="http://localhost/v1/chat/completions") + + with patch("nemo_retriever.agentic.retrieval.load_beir_dataset", return_value=beir_dataset) as mock_loader: + df_query, result, qrels, run, metrics = run_agentic_beir_evaluation( + loader="vidore_hf", + dataset_name="vidore_v3_finance_en", + cfg=cfg, + doc_id_field="pdf_basename", + ks=(1, 5, 10), + ) + + mock_loader.assert_called_once() + assert df_query["query_id"].tolist() == ["q1"] + assert result["doc_id"].tolist()[0] == "doc" + assert qrels == {"q1": {"doc": 1}} + assert run["q1"]["doc"] == 10.0 + assert metrics["recall@1"] == 1.0 + + +def test_pipeline_agentic_beir_wires_config_options(): + from nemo_retriever.pipeline.__main__ import _run_agentic_evaluation + + captured = {} + + def fake_run_agentic_beir_evaluation(**kwargs): + captured.update(kwargs) + return ( + pd.DataFrame({"query_id": ["q1"]}), + pd.DataFrame({"query_id": ["q1"], "doc_id": ["doc"], "rank": [1]}), + {"q1": {"doc": 1}}, + {"q1": {"doc": 10.0}}, + {"recall@1": 1.0}, + ) + + with ( + patch("nemo_retriever.model.resolve_embed_model", return_value="resolved-embed"), + patch( + "nemo_retriever.agentic.retrieval.run_agentic_beir_evaluation", side_effect=fake_run_agentic_beir_evaluation + ), + ): + label, _elapsed, metrics, query_count, ran = _run_agentic_evaluation( + evaluation_mode="beir", + vdb_op="lancedb", + vdb_kwargs={"uri": "db", "table_name": "tbl"}, + embed_model_name="embed", + embed_invoke_url="http://embed/v1", + embed_remote_api_key="embed-key", + embed_modality="text", + query_csv=None, + recall_match_mode="pdf_page", + reranker=False, + reranker_model_name="reranker", + reranker_invoke_url=None, + reranker_api_key="", + local_reranker_backend="vllm", + local_hf_batch_size=4, + local_query_embed_backend="hf", + agentic_llm_model="llm", + agentic_invoke_url="http://llm/v1/chat/completions", + agentic_api_key="llm-key", + agentic_react_max_steps=51, + agentic_backend_top_k=23, + agentic_text_truncation=99, + agentic_reasoning_effort="high", + agentic_num_concurrent=7, + beir_loader="vidore_hf", + beir_dataset_name="vidore_v3_finance_en", + beir_split="test", + beir_query_language=None, + beir_doc_id_field="pdf_basename", + beir_k=[1, 5], + ) + + cfg = captured["cfg"] + assert label == "Agentic BEIR" + assert metrics["recall@1"] == 1.0 + assert query_count == 1 + assert ran is True + assert captured["loader"] == "vidore_hf" + assert captured["dataset_name"] == "vidore_v3_finance_en" + assert captured["doc_id_field"] == "pdf_basename" + assert captured["ks"] == (1, 5) + assert cfg.query_embedder == "resolved-embed" + assert cfg.react_max_steps == 51 + assert cfg.backend_top_k == 23 + assert cfg.text_truncation == 99 + assert cfg.reasoning_effort == "high" + assert cfg.num_concurrent == 7 + + def test_agentic_config_requires_llm_model(): from nemo_retriever.agentic.retrieval import AgenticRetrievalConfig @@ -180,7 +295,8 @@ def test_pipeline_rejects_agentic_qa_mode(): assert result.exit_code != 0 assert "--retrieval-mode=agentic is currently supported only with" in result.output - assert "--evaluation-mode=recall" in result.output + assert "--evaluation-mode=audio_recall" in result.output + assert "--evaluation-mode=beir" in result.output def test_pipeline_requires_agentic_llm_model(): @@ -191,7 +307,11 @@ def test_pipeline_requires_agentic_llm_model(): [ ".", "--evaluation-mode", - "recall", + "audio_recall", + "--input-type", + "audio", + "--recall-match-mode", + "audio_segment", "--retrieval-mode", "agentic", ], diff --git a/nemo_retriever/tests/test_agentic_operators.py b/nemo_retriever/tests/test_agentic_operators.py index 78ae3dd1e2..ee60b9d0bf 100644 --- a/nemo_retriever/tests/test_agentic_operators.py +++ b/nemo_retriever/tests/test_agentic_operators.py @@ -75,6 +75,17 @@ def test_output_schema(self): result = op.run(self._make_input()) assert set(result.columns) >= {"query_id", "query_text", "doc_id", "rrf_score", "text"} + def test_carries_react_final_rank(self): + from nemo_retriever.graph.rrf_aggregator_operator import RRFAggregatorOperator + + df = self._make_input() + df["is_final_result"] = [False, False, True, False, False, False] + op = RRFAggregatorOperator(k=60) + result = op.run(df) + + q1 = result[result["query_id"] == "q1"].set_index("doc_id") + assert int(q1.loc["d1", "react_final_rank"]) == 1 + def test_missing_column_raises(self): from nemo_retriever.graph.rrf_aggregator_operator import RRFAggregatorOperator @@ -181,10 +192,11 @@ def test_happy_path_selects_docs(self, mock_step): ) result = op.run(self._make_input()) - assert list(result.columns) == ["query_id", "doc_id", "rank", "message"] + assert set(result.columns) >= {"query_id", "doc_id", "rank", "message", "result_source"} assert result["query_id"].tolist() == ["q1", "q1"] assert result["doc_id"].tolist() == ["d1", "d2"] assert result["rank"].tolist() == [1, 2] + assert result["result_source"].tolist() == ["selection_agent", "selection_agent"] @patch("nemo_retriever.graph.selection_agent_operator.invoke_chat_completion_step") def test_think_then_select(self, mock_step): @@ -228,6 +240,82 @@ def capture_and_respond(**kwargs): assert "RELEVANCE_DEFINITION" in captured_prompts[0] + @patch("nemo_retriever.graph.selection_agent_operator.invoke_chat_completion_step") + def test_final_results_policy_skips_selection_agent(self, mock_step): + from nemo_retriever.graph.selection_agent_operator import SelectionAgentOperator + + op = SelectionAgentOperator( + llm_model="test-model", + invoke_url="http://localhost/v1/chat/completions", + top_k=2, + ) + df = pd.DataFrame( + { + "query_id": ["q1", "q1", "q1"], + "query_text": ["What causes inflation?"] * 3, + "doc_id": ["d1", "d2", "d3"], + "text": ["doc one", "doc two", "doc three"], + "rrf_score": [0.1, 0.9, 0.8], + "react_final_rank": [2, None, 1], + } + ) + result = op.run(df) + + assert result["doc_id"].tolist() == ["d3", "d1"] + assert result["result_source"].tolist() == ["final_results", "final_results"] + mock_step.assert_not_called() + + @patch("nemo_retriever.graph.selection_agent_operator.invoke_chat_completion_step") + def test_result_policy_uses_rrf_before_selection_when_no_final_results(self, mock_step): + from nemo_retriever.graph.selection_agent_operator import SelectionAgentOperator + + op = SelectionAgentOperator( + llm_model="test-model", + invoke_url="http://localhost/v1/chat/completions", + top_k=2, + ) + df = pd.DataFrame( + { + "query_id": ["q1", "q1", "q1"], + "query_text": ["What causes inflation?"] * 3, + "doc_id": ["d1", "d2", "d3"], + "text": ["doc one", "doc two", "doc three"], + "rrf_score": [0.1, 0.9, 0.8], + "react_final_rank": [None, None, None], + } + ) + result = op.run(df) + + assert result["doc_id"].tolist() == ["d2", "d3"] + assert result["result_source"].tolist() == ["rrf", "rrf"] + mock_step.assert_not_called() + + @patch("nemo_retriever.graph.selection_agent_operator.invoke_chat_completion_step") + def test_empty_final_results_is_not_valid_and_falls_back_to_rrf(self, mock_step): + from nemo_retriever.graph.selection_agent_operator import SelectionAgentOperator + + op = SelectionAgentOperator( + llm_model="test-model", + invoke_url="http://localhost/v1/chat/completions", + top_k=2, + ) + df = pd.DataFrame( + { + "query_id": ["q1", "q1"], + "query_text": ["What causes inflation?"] * 2, + "doc_id": ["d1", "d2"], + "text": ["doc one", "doc two"], + "rrf_score": [0.9, 0.8], + "has_valid_final_results": [False, False], + "react_final_rank": [None, None], + } + ) + result = op.run(df) + + assert result["doc_id"].tolist() == ["d1", "d2"] + assert result["result_source"].tolist() == ["rrf", "rrf"] + mock_step.assert_not_called() + # --------------------------------------------------------------------------- # ReActAgentOperator — mock retriever_fn + invoke_chat_completion_step @@ -305,6 +393,65 @@ def test_with_results_mode_initial_retrieval(self, mock_step): assert retriever.call_count >= 1 assert 0 in result["step_idx"].values + def test_backend_top_k_caps_fetch_depth_and_replays_seen_docs(self): + from nemo_retriever.graph.react_agent_operator import ReActAgentOperator + + calls = [] + docs = [ + {"doc_id": "d1", "text": "already seen", "score": 0.9}, + {"doc_id": "d2", "text": "new two", "score": 0.8}, + {"doc_id": "d3", "text": "new three", "score": 0.7}, + {"doc_id": "d4", "text": "outside backend cap", "score": 0.6}, + ] + + def retriever_fn(query_text, top_k): + calls.append((query_text, top_k)) + return docs[:top_k] + + op = ReActAgentOperator( + invoke_url="http://localhost/v1/chat/completions", + llm_model="test-model", + retriever_fn=retriever_fn, + retriever_top_k=2, + backend_top_k=3, + ) + + result = op._call_retriever("inflation", {"d1"}, api_key=None) + + assert calls == [("inflation", 3)] + assert [doc["doc_id"] for doc in result] == ["d1", "d2", "d3"] + assert "retrieved before" in result[0]["text"] + assert result[1]["text"] == "new two" + + @pytest.mark.parametrize( + ("fn_args", "target_top_k", "enforce_top_k"), + [ + ({"doc_ids": [1], "message": "bad id type", "search_successful": "true"}, 1, False), + ({"doc_ids": [], "message": "empty", "search_successful": "false"}, 1, False), + ({"doc_ids": ["d1"], "message": "wrong count", "search_successful": "true"}, 2, True), + ({"doc_ids": ["d1"], "message": "bad status", "search_successful": "yes"}, 1, False), + ], + ) + @patch("nemo_retriever.graph.react_agent_operator.invoke_chat_completion_step") + def test_invalid_final_results_are_rejected(self, mock_step, fn_args, target_top_k, enforce_top_k): + from nemo_retriever.graph.react_agent_operator import ReActAgentOperator + + mock_step.return_value = _make_tool_call_response("final_results", fn_args) + retriever = MagicMock(return_value=[{"doc_id": "d1", "text": "monetary policy"}]) + + op = ReActAgentOperator( + invoke_url="http://localhost/v1/chat/completions", + llm_model="test-model", + retriever_fn=retriever, + user_msg_type="with_results", + target_top_k=target_top_k, + enforce_top_k=enforce_top_k, + ) + result = op.run(self._make_input()) + + assert not result["is_final_result"].astype(bool).any() + assert not result["has_valid_final_results"].astype(bool).any() + @patch("nemo_retriever.graph.react_agent_operator.invoke_chat_completion_step") def test_output_row_structure(self, mock_step): from nemo_retriever.graph.react_agent_operator import ReActAgentOperator @@ -321,6 +468,7 @@ def test_output_row_structure(self, mock_step): llm_model="test-model", retriever_fn=self._make_retriever(), user_msg_type="simple", + target_top_k=1, ) result = op.run(self._make_input()) @@ -328,6 +476,35 @@ def test_output_row_structure(self, mock_step): assert result["step_idx"].dtype in (int, "int64") assert result["doc_id"].notna().all() + @patch("nemo_retriever.graph.react_agent_operator.invoke_chat_completion_step") + def test_no_final_results_falls_back_to_retrieval_log(self, mock_step): + from nemo_retriever.graph.react_agent_operator import ReActAgentOperator + + mock_step.side_effect = [ + _make_tool_call_response("retrieve", {"query": "inflation monetary policy"}), + _make_tool_call_response("think", {"thought": "still reasoning"}), + ] + retriever = MagicMock( + return_value=[ + {"doc_id": "d1", "text": "monetary policy"}, + {"doc_id": "d2", "text": "supply chains"}, + ] + ) + + op = ReActAgentOperator( + invoke_url="http://localhost/v1/chat/completions", + llm_model="test-model", + retriever_fn=retriever, + user_msg_type="simple", + target_top_k=2, + max_steps=2, + ) + result = op.run(self._make_input()) + + assert result["doc_id"].tolist() == ["d1", "d2"] + assert result["rank"].tolist() == [1, 2] + assert retriever.call_count == 1 + @patch("nemo_retriever.graph.selection_agent_operator.invoke_chat_completion_step") @patch("nemo_retriever.graph.react_agent_operator.invoke_chat_completion_step") def test_pipeline_end_to_end_with_mocks(self, mock_react_step, mock_selection_step): @@ -345,7 +522,7 @@ def test_pipeline_end_to_end_with_mocks(self, mock_react_step, mock_selection_st mock_react_step.side_effect = [ _make_tool_call_response("retrieve", {"query": "inflation"}), _make_tool_call_response( - "final_results", {"doc_ids": ["d1", "d2"], "message": "ok", "search_successful": "true"} + "final_results", {"doc_ids": ["d1"], "message": "ok", "search_successful": "true"} ), ] # Selection: immediately log_selected_documents @@ -375,7 +552,7 @@ def retriever_fn(query_text, top_k): query_df = pd.DataFrame({"query_id": ["q1"], "query_text": ["What causes inflation?"]}) result = InprocessExecutor(pipeline).ingest(query_df) - assert list(result.columns) == ["query_id", "doc_id", "rank", "message"] + assert set(result.columns) >= {"query_id", "doc_id", "rank", "message", "result_source"} assert result["query_id"].tolist() == ["q1"] assert result["rank"].tolist() == [1] @@ -569,10 +746,9 @@ def test_non_dataframe_raises(self): class TestSelectionAgentMaxSteps: @patch("nemo_retriever.graph.selection_agent_operator.invoke_chat_completion_step") - def test_max_steps_exhausted_returns_empty(self, mock_step): + def test_rrf_candidates_skip_selection_agent(self, mock_step): from nemo_retriever.graph.selection_agent_operator import SelectionAgentOperator - # LLM only ever calls think — never log_selected_documents mock_step.return_value = _make_tool_call_response("think", {"thought": "still thinking..."}) op = SelectionAgentOperator( @@ -583,14 +759,17 @@ def test_max_steps_exhausted_returns_empty(self, mock_step): ) df = pd.DataFrame( { - "query_id": ["q1", "q1"], - "query_text": ["What causes inflation?"] * 2, - "doc_id": ["d1", "d2"], - "text": ["doc one", "doc two"], + "query_id": ["q1", "q1", "q1"], + "query_text": ["What causes inflation?"] * 3, + "doc_id": ["d1", "d2", "d3"], + "text": ["doc one", "doc two", "doc three"], + "rrf_score": [0.2, 0.9, 0.5], } ) result = op.run(df) - assert len(result) == 0 - assert list(result.columns) == ["query_id", "doc_id", "rank", "message"] - assert mock_step.call_count == 3 + assert result["doc_id"].tolist() == ["d2", "d3"] + assert result["rank"].tolist() == [1, 2] + assert result["message"].tolist() == ["Using RRF ranking."] * 2 + assert result["result_source"].tolist() == ["rrf", "rrf"] + mock_step.assert_not_called() From 9aaca6a135616cdce1ba708601ab99e9073a9c86 Mon Sep 17 00:00:00 2001 From: Mahika Wason Date: Tue, 9 Jun 2026 09:59:58 -0700 Subject: [PATCH 4/6] cleanup Signed-off-by: Mahika Wason --- nemo_retriever/README.md | 9 ++++++--- .../src/nemo_retriever/pipeline/__main__.py | 3 ++- nemo_retriever/tests/test_agentic_eval.py | 19 +++++++++++++++++++ 3 files changed, 27 insertions(+), 4 deletions(-) diff --git a/nemo_retriever/README.md b/nemo_retriever/README.md index 1957034aa7..4f2a21323e 100644 --- a/nemo_retriever/README.md +++ b/nemo_retriever/README.md @@ -266,6 +266,8 @@ BEIR-style dataset: ```bash retriever pipeline run ./data \ + --vdb-op lancedb \ + --vdb-kwargs-json '{"uri":"","table_name":""}' \ --evaluation-mode beir \ --retrieval-mode agentic \ --beir-loader vidore_hf \ @@ -279,9 +281,10 @@ retriever pipeline run ./data \ ``` Common BEIR options are `--beir-split`, `--beir-query-language`, and -`--beir-doc-id-field`. Agentic controls include `--agentic-react-max-steps` -(default `50`), `--agentic-backend-top-k` (default `20`), and -`--agentic-text-truncation` (`0` disables truncation), +`--beir-doc-id-field`. Use `--vdb-kwargs-json` to point evaluation at the +LanceDB URI and table for the indexed corpus. Agentic controls include +`--agentic-react-max-steps` (default `50`), `--agentic-backend-top-k` (default +`20`), and `--agentic-text-truncation` (`0` disables truncation), `--agentic-reasoning-effort`, and `--agentic-num-concurrent`. Throughput with high concurrency is bounded by the configured LLM endpoint. diff --git a/nemo_retriever/src/nemo_retriever/pipeline/__main__.py b/nemo_retriever/src/nemo_retriever/pipeline/__main__.py index e61fda1d70..836b4616d8 100644 --- a/nemo_retriever/src/nemo_retriever/pipeline/__main__.py +++ b/nemo_retriever/src/nemo_retriever/pipeline/__main__.py @@ -1494,7 +1494,8 @@ def run( if evaluation_mode not in {"none", "audio_recall", "beir", "qa"}: raise ValueError(f"Unsupported --evaluation-mode: {evaluation_mode!r}") if retrieval_mode not in {"standard", "agentic"}: - raise ValueError(f"Unsupported --retrieval-mode: {retrieval_mode!r}") + logger.warning("Unsupported --retrieval-mode=%r; falling back to 'standard'.", retrieval_mode) + retrieval_mode = "standard" if retrieval_mode == "agentic" and evaluation_mode not in {"audio_recall", "beir"}: raise typer.BadParameter( "--retrieval-mode=agentic is currently supported only with --evaluation-mode=audio_recall or " diff --git a/nemo_retriever/tests/test_agentic_eval.py b/nemo_retriever/tests/test_agentic_eval.py index abe1b9fe25..304a43f550 100644 --- a/nemo_retriever/tests/test_agentic_eval.py +++ b/nemo_retriever/tests/test_agentic_eval.py @@ -299,6 +299,25 @@ def test_pipeline_rejects_agentic_qa_mode(): assert "--evaluation-mode=beir" in result.output +def test_pipeline_invalid_retrieval_mode_falls_back_to_standard(): + from nemo_retriever.pipeline.__main__ import app + + result = CliRunner().invoke( + app, + [ + ".", + "--evaluation-mode", + "qa", + "--retrieval-mode", + "unknown", + ], + ) + + assert result.exit_code != 0 + assert "falling back to 'standard'" in result.output + assert "--evaluation-mode=qa requires --eval-config" in result.output + + def test_pipeline_requires_agentic_llm_model(): from nemo_retriever.pipeline.__main__ import app From 9e12581bcd079026ea119312fb99084390183ad7 Mon Sep 17 00:00:00 2001 From: Mahika Wason Date: Tue, 9 Jun 2026 19:40:33 +0000 Subject: [PATCH 5/6] cleanup Signed-off-by: Mahika Wason --- nemo_retriever/README.md | 92 +++++++++++---- .../src/nemo_retriever/pipeline/__main__.py | 25 +++- nemo_retriever/tests/test_agentic_eval.py | 110 ++++++++++++++++++ .../tests/test_graph_pipeline_cli.py | 6 +- 4 files changed, 203 insertions(+), 30 deletions(-) diff --git a/nemo_retriever/README.md b/nemo_retriever/README.md index 4f2a21323e..cb0118219c 100644 --- a/nemo_retriever/README.md +++ b/nemo_retriever/README.md @@ -259,34 +259,84 @@ hits = retriever.query(query) {'text': '| Table | 1 |\n| This | table | describes | some | animals, | and | some | activities | they | might | be | doing | in | specific |\n| locations. |\n| Animal | Activity | Place |\n| Giraffe | Driving | a | car | At | the | beach |\n| Lion | Putting | on | sunscreen | At | the | park |\n| Cat | Jumping | onto | a | laptop | In | a | home | office |\n| Dog | Chasing | a | squirrel | In | the | front | yard |\n| Chart | 1 |', 'metadata': '{"page_number": 1, "pdf_page": "multimodal_test_1", "page_elements_v3_num_detections": 9, "page_elements_v3_counts_by_label": {"table": 1, "chart": 1, "title": 3, "text": 4}, "ocr_table_detections": 1, "ocr_chart_detections": 1, "ocr_infographic_detections": 0}', 'source': '{"source_id": "/home/dev/projects/NeMo-Retriever/data/multimodal_test.pdf"}', 'page_number': 1, '_distance': 1.614684820175171} ``` -### Agentic BEIR evaluation +### Agentic retrieval evaluation -The pipeline CLI can evaluate a LanceDB corpus with agentic retrieval against a -BEIR-style dataset: +Agentic retrieval swaps the single dense-retrieval pass for an LLM-driven ReAct +loop: the agent issues several retrieval sub-queries, fuses the candidates with +reciprocal rank fusion, and selects a final ranking. You evaluate it the same way +you evaluate standard retrieval — by scoring the ranked results against ground +truth. + +`retriever pipeline run` ingests your corpus exactly as in +[Ingest a test corpus (CLI)](#ingest-a-test-corpus-cli), then scores agentic +retrieval against that ground truth. Add `--retrieval-mode agentic` and name the +chat model the agent drives with `--agentic-llm-model`. The simplest form scores +against your own query CSV (columns `query` and `golden_answer`), so no dataset +loader is needed: ```bash retriever pipeline run ./data \ --vdb-op lancedb \ - --vdb-kwargs-json '{"uri":"","table_name":""}' \ - --evaluation-mode beir \ + --vdb-kwargs-json '{"uri":"lancedb","table_name":"nemo-retriever"}' \ + --evaluation-mode recall \ --retrieval-mode agentic \ - --beir-loader vidore_hf \ - --beir-dataset-name vidore_v3_finance_en \ - --beir-doc-id-field pdf_basename \ - --agentic-llm-model nvidia/llama-3.3-nemotron-super-49b-v1.5 \ + --query-csv ./queries.csv \ + --recall-match-mode pdf_page \ + --agentic-llm-model nvidia/llama-3.3-nemotron-super-49b-v1.5 +``` + +`--recall-match-mode` is `pdf_page` or `pdf_only`, depending on whether +`golden_answer` names a page or a whole document. + +#### Optional extras + +- **Remote inference (no local GPU)** — drive the agent and embedder through NIM + endpoints instead of local models: + ```bash --agentic-invoke-url http:///v1/chat/completions \ - --embed-invoke-url http:///v1 \ - --agentic-reasoning-effort high \ - --agentic-num-concurrent 10 -``` - -Common BEIR options are `--beir-split`, `--beir-query-language`, and -`--beir-doc-id-field`. Use `--vdb-kwargs-json` to point evaluation at the -LanceDB URI and table for the indexed corpus. Agentic controls include -`--agentic-react-max-steps` (default `50`), `--agentic-backend-top-k` (default -`20`), and `--agentic-text-truncation` (`0` disables truncation), -`--agentic-reasoning-effort`, and `--agentic-num-concurrent`. Throughput with -high concurrency is bounded by the configured LLM endpoint. + --embed-invoke-url http:///v1 + ``` +- **BEIR-style datasets** — score against a registered benchmark instead of a + query CSV. HuggingFace-hosted sets (the `vidore_hf` loader) need the `datasets` + package, which the `benchmarks` extra provides: + ```bash + uv pip install "nemo-retriever[local,benchmarks]==26.05-RC1" + ``` + ```bash + --evaluation-mode beir \ + --beir-loader vidore_hf \ + --beir-dataset-name \ + --beir-doc-id-field pdf_basename + ``` + Built-in loaders include `vidore_hf` (HuggingFace download) and + `financebench_json`. `--beir-split` and `--beir-query-language` select the + split and language. +- **Image + text corpora** — for page-image benchmarks, ingest rendered pages so + the agent retrieves over page images, matching the + [ViDoRe Harness Sweep](#vidore-harness-sweep): + ```bash + --embed-model-name nvidia/llama-nemotron-embed-vl-1b-v2 \ + --embed-modality text_image \ + --embed-granularity page \ + --extract-page-as-image \ + --extract-infographics + ``` +- **Tune the agent** — each flag controls a different stage of the loop: + - `--agentic-react-max-steps` (default `50`) — how many think → retrieve rounds + the agent may take per query before it has to answer. + - `--agentic-backend-top-k` (default `20`) — how many candidates each retrieval + call pulls from the vector DB (the pool the agent reasons over and fuses). + - `--agentic-text-truncation` (default `0`) — max characters of each candidate's + text shown to the agent; `0` sends the full text. + - `--agentic-reasoning-effort` (default `high`) — the OpenAI-compatible + reasoning depth (`low`/`medium`/`high`) requested per LLM call. + - `--agentic-num-concurrent` (default `1`) — how many queries are evaluated in + parallel, bounded by the LLM endpoint's throughput. +- **Logging** — per-query agent progress is logged at `INFO` by default (`--quiet` + suppresses it, `--debug` adds detail). There is no default log file — output + goes to the console; pass `--log-file ./run.log` to also write it to a file. + Pass `--runtime-metrics-dir ./out` to write a JSON summary of the metrics and + timing alongside the run. ### Generate a query answer using an LLM The above retrieval results are often feedable directly to an LLM for answer generation. diff --git a/nemo_retriever/src/nemo_retriever/pipeline/__main__.py b/nemo_retriever/src/nemo_retriever/pipeline/__main__.py index 836b4616d8..1b1751eaaa 100644 --- a/nemo_retriever/src/nemo_retriever/pipeline/__main__.py +++ b/nemo_retriever/src/nemo_retriever/pipeline/__main__.py @@ -902,7 +902,7 @@ def _run_agentic_evaluation( ks=tuple(beir_k) if beir_k else (1, 3, 5, 10), ) evaluation_label = "Agentic BEIR" - elif evaluation_mode == "audio_recall": + elif evaluation_mode in ("recall", "audio_recall"): if query_csv is None: logger.warning("No query CSV configured; skipping agentic recall evaluation.") return "Agentic Recall", 0.0, {}, None, False @@ -1329,7 +1329,7 @@ def run( evaluation_mode: str = typer.Option( "none", "--evaluation-mode", - help="Post-ingest evaluation: none (default), audio_recall, beir, or qa.", + help="Post-ingest evaluation: none (default), recall, audio_recall, beir, or qa.", rich_help_panel=_PANEL_EVAL, ), retrieval_mode: str = typer.Option( @@ -1491,21 +1491,34 @@ def run( _reject_service_incompatible_flags(ctx) if audio_split_type not in {"size", "time", "frame"}: raise ValueError(f"Unsupported --audio-split-type: {audio_split_type!r}") - if evaluation_mode not in {"none", "audio_recall", "beir", "qa"}: + if evaluation_mode not in {"none", "audio_recall", "beir", "qa", "recall"}: raise ValueError(f"Unsupported --evaluation-mode: {evaluation_mode!r}") if retrieval_mode not in {"standard", "agentic"}: logger.warning("Unsupported --retrieval-mode=%r; falling back to 'standard'.", retrieval_mode) retrieval_mode = "standard" - if retrieval_mode == "agentic" and evaluation_mode not in {"audio_recall", "beir"}: + if retrieval_mode == "agentic" and evaluation_mode not in {"recall", "audio_recall", "beir"}: raise typer.BadParameter( - "--retrieval-mode=agentic is currently supported only with --evaluation-mode=audio_recall or " - "--evaluation-mode=beir." + "--retrieval-mode=agentic is currently supported only with --evaluation-mode=recall, " + "--evaluation-mode=audio_recall, or --evaluation-mode=beir." ) if evaluation_mode == "audio_recall": if input_type != "audio": raise ValueError("--evaluation-mode=audio_recall is only supported with --input-type=audio") if recall_match_mode != "audio_segment": raise ValueError("--evaluation-mode=audio_recall requires --recall-match-mode=audio_segment") + if evaluation_mode == "recall": + # Generic agentic recall over a query CSV (no per-dataset BEIR loader needed). + # Standard retrieval's recall path is audio-only, so this mode is agentic-only. + if retrieval_mode != "agentic": + raise typer.BadParameter( + "--evaluation-mode=recall is currently supported only with --retrieval-mode=agentic; " + "use --evaluation-mode=beir or audio_recall for standard retrieval." + ) + if recall_match_mode not in {"pdf_page", "pdf_only"}: + raise typer.BadParameter( + "--evaluation-mode=recall requires --recall-match-mode=pdf_page or pdf_only " + "(the query CSV's golden_answer column maps to PDF page/document IDs)." + ) if evaluation_mode == "qa" and eval_config is None: raise typer.BadParameter( "--evaluation-mode=qa requires --eval-config (QA sweep YAML/JSON). " diff --git a/nemo_retriever/tests/test_agentic_eval.py b/nemo_retriever/tests/test_agentic_eval.py index 304a43f550..2a9baa82fa 100644 --- a/nemo_retriever/tests/test_agentic_eval.py +++ b/nemo_retriever/tests/test_agentic_eval.py @@ -338,3 +338,113 @@ def test_pipeline_requires_agentic_llm_model(): assert result.exit_code != 0 assert "--retrieval-mode=agentic requires --agentic-llm-model" in result.output + + +def test_pipeline_agentic_recall_wires_query_csv(tmp_path): + from nemo_retriever.pipeline.__main__ import _run_agentic_evaluation + + captured = {} + + def fake_run_agentic_recall_evaluation(**kwargs): + captured.update(kwargs) + return ( + pd.DataFrame({"query_id": ["q1"]}), + pd.DataFrame({"query_id": ["q1"], "doc_id": ["doc"], "rank": [1]}), + {"q1": {"doc": 1}}, + {"q1": {"doc": 10.0}}, + {"recall@1": 1.0}, + ) + + query_csv = tmp_path / "queries.csv" + query_csv.write_text("query,golden_answer\nwhat is x,doc\n", encoding="utf-8") + + with ( + patch("nemo_retriever.model.resolve_embed_model", return_value="resolved-embed"), + patch( + "nemo_retriever.agentic.retrieval.run_agentic_recall_evaluation", + side_effect=fake_run_agentic_recall_evaluation, + ), + ): + label, _elapsed, metrics, query_count, ran = _run_agentic_evaluation( + evaluation_mode="recall", + vdb_op="lancedb", + vdb_kwargs={"uri": "db", "table_name": "tbl"}, + embed_model_name="embed", + embed_invoke_url="http://embed/v1", + embed_remote_api_key="embed-key", + embed_modality="text", + query_csv=query_csv, + recall_match_mode="pdf_page", + reranker=False, + reranker_model_name="reranker", + reranker_invoke_url=None, + reranker_api_key="", + local_reranker_backend="vllm", + local_hf_batch_size=4, + local_query_embed_backend="hf", + agentic_llm_model="llm", + agentic_invoke_url="http://llm/v1/chat/completions", + agentic_api_key="llm-key", + agentic_react_max_steps=50, + agentic_backend_top_k=20, + agentic_text_truncation=0, + agentic_reasoning_effort="high", + agentic_num_concurrent=1, + beir_loader=None, + beir_dataset_name=None, + beir_split="test", + beir_query_language=None, + beir_doc_id_field="pdf_basename", + beir_k=[1, 5, 10], + ) + + assert label == "Agentic Recall" + assert ran is True + assert metrics["recall@1"] == 1.0 + assert query_count == 1 + assert captured["query_csv"] == query_csv + assert captured["match_mode"] == "pdf_page" + + +def test_pipeline_recall_agentic_requires_pdf_match_mode(): + from nemo_retriever.pipeline.__main__ import app + + # Default --recall-match-mode is audio_segment, which is invalid for recall. + result = CliRunner().invoke( + app, + [ + ".", + "--evaluation-mode", + "recall", + "--retrieval-mode", + "agentic", + "--agentic-llm-model", + "test-model", + "--query-csv", + "queries.csv", + ], + ) + + assert result.exit_code != 0 + assert "--evaluation-mode=recall requires" in result.output + assert "pdf_only" in result.output + + +def test_pipeline_recall_requires_agentic(): + from nemo_retriever.pipeline.__main__ import app + + result = CliRunner().invoke( + app, + [ + ".", + "--evaluation-mode", + "recall", + "--retrieval-mode", + "standard", + "--recall-match-mode", + "pdf_page", + ], + ) + + assert result.exit_code != 0 + assert "--evaluation-mode=recall is currently supported only" in result.output diff --git a/nemo_retriever/tests/test_graph_pipeline_cli.py b/nemo_retriever/tests/test_graph_pipeline_cli.py index 8d59b721f8..cbef965dd9 100644 --- a/nemo_retriever/tests/test_graph_pipeline_cli.py +++ b/nemo_retriever/tests/test_graph_pipeline_cli.py @@ -425,16 +425,16 @@ def test_graph_pipeline_cli_allows_default_evaluation_for_pdf_inputs(tmp_path, m assert isinstance(fake_ingestor.file_patterns, list) -def test_graph_pipeline_cli_rejects_invalid_recall_mode(tmp_path) -> None: +def test_graph_pipeline_cli_rejects_invalid_evaluation_mode(tmp_path) -> None: dataset_dir = tmp_path / "dataset" dataset_dir.mkdir() (dataset_dir / "sample.pdf").write_text("placeholder", encoding="utf-8") - result = RUNNER.invoke(batch_pipeline.app, [str(dataset_dir), "--evaluation-mode", "recall"]) + result = RUNNER.invoke(batch_pipeline.app, [str(dataset_dir), "--evaluation-mode", "bogus"]) assert result.exit_code != 0 assert result.exception is not None - assert "Unsupported --evaluation-mode: 'recall'" in str(result.exception) + assert "Unsupported --evaluation-mode: 'bogus'" in str(result.exception) def test_graph_pipeline_cli_rejects_audio_recall_for_pdf_inputs(tmp_path) -> None: From 8c0af288361e6353c0ef8141627e285e45785c65 Mon Sep 17 00:00:00 2001 From: Mahika Wason Date: Tue, 9 Jun 2026 20:26:40 +0000 Subject: [PATCH 6/6] added review fixes Signed-off-by: Mahika Wason --- nemo_retriever/src/nemo_retriever/agentic/__init__.py | 2 ++ nemo_retriever/src/nemo_retriever/agentic/retrieval.py | 9 ++++++++- .../src/nemo_retriever/graph/react_agent_operator.py | 6 ++++-- .../src/nemo_retriever/graph/selection_agent_operator.py | 3 ++- 4 files changed, 16 insertions(+), 4 deletions(-) diff --git a/nemo_retriever/src/nemo_retriever/agentic/__init__.py b/nemo_retriever/src/nemo_retriever/agentic/__init__.py index f0536322bc..dd93570767 100644 --- a/nemo_retriever/src/nemo_retriever/agentic/__init__.py +++ b/nemo_retriever/src/nemo_retriever/agentic/__init__.py @@ -9,6 +9,7 @@ AgenticRetriever, build_beir_run_from_agentic_result, build_qrels, + run_agentic_beir_evaluation, run_agentic_recall_evaluation, ) @@ -17,5 +18,6 @@ "AgenticRetriever", "build_beir_run_from_agentic_result", "build_qrels", + "run_agentic_beir_evaluation", "run_agentic_recall_evaluation", ] diff --git a/nemo_retriever/src/nemo_retriever/agentic/retrieval.py b/nemo_retriever/src/nemo_retriever/agentic/retrieval.py index bb36e59184..dfc025b6c0 100644 --- a/nemo_retriever/src/nemo_retriever/agentic/retrieval.py +++ b/nemo_retriever/src/nemo_retriever/agentic/retrieval.py @@ -244,7 +244,14 @@ def retrieve(self, query_ids: Sequence[str], query_texts: Sequence[str]) -> pd.D return _raw_hits_to_agentic_result([str(query_id) for query_id in query_ids], raw_hits) def _retrieve_for_agent(self, query_text: str, top_k: int) -> list[dict[str, Any]]: - """Retriever callback used by ``ReActAgentOperator``.""" + """Retriever callback used by ``ReActAgentOperator``. + + Retrieval is serialized across concurrent ReAct workers via ``self._lock`` + because the shared ``Retriever`` is not assumed thread-safe. This caps the + retrieval hop at single-threaded throughput; it is intentional and not the + bottleneck, since per-query cost is dominated by the multi-step LLM calls, + which still run concurrently under ``num_concurrent > 1``. + """ with self._lock: hits = self._retriever.query(str(query_text), top_k=int(top_k)) diff --git a/nemo_retriever/src/nemo_retriever/graph/react_agent_operator.py b/nemo_retriever/src/nemo_retriever/graph/react_agent_operator.py index 785c20d357..1a04dad94f 100644 --- a/nemo_retriever/src/nemo_retriever/graph/react_agent_operator.py +++ b/nemo_retriever/src/nemo_retriever/graph/react_agent_operator.py @@ -629,7 +629,8 @@ def _run_single_query( finish_reason = choice.get("finish_reason") tool_calls = msg.get("tool_calls") or [] if msg.get("content"): - logger.info( + # Agent reasoning can quote document text/PII; keep content at DEBUG. + logger.debug( "ReActAgentOperator: query=%s step=%d assistant content=%r", query_id, _step, @@ -676,7 +677,8 @@ def _run_single_query( continue if fn_name == "think": - logger.info( + # Agent thoughts can quote document text/PII; keep content at DEBUG. + logger.debug( "ReActAgentOperator: query=%s step=%d think=%r", query_id, _step, diff --git a/nemo_retriever/src/nemo_retriever/graph/selection_agent_operator.py b/nemo_retriever/src/nemo_retriever/graph/selection_agent_operator.py index 6f53342388..a92974cfea 100644 --- a/nemo_retriever/src/nemo_retriever/graph/selection_agent_operator.py +++ b/nemo_retriever/src/nemo_retriever/graph/selection_agent_operator.py @@ -545,7 +545,8 @@ def _select_documents( assistant_turn: Dict[str, Any] = {"role": "assistant"} if msg.get("content"): assistant_turn["content"] = msg["content"] - logger.info( + # Agent reasoning can quote document text/PII; keep content at DEBUG. + logger.debug( "SelectionAgentOperator: step=%d assistant content=%r", _step, _preview_text(msg.get("content")),