Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 30 additions & 13 deletions nemo_retriever/src/nemo_retriever/adapters/cli/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,13 @@
run_ingest_workflow,
)
from nemo_retriever.adapters.cli.query_workflow import query_documents
from nemo_retriever.query.options import (
QueryEmbedOptions,
QueryRerankOptions,
QueryRequest,
QueryRetrievalOptions,
QueryStorageOptions,
)
from nemo_retriever.vdb.records import RetrievalHit
from nemo_retriever.version import get_version_info

Expand Down Expand Up @@ -783,19 +790,29 @@ def query_command(
try:
with _quiet_capture():
hits = query_documents(
query,
top_k=top_k,
candidate_k=candidate_k,
page_dedup=page_dedup,
content_types=content_types,
lancedb_uri=lancedb_uri,
table_name=table_name,
embed_invoke_url=embed_invoke_url,
embed_model_name=embed_model_name,
reranker_invoke_url=reranker_invoke_url,
reranker_model_name=reranker_model_name,
reranker_backend=reranker_backend,
rerank=rerank,
QueryRequest(
query=query,
retrieval=QueryRetrievalOptions(
top_k=top_k,
candidate_k=candidate_k,
page_dedup=page_dedup,
content_types=content_types,
),
embed=QueryEmbedOptions(
embed_invoke_url=embed_invoke_url,
embed_model_name=embed_model_name,
),
rerank=QueryRerankOptions(
enabled=rerank,
reranker_invoke_url=reranker_invoke_url,
reranker_model_name=reranker_model_name,
reranker_backend=reranker_backend,
),
storage=QueryStorageOptions(
lancedb_uri=lancedb_uri,
table_name=table_name,
),
)
)
except _ROOT_CLI_ERRORS as exc:
typer.echo(f"Error: {exc}", err=True)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,88 +4,11 @@

from __future__ import annotations

from typing import Any, Sequence

from nemo_retriever.params import build_embed_option_kwargs
from nemo_retriever.retriever import Retriever
from nemo_retriever.utils.remote_auth import resolve_remote_api_key
from nemo_retriever.query.options import QueryRequest
from nemo_retriever.query.workflow import query_documents as run_query_documents
from nemo_retriever.vdb.records import RetrievalHit

_LOCAL_VL_RERANK_MODEL = "nvidia/llama-nemotron-rerank-vl-1b-v2"


def _build_rerank_kwargs(
reranker_invoke_url: str | None,
reranker_model_name: str | None = None,
reranker_backend: str | None = None,
) -> dict[str, str]:
"""Build kwargs for the rerank stage. Mirrors :func:`build_embed_option_kwargs`:
if ``reranker_invoke_url`` is given the remote NIM path is configured;
otherwise the local GPU reranker runs with ``reranker_model_name`` (or the
matching VL default to pair with the local VL embedder).

``reranker_backend`` only applies to the local path and selects the local
inference backend (``"vllm"`` or ``"hf"``); ``None`` defers to the library
default in ``_default_rerank_actor_kwargs``.
"""
reranker_url = (reranker_invoke_url or "").strip()
if reranker_url:
rerank_kwargs: dict[str, str] = {"rerank_invoke_url": reranker_url}
if reranker_model_name:
rerank_kwargs["model_name"] = reranker_model_name
api_key = resolve_remote_api_key()
if api_key is not None:
rerank_kwargs["api_key"] = api_key
return rerank_kwargs

# Local GPU reranker - VL by default to pair with the local VL embedder.
# ``NemotronRerankGPUActor`` loads the model once per actor; the rerank
# model is ~2 GB and coexists with the vLLM embedder (which respects
# ``gpu_memory_utilization=0.45``).
local: dict[str, str] = {"model_name": reranker_model_name or _LOCAL_VL_RERANK_MODEL}
if reranker_backend:
local["local_reranker_backend"] = reranker_backend
return local


def query_documents(
query: str,
*,
top_k: int = 10,
candidate_k: int | None = None,
page_dedup: bool = False,
content_types: str | Sequence[str] | None = None,
lancedb_uri: str = "lancedb",
table_name: str = "nemo-retriever",
embed_invoke_url: str | None = None,
embed_model_name: str | None = None,
reranker_invoke_url: str | None = None,
reranker_model_name: str | None = None,
reranker_backend: str | None = None,
rerank: bool = False,
) -> list[RetrievalHit]:
"""Run the minimal SDK query path used by the root CLI.

Reranking is opt-in: pass ``rerank=True`` (or any of the rerank-related
args via the CLI, which implicitly set ``rerank=True``) to enable.
"""
embed_kwargs = build_embed_option_kwargs(embed_invoke_url, embed_model_name)
retriever_kwargs: dict[str, Any] = {
"top_k": top_k,
"vdb_kwargs": {"uri": lancedb_uri, "table_name": table_name},
}
if embed_kwargs:
retriever_kwargs["embed_kwargs"] = embed_kwargs
if rerank:
rerank_kwargs = _build_rerank_kwargs(reranker_invoke_url, reranker_model_name, reranker_backend)
retriever_kwargs["rerank"] = True
if rerank_kwargs:
retriever_kwargs["rerank_kwargs"] = rerank_kwargs

retriever = Retriever(**retriever_kwargs)
return retriever.query(
query,
candidate_k=candidate_k,
page_dedup=page_dedup,
content_types=content_types,
)
def query_documents(request: QueryRequest) -> list[RetrievalHit]:
"""Run the typed root query workflow."""
return run_query_documents(request)
5 changes: 5 additions & 0 deletions nemo_retriever/src/nemo_retriever/query/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
# SPDX-FileCopyrightText: Copyright (c) 2024-26, NVIDIA CORPORATION & AFFILIATES.
# All rights reserved.
# SPDX-License-Identifier: Apache-2.0

"""Core query planning and execution package."""
45 changes: 45 additions & 0 deletions nemo_retriever/src/nemo_retriever/query/options.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
# SPDX-FileCopyrightText: Copyright (c) 2024-26, NVIDIA CORPORATION & AFFILIATES.
# All rights reserved.
# SPDX-License-Identifier: Apache-2.0

from __future__ import annotations

from dataclasses import dataclass, field
from typing import Sequence


@dataclass(frozen=True)
class QueryRetrievalOptions:
top_k: int = 10
candidate_k: int | None = None
page_dedup: bool = False
content_types: str | Sequence[str] | None = None


@dataclass(frozen=True)
class QueryEmbedOptions:
embed_invoke_url: str | None = None
embed_model_name: str | None = None


@dataclass(frozen=True)
class QueryRerankOptions:
enabled: bool = False
reranker_invoke_url: str | None = None
reranker_model_name: str | None = None
reranker_backend: str | None = None


@dataclass(frozen=True)
class QueryStorageOptions:
lancedb_uri: str = "lancedb"
table_name: str = "nemo-retriever"


@dataclass(frozen=True)
class QueryRequest:
query: str
retrieval: QueryRetrievalOptions = field(default_factory=QueryRetrievalOptions)
embed: QueryEmbedOptions = field(default_factory=QueryEmbedOptions)
rerank: QueryRerankOptions = field(default_factory=QueryRerankOptions)
storage: QueryStorageOptions = field(default_factory=QueryStorageOptions)
Comment thread
jioffe502 marked this conversation as resolved.
63 changes: 63 additions & 0 deletions nemo_retriever/src/nemo_retriever/query/workflow.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
# SPDX-FileCopyrightText: Copyright (c) 2024-26, NVIDIA CORPORATION & AFFILIATES.
# All rights reserved.
# SPDX-License-Identifier: Apache-2.0

from __future__ import annotations

from typing import Any

from nemo_retriever.params import build_embed_option_kwargs
from nemo_retriever.query.options import QueryRequest, QueryRerankOptions
from nemo_retriever.retriever import Retriever
from nemo_retriever.utils.remote_auth import resolve_remote_api_key
from nemo_retriever.vdb.records import RetrievalHit

_LOCAL_VL_RERANK_MODEL = "nvidia/llama-nemotron-rerank-vl-1b-v2"


def _build_rerank_kwargs(options: QueryRerankOptions) -> dict[str, str]:
"""Build kwargs for the rerank stage using the existing root query behavior."""
reranker_url = (options.reranker_invoke_url or "").strip()
if reranker_url:
rerank_kwargs: dict[str, str] = {"rerank_invoke_url": reranker_url}
if options.reranker_model_name:
rerank_kwargs["model_name"] = options.reranker_model_name
api_key = resolve_remote_api_key()
if api_key is not None:
rerank_kwargs["api_key"] = api_key
return rerank_kwargs

local: dict[str, str] = {"model_name": options.reranker_model_name or _LOCAL_VL_RERANK_MODEL}
if options.reranker_backend:
local["local_reranker_backend"] = options.reranker_backend
return local


def _build_retriever_kwargs(request: QueryRequest) -> dict[str, Any]:
embed_kwargs = build_embed_option_kwargs(request.embed.embed_invoke_url, request.embed.embed_model_name)
retriever_kwargs: dict[str, Any] = {
"top_k": request.retrieval.top_k,
"vdb_kwargs": {
"uri": request.storage.lancedb_uri,
"table_name": request.storage.table_name,
},
}
if embed_kwargs:
retriever_kwargs["embed_kwargs"] = embed_kwargs
if request.rerank.enabled:
rerank_kwargs = _build_rerank_kwargs(request.rerank)
retriever_kwargs["rerank"] = True
if rerank_kwargs:
retriever_kwargs["rerank_kwargs"] = rerank_kwargs
return retriever_kwargs


def query_documents(request: QueryRequest) -> list[RetrievalHit]:
"""Run the SDK query path used by the root CLI."""
retriever = Retriever(**_build_retriever_kwargs(request))
return retriever.query(
request.query,
candidate_k=request.retrieval.candidate_k,
page_dedup=request.retrieval.page_dedup,
content_types=request.retrieval.content_types,
)
124 changes: 124 additions & 0 deletions nemo_retriever/tests/test_query_workflow_options.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,124 @@
# SPDX-FileCopyrightText: Copyright (c) 2024-26, NVIDIA CORPORATION & AFFILIATES.
# All rights reserved.
# SPDX-License-Identifier: Apache-2.0

from __future__ import annotations

from typing import Any

import nemo_retriever.query.workflow as query_workflow
from nemo_retriever.query.options import (
QueryEmbedOptions,
QueryRerankOptions,
QueryRequest,
QueryRetrievalOptions,
QueryStorageOptions,
)


def test_query_request_builds_retriever_kwargs_without_rerank(monkeypatch) -> None:
retriever_calls: list[dict[str, Any]] = []

class FakeRetriever:
def __init__(self, **kwargs: Any) -> None:
retriever_calls.append(kwargs)

def query(self, query: str, **_kwargs: Any) -> list[dict[str, Any]]:
return []

monkeypatch.setattr(query_workflow, "Retriever", FakeRetriever)
request = QueryRequest(
query="deployment?",
retrieval=QueryRetrievalOptions(top_k=3),
storage=QueryStorageOptions(lancedb_uri="/tmp/lancedb", table_name="docs"),
)

assert query_workflow.query_documents(request) == []
assert retriever_calls == [
{
"top_k": 3,
"vdb_kwargs": {"uri": "/tmp/lancedb", "table_name": "docs"},
}
]


def test_query_request_builds_retriever_kwargs_with_embed_and_remote_rerank(monkeypatch) -> None:
retriever_calls: list[dict[str, Any]] = []
monkeypatch.setenv("NVIDIA_API_KEY", "nvapi-test")

class FakeRetriever:
def __init__(self, **kwargs: Any) -> None:
retriever_calls.append(kwargs)

def query(self, query: str, **_kwargs: Any) -> list[dict[str, Any]]:
return []

monkeypatch.setattr(query_workflow, "Retriever", FakeRetriever)
request = QueryRequest(
query="deployment?",
embed=QueryEmbedOptions(
embed_invoke_url="http://embed:8000/v1/embeddings",
embed_model_name="nvidia/llama-nemotron-embed-1b-v2",
),
rerank=QueryRerankOptions(
enabled=True,
reranker_invoke_url="http://rerank:8000/v1/ranking",
),
)

assert query_workflow.query_documents(request) == []
assert retriever_calls == [
{
"top_k": 10,
"vdb_kwargs": {"uri": "lancedb", "table_name": "nemo-retriever"},
"embed_kwargs": {
"embed_invoke_url": "http://embed:8000/v1/embeddings",
"embedding_endpoint": "http://embed:8000/v1/embeddings",
"model_name": "nvidia/llama-nemotron-embed-1b-v2",
"embed_model_name": "nvidia/llama-nemotron-embed-1b-v2",
},
"rerank": True,
"rerank_kwargs": {
"rerank_invoke_url": "http://rerank:8000/v1/ranking",
"api_key": "nvapi-test",
},
}
]


def test_query_documents_uses_typed_request(monkeypatch) -> None:
retriever_calls: list[dict[str, Any]] = []
query_calls: list[tuple[str, dict[str, Any]]] = []

class FakeRetriever:
def __init__(self, **kwargs: Any) -> None:
retriever_calls.append(kwargs)

def query(self, query: str, **kwargs: Any) -> list[dict[str, Any]]:
query_calls.append((query, kwargs))
return [{"text": "passage", "source": "doc.pdf", "page_number": 1}]

monkeypatch.setattr(query_workflow, "Retriever", FakeRetriever)

request = QueryRequest(
query="deployment?",
retrieval=QueryRetrievalOptions(
top_k=1,
candidate_k=3,
page_dedup=True,
content_types="text,table",
),
)

assert query_workflow.query_documents(request) == [{"text": "passage", "source": "doc.pdf", "page_number": 1}]
assert retriever_calls == [{"top_k": 1, "vdb_kwargs": {"uri": "lancedb", "table_name": "nemo-retriever"}}]
assert query_calls == [
(
"deployment?",
{
"candidate_k": 3,
"page_dedup": True,
"content_types": "text,table",
},
)
]
Loading
Loading