diff --git a/nemo_retriever/src/nemo_retriever/adapters/cli/main.py b/nemo_retriever/src/nemo_retriever/adapters/cli/main.py index afda7c12f..1f61669d8 100644 --- a/nemo_retriever/src/nemo_retriever/adapters/cli/main.py +++ b/nemo_retriever/src/nemo_retriever/adapters/cli/main.py @@ -43,6 +43,13 @@ run_ingest_workflow, ) from nemo_retriever.adapters.cli.query_workflow import query_documents +from nemo_retriever.query.options import ( + QueryEmbedOptions, + QueryRerankOptions, + QueryRequest, + QueryRetrievalOptions, + QueryStorageOptions, +) from nemo_retriever.vdb.records import RetrievalHit from nemo_retriever.version import get_version_info @@ -783,19 +790,29 @@ def query_command( try: with _quiet_capture(): hits = query_documents( - query, - top_k=top_k, - candidate_k=candidate_k, - page_dedup=page_dedup, - content_types=content_types, - lancedb_uri=lancedb_uri, - table_name=table_name, - embed_invoke_url=embed_invoke_url, - embed_model_name=embed_model_name, - reranker_invoke_url=reranker_invoke_url, - reranker_model_name=reranker_model_name, - reranker_backend=reranker_backend, - rerank=rerank, + QueryRequest( + query=query, + retrieval=QueryRetrievalOptions( + top_k=top_k, + candidate_k=candidate_k, + page_dedup=page_dedup, + content_types=content_types, + ), + embed=QueryEmbedOptions( + embed_invoke_url=embed_invoke_url, + embed_model_name=embed_model_name, + ), + rerank=QueryRerankOptions( + enabled=rerank, + reranker_invoke_url=reranker_invoke_url, + reranker_model_name=reranker_model_name, + reranker_backend=reranker_backend, + ), + storage=QueryStorageOptions( + lancedb_uri=lancedb_uri, + table_name=table_name, + ), + ) ) except _ROOT_CLI_ERRORS as exc: typer.echo(f"Error: {exc}", err=True) diff --git a/nemo_retriever/src/nemo_retriever/adapters/cli/query_workflow.py b/nemo_retriever/src/nemo_retriever/adapters/cli/query_workflow.py index 121d51689..d9fea1140 100644 --- a/nemo_retriever/src/nemo_retriever/adapters/cli/query_workflow.py +++ b/nemo_retriever/src/nemo_retriever/adapters/cli/query_workflow.py @@ -4,88 +4,11 @@ from __future__ import annotations -from typing import Any, Sequence - -from nemo_retriever.params import build_embed_option_kwargs -from nemo_retriever.retriever import Retriever -from nemo_retriever.utils.remote_auth import resolve_remote_api_key +from nemo_retriever.query.options import QueryRequest +from nemo_retriever.query.workflow import query_documents as run_query_documents from nemo_retriever.vdb.records import RetrievalHit -_LOCAL_VL_RERANK_MODEL = "nvidia/llama-nemotron-rerank-vl-1b-v2" - - -def _build_rerank_kwargs( - reranker_invoke_url: str | None, - reranker_model_name: str | None = None, - reranker_backend: str | None = None, -) -> dict[str, str]: - """Build kwargs for the rerank stage. Mirrors :func:`build_embed_option_kwargs`: - if ``reranker_invoke_url`` is given the remote NIM path is configured; - otherwise the local GPU reranker runs with ``reranker_model_name`` (or the - matching VL default to pair with the local VL embedder). - - ``reranker_backend`` only applies to the local path and selects the local - inference backend (``"vllm"`` or ``"hf"``); ``None`` defers to the library - default in ``_default_rerank_actor_kwargs``. - """ - reranker_url = (reranker_invoke_url or "").strip() - if reranker_url: - rerank_kwargs: dict[str, str] = {"rerank_invoke_url": reranker_url} - if reranker_model_name: - rerank_kwargs["model_name"] = reranker_model_name - api_key = resolve_remote_api_key() - if api_key is not None: - rerank_kwargs["api_key"] = api_key - return rerank_kwargs - - # Local GPU reranker - VL by default to pair with the local VL embedder. - # ``NemotronRerankGPUActor`` loads the model once per actor; the rerank - # model is ~2 GB and coexists with the vLLM embedder (which respects - # ``gpu_memory_utilization=0.45``). - local: dict[str, str] = {"model_name": reranker_model_name or _LOCAL_VL_RERANK_MODEL} - if reranker_backend: - local["local_reranker_backend"] = reranker_backend - return local - - -def query_documents( - query: str, - *, - top_k: int = 10, - candidate_k: int | None = None, - page_dedup: bool = False, - content_types: str | Sequence[str] | None = None, - lancedb_uri: str = "lancedb", - table_name: str = "nemo-retriever", - embed_invoke_url: str | None = None, - embed_model_name: str | None = None, - reranker_invoke_url: str | None = None, - reranker_model_name: str | None = None, - reranker_backend: str | None = None, - rerank: bool = False, -) -> list[RetrievalHit]: - """Run the minimal SDK query path used by the root CLI. - - Reranking is opt-in: pass ``rerank=True`` (or any of the rerank-related - args via the CLI, which implicitly set ``rerank=True``) to enable. - """ - embed_kwargs = build_embed_option_kwargs(embed_invoke_url, embed_model_name) - retriever_kwargs: dict[str, Any] = { - "top_k": top_k, - "vdb_kwargs": {"uri": lancedb_uri, "table_name": table_name}, - } - if embed_kwargs: - retriever_kwargs["embed_kwargs"] = embed_kwargs - if rerank: - rerank_kwargs = _build_rerank_kwargs(reranker_invoke_url, reranker_model_name, reranker_backend) - retriever_kwargs["rerank"] = True - if rerank_kwargs: - retriever_kwargs["rerank_kwargs"] = rerank_kwargs - retriever = Retriever(**retriever_kwargs) - return retriever.query( - query, - candidate_k=candidate_k, - page_dedup=page_dedup, - content_types=content_types, - ) +def query_documents(request: QueryRequest) -> list[RetrievalHit]: + """Run the typed root query workflow.""" + return run_query_documents(request) diff --git a/nemo_retriever/src/nemo_retriever/query/__init__.py b/nemo_retriever/src/nemo_retriever/query/__init__.py new file mode 100644 index 000000000..82fdb4d56 --- /dev/null +++ b/nemo_retriever/src/nemo_retriever/query/__init__.py @@ -0,0 +1,5 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024-26, NVIDIA CORPORATION & AFFILIATES. +# All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +"""Core query planning and execution package.""" diff --git a/nemo_retriever/src/nemo_retriever/query/options.py b/nemo_retriever/src/nemo_retriever/query/options.py new file mode 100644 index 000000000..331ac8237 --- /dev/null +++ b/nemo_retriever/src/nemo_retriever/query/options.py @@ -0,0 +1,45 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024-26, NVIDIA CORPORATION & AFFILIATES. +# All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +from __future__ import annotations + +from dataclasses import dataclass, field +from typing import Sequence + + +@dataclass(frozen=True) +class QueryRetrievalOptions: + top_k: int = 10 + candidate_k: int | None = None + page_dedup: bool = False + content_types: str | Sequence[str] | None = None + + +@dataclass(frozen=True) +class QueryEmbedOptions: + embed_invoke_url: str | None = None + embed_model_name: str | None = None + + +@dataclass(frozen=True) +class QueryRerankOptions: + enabled: bool = False + reranker_invoke_url: str | None = None + reranker_model_name: str | None = None + reranker_backend: str | None = None + + +@dataclass(frozen=True) +class QueryStorageOptions: + lancedb_uri: str = "lancedb" + table_name: str = "nemo-retriever" + + +@dataclass(frozen=True) +class QueryRequest: + query: str + retrieval: QueryRetrievalOptions = field(default_factory=QueryRetrievalOptions) + embed: QueryEmbedOptions = field(default_factory=QueryEmbedOptions) + rerank: QueryRerankOptions = field(default_factory=QueryRerankOptions) + storage: QueryStorageOptions = field(default_factory=QueryStorageOptions) diff --git a/nemo_retriever/src/nemo_retriever/query/workflow.py b/nemo_retriever/src/nemo_retriever/query/workflow.py new file mode 100644 index 000000000..b54154435 --- /dev/null +++ b/nemo_retriever/src/nemo_retriever/query/workflow.py @@ -0,0 +1,63 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024-26, NVIDIA CORPORATION & AFFILIATES. +# All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +from __future__ import annotations + +from typing import Any + +from nemo_retriever.params import build_embed_option_kwargs +from nemo_retriever.query.options import QueryRequest, QueryRerankOptions +from nemo_retriever.retriever import Retriever +from nemo_retriever.utils.remote_auth import resolve_remote_api_key +from nemo_retriever.vdb.records import RetrievalHit + +_LOCAL_VL_RERANK_MODEL = "nvidia/llama-nemotron-rerank-vl-1b-v2" + + +def _build_rerank_kwargs(options: QueryRerankOptions) -> dict[str, str]: + """Build kwargs for the rerank stage using the existing root query behavior.""" + reranker_url = (options.reranker_invoke_url or "").strip() + if reranker_url: + rerank_kwargs: dict[str, str] = {"rerank_invoke_url": reranker_url} + if options.reranker_model_name: + rerank_kwargs["model_name"] = options.reranker_model_name + api_key = resolve_remote_api_key() + if api_key is not None: + rerank_kwargs["api_key"] = api_key + return rerank_kwargs + + local: dict[str, str] = {"model_name": options.reranker_model_name or _LOCAL_VL_RERANK_MODEL} + if options.reranker_backend: + local["local_reranker_backend"] = options.reranker_backend + return local + + +def _build_retriever_kwargs(request: QueryRequest) -> dict[str, Any]: + embed_kwargs = build_embed_option_kwargs(request.embed.embed_invoke_url, request.embed.embed_model_name) + retriever_kwargs: dict[str, Any] = { + "top_k": request.retrieval.top_k, + "vdb_kwargs": { + "uri": request.storage.lancedb_uri, + "table_name": request.storage.table_name, + }, + } + if embed_kwargs: + retriever_kwargs["embed_kwargs"] = embed_kwargs + if request.rerank.enabled: + rerank_kwargs = _build_rerank_kwargs(request.rerank) + retriever_kwargs["rerank"] = True + if rerank_kwargs: + retriever_kwargs["rerank_kwargs"] = rerank_kwargs + return retriever_kwargs + + +def query_documents(request: QueryRequest) -> list[RetrievalHit]: + """Run the SDK query path used by the root CLI.""" + retriever = Retriever(**_build_retriever_kwargs(request)) + return retriever.query( + request.query, + candidate_k=request.retrieval.candidate_k, + page_dedup=request.retrieval.page_dedup, + content_types=request.retrieval.content_types, + ) diff --git a/nemo_retriever/tests/test_query_workflow_options.py b/nemo_retriever/tests/test_query_workflow_options.py new file mode 100644 index 000000000..db1ef2a75 --- /dev/null +++ b/nemo_retriever/tests/test_query_workflow_options.py @@ -0,0 +1,124 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024-26, NVIDIA CORPORATION & AFFILIATES. +# All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +from __future__ import annotations + +from typing import Any + +import nemo_retriever.query.workflow as query_workflow +from nemo_retriever.query.options import ( + QueryEmbedOptions, + QueryRerankOptions, + QueryRequest, + QueryRetrievalOptions, + QueryStorageOptions, +) + + +def test_query_request_builds_retriever_kwargs_without_rerank(monkeypatch) -> None: + retriever_calls: list[dict[str, Any]] = [] + + class FakeRetriever: + def __init__(self, **kwargs: Any) -> None: + retriever_calls.append(kwargs) + + def query(self, query: str, **_kwargs: Any) -> list[dict[str, Any]]: + return [] + + monkeypatch.setattr(query_workflow, "Retriever", FakeRetriever) + request = QueryRequest( + query="deployment?", + retrieval=QueryRetrievalOptions(top_k=3), + storage=QueryStorageOptions(lancedb_uri="/tmp/lancedb", table_name="docs"), + ) + + assert query_workflow.query_documents(request) == [] + assert retriever_calls == [ + { + "top_k": 3, + "vdb_kwargs": {"uri": "/tmp/lancedb", "table_name": "docs"}, + } + ] + + +def test_query_request_builds_retriever_kwargs_with_embed_and_remote_rerank(monkeypatch) -> None: + retriever_calls: list[dict[str, Any]] = [] + monkeypatch.setenv("NVIDIA_API_KEY", "nvapi-test") + + class FakeRetriever: + def __init__(self, **kwargs: Any) -> None: + retriever_calls.append(kwargs) + + def query(self, query: str, **_kwargs: Any) -> list[dict[str, Any]]: + return [] + + monkeypatch.setattr(query_workflow, "Retriever", FakeRetriever) + request = QueryRequest( + query="deployment?", + embed=QueryEmbedOptions( + embed_invoke_url="http://embed:8000/v1/embeddings", + embed_model_name="nvidia/llama-nemotron-embed-1b-v2", + ), + rerank=QueryRerankOptions( + enabled=True, + reranker_invoke_url="http://rerank:8000/v1/ranking", + ), + ) + + assert query_workflow.query_documents(request) == [] + assert retriever_calls == [ + { + "top_k": 10, + "vdb_kwargs": {"uri": "lancedb", "table_name": "nemo-retriever"}, + "embed_kwargs": { + "embed_invoke_url": "http://embed:8000/v1/embeddings", + "embedding_endpoint": "http://embed:8000/v1/embeddings", + "model_name": "nvidia/llama-nemotron-embed-1b-v2", + "embed_model_name": "nvidia/llama-nemotron-embed-1b-v2", + }, + "rerank": True, + "rerank_kwargs": { + "rerank_invoke_url": "http://rerank:8000/v1/ranking", + "api_key": "nvapi-test", + }, + } + ] + + +def test_query_documents_uses_typed_request(monkeypatch) -> None: + retriever_calls: list[dict[str, Any]] = [] + query_calls: list[tuple[str, dict[str, Any]]] = [] + + class FakeRetriever: + def __init__(self, **kwargs: Any) -> None: + retriever_calls.append(kwargs) + + def query(self, query: str, **kwargs: Any) -> list[dict[str, Any]]: + query_calls.append((query, kwargs)) + return [{"text": "passage", "source": "doc.pdf", "page_number": 1}] + + monkeypatch.setattr(query_workflow, "Retriever", FakeRetriever) + + request = QueryRequest( + query="deployment?", + retrieval=QueryRetrievalOptions( + top_k=1, + candidate_k=3, + page_dedup=True, + content_types="text,table", + ), + ) + + assert query_workflow.query_documents(request) == [{"text": "passage", "source": "doc.pdf", "page_number": 1}] + assert retriever_calls == [{"top_k": 1, "vdb_kwargs": {"uri": "lancedb", "table_name": "nemo-retriever"}}] + assert query_calls == [ + ( + "deployment?", + { + "candidate_k": 3, + "page_dedup": True, + "content_types": "text,table", + }, + ) + ] diff --git a/nemo_retriever/tests/test_root_query_cli.py b/nemo_retriever/tests/test_root_query_cli.py index 1fb316e52..943b6dc67 100644 --- a/nemo_retriever/tests/test_root_query_cli.py +++ b/nemo_retriever/tests/test_root_query_cli.py @@ -10,7 +10,7 @@ from typer.testing import CliRunner -import nemo_retriever.adapters.cli.query_workflow as query_workflow +import nemo_retriever.query.workflow as query_core RUNNER = CliRunner() @@ -49,7 +49,7 @@ def query(self, query: str, **_kwargs: Any) -> list[dict[str, Any]]: query_calls.append(query) return hits - monkeypatch.setattr(query_workflow, "Retriever", FakeRetriever) + monkeypatch.setattr(query_core, "Retriever", FakeRetriever) result = RUNNER.invoke( cli_main.app, @@ -88,7 +88,7 @@ def query(self, query: str, **kwargs: Any) -> list[dict[str, Any]]: {"text": "text row", "metadata": {"type": "text"}, "page_number": 1, "source": "doc.pdf"}, ] - monkeypatch.setattr(query_workflow, "Retriever", FakeRetriever) + monkeypatch.setattr(query_core, "Retriever", FakeRetriever) result = RUNNER.invoke( cli_main.app, @@ -124,7 +124,7 @@ def query(self, query: str, **_kwargs: Any) -> list[dict[str, Any]]: query_calls.append(query) return [] - monkeypatch.setattr(query_workflow, "Retriever", FakeRetriever) + monkeypatch.setattr(query_core, "Retriever", FakeRetriever) result = RUNNER.invoke( cli_main.app, @@ -169,7 +169,7 @@ def query(self, query: str, **_kwargs: Any) -> list[dict[str, Any]]: query_calls.append(query) return [] - monkeypatch.setattr(query_workflow, "Retriever", FakeRetriever) + monkeypatch.setattr(query_core, "Retriever", FakeRetriever) result = RUNNER.invoke( cli_main.app, @@ -208,7 +208,7 @@ def __init__(self, **kwargs: Any) -> None: def query(self, query: str, **_kwargs: Any) -> list[dict[str, Any]]: return [] - monkeypatch.setattr(query_workflow, "Retriever", FakeRetriever) + monkeypatch.setattr(query_core, "Retriever", FakeRetriever) result = RUNNER.invoke(cli_main.app, ["query", "hello", "--rerank"]) @@ -234,7 +234,7 @@ def __init__(self, **kwargs: Any) -> None: def query(self, query: str, **_kwargs: Any) -> list[dict[str, Any]]: return [] - monkeypatch.setattr(query_workflow, "Retriever", FakeRetriever) + monkeypatch.setattr(query_core, "Retriever", FakeRetriever) result = RUNNER.invoke(cli_main.app, ["query", "hello"]) @@ -256,7 +256,7 @@ def __init__(self, **kwargs: Any) -> None: def query(self, query: str, **_kwargs: Any) -> list[dict[str, Any]]: return [] - monkeypatch.setattr(query_workflow, "Retriever", FakeRetriever) + monkeypatch.setattr(query_core, "Retriever", FakeRetriever) result = RUNNER.invoke( cli_main.app,