From 137c57a8ec370ba3fef62f8290d3b43bee768dbf Mon Sep 17 00:00:00 2001 From: jioffe502 Date: Fri, 5 Jun 2026 20:08:25 +0000 Subject: [PATCH] Move service ingest construction into ingest core --- .../adapters/cli/ingest_workflow.py | 35 ++ .../src/nemo_retriever/adapters/cli/main.py | 388 ++++++++---- .../src/nemo_retriever/ingest/service.py | 564 ++++++++++++++++++ .../src/nemo_retriever/pipeline/__main__.py | 185 +----- nemo_retriever/tests/test_pipeline_helpers.py | 27 +- .../tests/test_root_cli_workflow.py | 146 +++++ 6 files changed, 1060 insertions(+), 285 deletions(-) create mode 100644 nemo_retriever/src/nemo_retriever/ingest/service.py diff --git a/nemo_retriever/src/nemo_retriever/adapters/cli/ingest_workflow.py b/nemo_retriever/src/nemo_retriever/adapters/cli/ingest_workflow.py index 4fe4e471a..147153984 100644 --- a/nemo_retriever/src/nemo_retriever/adapters/cli/ingest_workflow.py +++ b/nemo_retriever/src/nemo_retriever/adapters/cli/ingest_workflow.py @@ -4,10 +4,16 @@ from __future__ import annotations +from dataclasses import asdict from typing import Any from nemo_retriever.ingest.execution import execute_ingest_plan from nemo_retriever.ingest.plan import ResolvedIngestPlan +from nemo_retriever.ingest.service import ( + ServiceIngestRequest, + execute_service_ingest_request, + service_split_config_for_request, +) from nemo_retriever.ingest_manifest import format_branch_summary _DRY_RUN_SECRET_FIELD_PATTERNS = ("api_key", "password", "secret", "credential", "bearer") @@ -88,6 +94,23 @@ def _ingest_plan_to_dry_run_data(plan: ResolvedIngestPlan) -> dict[str, Any]: } +def service_ingest_request_to_dry_run_data(request: ServiceIngestRequest) -> dict[str, Any]: + """Return the JSON payload printed by ``retriever ingest --run-mode service --dry-run``.""" + return { + "dry_run": True, + "run_mode": "service", + "documents": list(request.documents), + "input_type": request.input_type, + "service": _strip_secret_values(asdict(request.connection)), + "extract": _params_to_dry_run_dict(request.extract_params), + "split_config": _params_to_dry_run_dict(service_split_config_for_request(request)), + "dedup": _params_to_dry_run_dict(request.dedup_params), + "caption": _params_to_dry_run_dict(request.caption_params), + "embed": _params_to_dry_run_dict(request.embed_params), + "store": _params_to_dry_run_dict(request.store_params), + } + + def run_ingest_workflow( plan: ResolvedIngestPlan, *, @@ -98,3 +121,15 @@ def run_ingest_workflow( return _ingest_plan_to_dry_run_data(plan) return execute_ingest_plan(plan).to_summary_dict() + + +def run_service_ingest_workflow( + request: ServiceIngestRequest, + *, + dry_run: bool = False, +) -> dict[str, Any]: + """Apply root ingest workflow policy to an already-resolved service request.""" + if dry_run: + return service_ingest_request_to_dry_run_data(request) + + return execute_service_ingest_request(request).to_summary_dict() diff --git a/nemo_retriever/src/nemo_retriever/adapters/cli/main.py b/nemo_retriever/src/nemo_retriever/adapters/cli/main.py index afda7c12f..5cc62cd27 100644 --- a/nemo_retriever/src/nemo_retriever/adapters/cli/main.py +++ b/nemo_retriever/src/nemo_retriever/adapters/cli/main.py @@ -30,7 +30,6 @@ IngestPlanRequest, IngestProfileValue, IngestRuntimeOptions, - IngestRunModeValue, IngestSourceOptions, IngestStorageOptions, LocalIngestEmbedBackendValue, @@ -39,8 +38,22 @@ TableOutputFormatValue, resolve_ingest_plan, ) +from nemo_retriever.params import IngestorRunMode +from nemo_retriever.ingest.service import ( + ServiceIngestCaptionOptions, + ServiceIngestChunkOptions, + ServiceIngestConnectionOptions, + ServiceIngestDedupOptions, + ServiceIngestEmbedOptions, + ServiceIngestExtractOptions, + ServiceIngestImageStoreOptions, + ServiceIngestPlanRequest, + ServiceIngestSourceOptions, + resolve_service_ingest_request, +) from nemo_retriever.adapters.cli.ingest_workflow import ( run_ingest_workflow, + run_service_ingest_workflow, ) from nemo_retriever.adapters.cli.query_workflow import query_documents from nemo_retriever.vdb.records import RetrievalHit @@ -82,7 +95,77 @@ except Exception: logger.debug("Skipping '%s' sub-command (import failed)", _name) -_ROOT_CLI_ERRORS = (OSError, RuntimeError, ValueError, ValidationError) +_ROOT_CLI_ERRORS = (OSError, RuntimeError, ValueError, ValidationError, typer.BadParameter) + + +_ROOT_SERVICE_INCOMPATIBLE_FLAGS: tuple[tuple[str, str], ...] = ( + ("--lancedb-uri", "lancedb_uri"), + ("--table-name", "table_name"), + ("--overwrite/--append", "overwrite"), + ("--api-key", "api_key"), + ("--page-elements-invoke-url", "page_elements_invoke_url"), + ("--ocr-invoke-url", "ocr_invoke_url"), + ("--ocr-lang", "ocr_lang"), + ("--graphic-elements-invoke-url", "graphic_elements_invoke_url"), + ("--table-structure-invoke-url", "table_structure_invoke_url"), + ("--caption-invoke-url", "caption_invoke_url"), + ("--caption-model-name", "caption_model_name"), + ("--embed-invoke-url", "embed_invoke_url"), + ("--embed-model-name", "embed_model_name"), + ("--local-ingest-embed-backend", "local_ingest_embed_backend"), + ("--ray-address", "ray_address"), + ("--ray-log-to-driver/--no-ray-log-to-driver", "ray_log_to_driver"), + ("--pdf-split-batch-size", "pdf_split_batch_size"), + ("--pdf-extract-workers", "pdf_extract_workers"), + ("--pdf-extract-batch-size", "pdf_extract_batch_size"), + ("--pdf-extract-cpus-per-task", "pdf_extract_cpus_per_task"), + ("--page-elements-workers", "page_elements_workers"), + ("--page-elements-batch-size", "page_elements_batch_size"), + ("--page-elements-cpus-per-actor", "page_elements_cpus_per_actor"), + ("--page-elements-gpus-per-actor", "page_elements_gpus_per_actor"), + ("--ocr-workers", "ocr_workers"), + ("--ocr-batch-size", "ocr_batch_size"), + ("--ocr-cpus-per-actor", "ocr_cpus_per_actor"), + ("--ocr-gpus-per-actor", "ocr_gpus_per_actor"), + ("--table-structure-workers", "table_structure_workers"), + ("--table-structure-batch-size", "table_structure_batch_size"), + ("--table-structure-cpus-per-actor", "table_structure_cpus_per_actor"), + ("--table-structure-gpus-per-actor", "table_structure_gpus_per_actor"), + ("--nemotron-parse-workers", "nemotron_parse_workers"), + ("--nemotron-parse-batch-size", "nemotron_parse_batch_size"), + ("--nemotron-parse-gpus-per-actor", "nemotron_parse_gpus_per_actor"), + ("--embed-workers", "embed_workers"), + ("--embed-batch-size", "embed_batch_size"), + ("--embed-cpus-per-actor", "embed_cpus_per_actor"), + ("--embed-gpus-per-actor", "embed_gpus_per_actor"), + ("--segment-audio/--no-segment-audio", "segment_audio"), + ("--audio-split-type", "audio_split_type"), + ("--audio-split-interval", "audio_split_interval"), + ("--video-extract-audio/--no-video-extract-audio", "video_extract_audio"), + ("--video-extract-frames/--no-video-extract-frames", "video_extract_frames"), + ("--video-frame-fps", "video_frame_fps"), + ("--video-frame-dedup/--no-video-frame-dedup", "video_frame_dedup"), + ("--video-frame-text-dedup/--no-video-frame-text-dedup", "video_frame_text_dedup"), + ("--video-frame-text-dedup-max-dropped-frames", "video_frame_text_dedup_max_dropped_frames"), + ("--video-av-fuse/--no-video-av-fuse", "video_av_fuse"), +) + + +def _reject_root_service_incompatible_flags(ctx: typer.Context) -> None: + user_set: list[str] = [] + for cli_flag, param_name in _ROOT_SERVICE_INCOMPATIBLE_FLAGS: + source = ctx.get_parameter_source(param_name) + if getattr(source, "name", None) in {"COMMANDLINE", "ENVIRONMENT"}: + user_set.append(cli_flag) + if not user_set: + return + raise typer.BadParameter( + "--run-mode=service delegates pipeline configuration to the " + "retriever service; the following flag(s) cannot be set on the " + "client and would be silently dropped: " + ", ".join(user_set) + ". " + "Remove them, or use --run-mode batch/inprocess to apply them locally. " + "Server-side pipeline configuration lives in retriever-service.yaml." + ) def _query_cli_hit(hit: RetrievalHit) -> dict[str, object]: @@ -169,6 +252,7 @@ def main() -> None: @app.command("ingest") def ingest_command( + ctx: typer.Context, documents: list[str] = typer.Argument( ..., help="One or more files, directories, or globs. Supported file types are detected automatically.", @@ -180,10 +264,33 @@ def ingest_command( ), lancedb_uri: str = typer.Option("lancedb", "--lancedb-uri", help="LanceDB database URI."), table_name: str = typer.Option("nemo-retriever", "--table-name", help="LanceDB table name."), - run_mode: IngestRunModeValue = typer.Option( + run_mode: IngestorRunMode = typer.Option( "inprocess", "--run-mode", - help="Execution mode for the SDK ingestor. Defaults to inprocess; use batch for Ray Data scale-out.", + help=( + "Execution mode for ingest: inprocess (default), batch for Ray Data scale-out, " + "or service for a remote retriever service." + ), + ), + service_url: str = typer.Option( + "http://localhost:7670", + "--service-url", + help="Base URL of the retriever service (used only when --run-mode=service).", + ), + service_concurrency: int = typer.Option( + 8, + "--service-concurrency", + min=1, + help="Maximum concurrent document uploads to the service (used only when --run-mode=service).", + ), + service_api_token: str | None = typer.Option( + None, + "--service-api-token", + envvar="NEMO_RETRIEVER_API_TOKEN", + help=( + "Bearer token for authenticating with the retriever service " + "(used only when --run-mode=service). Falls back to $NEMO_RETRIEVER_API_TOKEN." + ), ), dry_run: bool = typer.Option( False, @@ -589,113 +696,166 @@ def ingest_command( capture = _quiet_capture() if quiet else contextlib.nullcontext() try: with capture: - ingest_plan = resolve_ingest_plan( - IngestPlanRequest( - source=IngestSourceOptions(documents=documents, profile=profile), - runtime=IngestRuntimeOptions( - run_mode=run_mode, - ray_address=ray_address, - ray_log_to_driver=ray_log_to_driver, - ), - extract=IngestExtractOptions( - method=method, - dpi=dpi, - extract_text=extract_text, - extract_images=extract_images, - extract_tables=extract_tables, - extract_charts=extract_charts, - extract_infographics=extract_infographics, - extract_page_as_image=extract_page_as_image, - use_page_elements=use_page_elements, - use_graphic_elements=use_graphic_elements, - use_table_structure=use_table_structure, - page_elements_invoke_url=page_elements_invoke_url, - ocr_invoke_url=ocr_invoke_url, - ocr_version=ocr_version, - ocr_lang=ocr_lang, - graphic_elements_invoke_url=graphic_elements_invoke_url, - table_structure_invoke_url=table_structure_invoke_url, - table_output_format=table_output_format, - extract_api_key=api_key, - batch=IngestExtractBatchOptions( - pdf_split_batch_size=pdf_split_batch_size, - pdf_extract_workers=pdf_extract_workers, - pdf_extract_batch_size=pdf_extract_batch_size, - pdf_extract_cpus_per_task=pdf_extract_cpus_per_task, - page_elements_workers=page_elements_workers, - page_elements_batch_size=page_elements_batch_size, - page_elements_cpus_per_actor=page_elements_cpus_per_actor, - page_elements_gpus_per_actor=page_elements_gpus_per_actor, - ocr_workers=ocr_workers, - ocr_batch_size=ocr_batch_size, - ocr_cpus_per_actor=ocr_cpus_per_actor, - ocr_gpus_per_actor=ocr_gpus_per_actor, - table_structure_workers=table_structure_workers, - table_structure_batch_size=table_structure_batch_size, - table_structure_cpus_per_actor=table_structure_cpus_per_actor, - table_structure_gpus_per_actor=table_structure_gpus_per_actor, - nemotron_parse_workers=nemotron_parse_workers, - nemotron_parse_batch_size=nemotron_parse_batch_size, - nemotron_parse_gpus_per_actor=nemotron_parse_gpus_per_actor, + if run_mode == "service": + _reject_root_service_incompatible_flags(ctx) + service_request = resolve_service_ingest_request( + ServiceIngestPlanRequest( + source=ServiceIngestSourceOptions(documents=documents, profile=profile), + connection=ServiceIngestConnectionOptions( + service_url=service_url, + service_concurrency=service_concurrency, + service_api_token=service_api_token, + ), + extract=ServiceIngestExtractOptions( + method=method, + dpi=dpi, + extract_text=extract_text, + extract_images=extract_images, + extract_tables=extract_tables, + extract_charts=extract_charts, + extract_infographics=extract_infographics, + extract_page_as_image=extract_page_as_image, + use_page_elements=use_page_elements, + use_graphic_elements=use_graphic_elements, + use_table_structure=use_table_structure, + table_output_format=table_output_format, + ocr_version=ocr_version, + ), + dedup=ServiceIngestDedupOptions( + enabled=dedup, + iou_threshold=dedup_iou_threshold, + ), + caption=ServiceIngestCaptionOptions( + enabled=caption, + context_text_max_chars=caption_context_text_max_chars, + caption_infographics=caption_infographics, ), - ), - media=IngestMediaOptions( - segment_audio=segment_audio, - audio_split_type=audio_split_type, - audio_split_interval=audio_split_interval, - video_extract_audio=video_extract_audio, - video_extract_frames=video_extract_frames, - video_frame_fps=video_frame_fps, - video_frame_dedup=video_frame_dedup, - video_frame_text_dedup=video_frame_text_dedup, - video_frame_text_dedup_max_dropped_frames=video_frame_text_dedup_max_dropped_frames, - video_av_fuse=video_av_fuse, - ), - caption=IngestCaptionOptions( - enabled=caption, - caption_invoke_url=caption_invoke_url, - caption_api_key=api_key, - caption_model_name=caption_model_name, - caption_context_text_max_chars=caption_context_text_max_chars, - caption_infographics=caption_infographics, - ), - dedup=IngestDedupOptions( - enabled=dedup, - iou_threshold=dedup_iou_threshold, - ), - chunk=IngestChunkOptions( - enabled=text_chunk, - text_chunk_max_tokens=text_chunk_max_tokens, - text_chunk_overlap_tokens=text_chunk_overlap_tokens, - ), - embed=IngestEmbedOptions( - embed_invoke_url=embed_invoke_url, - embed_model_name=embed_model_name, - local_ingest_embed_backend=local_ingest_embed_backend, - embed_api_key=api_key, - embed_modality=embed_modality, - embed_granularity=embed_granularity, - text_elements_modality=text_elements_modality, - structured_elements_modality=structured_elements_modality, - batch=IngestEmbedBatchOptions( - embed_workers=embed_workers, - embed_batch_size=embed_batch_size, - embed_cpus_per_actor=embed_cpus_per_actor, - embed_gpus_per_actor=embed_gpus_per_actor, + chunk=ServiceIngestChunkOptions( + enabled=text_chunk, + text_chunk_max_tokens=text_chunk_max_tokens, + text_chunk_overlap_tokens=text_chunk_overlap_tokens, ), - ), - image_store=IngestImageStoreOptions(images_uri=store_images_uri), - storage=IngestStorageOptions( - lancedb_uri=lancedb_uri, - table_name=table_name, - overwrite=overwrite, - ), + embed=ServiceIngestEmbedOptions( + embed_modality=embed_modality, + text_elements_modality=text_elements_modality, + structured_elements_modality=structured_elements_modality, + embed_granularity=embed_granularity, + ), + image_store=ServiceIngestImageStoreOptions(images_uri=store_images_uri), + ) + ) + summary = run_service_ingest_workflow( + service_request, + dry_run=dry_run, + ) + else: + ingest_plan = resolve_ingest_plan( + IngestPlanRequest( + source=IngestSourceOptions(documents=documents, profile=profile), + runtime=IngestRuntimeOptions( + run_mode=run_mode, + ray_address=ray_address, + ray_log_to_driver=ray_log_to_driver, + ), + extract=IngestExtractOptions( + method=method, + dpi=dpi, + extract_text=extract_text, + extract_images=extract_images, + extract_tables=extract_tables, + extract_charts=extract_charts, + extract_infographics=extract_infographics, + extract_page_as_image=extract_page_as_image, + use_page_elements=use_page_elements, + use_graphic_elements=use_graphic_elements, + use_table_structure=use_table_structure, + page_elements_invoke_url=page_elements_invoke_url, + ocr_invoke_url=ocr_invoke_url, + ocr_version=ocr_version, + ocr_lang=ocr_lang, + graphic_elements_invoke_url=graphic_elements_invoke_url, + table_structure_invoke_url=table_structure_invoke_url, + table_output_format=table_output_format, + extract_api_key=api_key, + batch=IngestExtractBatchOptions( + pdf_split_batch_size=pdf_split_batch_size, + pdf_extract_workers=pdf_extract_workers, + pdf_extract_batch_size=pdf_extract_batch_size, + pdf_extract_cpus_per_task=pdf_extract_cpus_per_task, + page_elements_workers=page_elements_workers, + page_elements_batch_size=page_elements_batch_size, + page_elements_cpus_per_actor=page_elements_cpus_per_actor, + page_elements_gpus_per_actor=page_elements_gpus_per_actor, + ocr_workers=ocr_workers, + ocr_batch_size=ocr_batch_size, + ocr_cpus_per_actor=ocr_cpus_per_actor, + ocr_gpus_per_actor=ocr_gpus_per_actor, + table_structure_workers=table_structure_workers, + table_structure_batch_size=table_structure_batch_size, + table_structure_cpus_per_actor=table_structure_cpus_per_actor, + table_structure_gpus_per_actor=table_structure_gpus_per_actor, + nemotron_parse_workers=nemotron_parse_workers, + nemotron_parse_batch_size=nemotron_parse_batch_size, + nemotron_parse_gpus_per_actor=nemotron_parse_gpus_per_actor, + ), + ), + media=IngestMediaOptions( + segment_audio=segment_audio, + audio_split_type=audio_split_type, + audio_split_interval=audio_split_interval, + video_extract_audio=video_extract_audio, + video_extract_frames=video_extract_frames, + video_frame_fps=video_frame_fps, + video_frame_dedup=video_frame_dedup, + video_frame_text_dedup=video_frame_text_dedup, + video_frame_text_dedup_max_dropped_frames=video_frame_text_dedup_max_dropped_frames, + video_av_fuse=video_av_fuse, + ), + caption=IngestCaptionOptions( + enabled=caption, + caption_invoke_url=caption_invoke_url, + caption_api_key=api_key, + caption_model_name=caption_model_name, + caption_context_text_max_chars=caption_context_text_max_chars, + caption_infographics=caption_infographics, + ), + dedup=IngestDedupOptions( + enabled=dedup, + iou_threshold=dedup_iou_threshold, + ), + chunk=IngestChunkOptions( + enabled=text_chunk, + text_chunk_max_tokens=text_chunk_max_tokens, + text_chunk_overlap_tokens=text_chunk_overlap_tokens, + ), + embed=IngestEmbedOptions( + embed_invoke_url=embed_invoke_url, + embed_model_name=embed_model_name, + local_ingest_embed_backend=local_ingest_embed_backend, + embed_api_key=api_key, + embed_modality=embed_modality, + embed_granularity=embed_granularity, + text_elements_modality=text_elements_modality, + structured_elements_modality=structured_elements_modality, + batch=IngestEmbedBatchOptions( + embed_workers=embed_workers, + embed_batch_size=embed_batch_size, + embed_cpus_per_actor=embed_cpus_per_actor, + embed_gpus_per_actor=embed_gpus_per_actor, + ), + ), + image_store=IngestImageStoreOptions(images_uri=store_images_uri), + storage=IngestStorageOptions( + lancedb_uri=lancedb_uri, + table_name=table_name, + overwrite=overwrite, + ), + ) + ) + summary = run_ingest_workflow( + ingest_plan, + dry_run=dry_run, ) - ) - summary = run_ingest_workflow( - ingest_plan, - dry_run=dry_run, - ) except _ROOT_CLI_ERRORS as exc: typer.echo(f"Error: {exc}", err=True) raise typer.Exit(1) from exc @@ -704,6 +864,18 @@ def ingest_command( typer.echo(json.dumps(summary, indent=2, sort_keys=True, default=str)) return + if summary.get("run_mode") == "service": + n_files = len(summary["documents"]) + service_target = summary["service_url"] + n_rows = summary.get("n_rows") + if n_rows is None: + typer.echo( + f"Ingested {n_files} file(s) through retriever service {service_target} " "(row count unavailable)." + ) + else: + typer.echo(f"Ingested {n_files} file(s) → {n_rows} row(s) through retriever service {service_target}.") + return + # Report input-file count alongside the actual landed-row count from the # LanceDB table — they diverge whenever one document explodes into multiple # chunks (PDFs → page elements, video → audio_visual segments) or diff --git a/nemo_retriever/src/nemo_retriever/ingest/service.py b/nemo_retriever/src/nemo_retriever/ingest/service.py new file mode 100644 index 000000000..7eaf642fd --- /dev/null +++ b/nemo_retriever/src/nemo_retriever/ingest/service.py @@ -0,0 +1,564 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024-26, NVIDIA CORPORATION & AFFILIATES. +# All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +from __future__ import annotations + +import glob as _glob +import logging +from dataclasses import dataclass, field +from pathlib import Path +from typing import Any, Literal, Sequence, cast + +from nemo_retriever.params import CaptionParams, DedupParams, EmbedParams, ExtractParams, StoreParams, TextChunkParams +from nemo_retriever.utils.input_files import ( + AUTO_INPUT_EXTENSIONS, + INPUT_TYPE_EXTENSIONS, + expand_input_file_patterns, + input_type_for_path, + resolve_input_files, +) + +logger = logging.getLogger(__name__) + +ServiceIngestInputTypeValue = Literal["auto", "pdf", "doc", "txt", "html", "image", "audio", "video"] +ServiceIngestProfileValue = Literal["auto", "fast-text"] +_SUPPORTED_SERVICE_INPUT_TYPES: tuple[ServiceIngestInputTypeValue, ...] = ( + "auto", + "pdf", + "doc", + "txt", + "html", + "image", + "audio", + "video", +) +_SUPPORTED_SERVICE_PROFILES: tuple[ServiceIngestProfileValue, ...] = ("auto", "fast-text") +_DEFAULT_TEXT_CHUNK_MAX_TOKENS = 1024 +_DEFAULT_TEXT_CHUNK_OVERLAP_TOKENS = 150 + + +@dataclass(frozen=True) +class ServiceIngestSourceOptions: + documents: Sequence[str] + profile: ServiceIngestProfileValue = "auto" + input_type: ServiceIngestInputTypeValue = "auto" + + +@dataclass(frozen=True) +class ServiceIngestConnectionOptions: + service_url: str = "http://localhost:7670" + service_concurrency: int = 8 + service_api_token: str | None = None + + +@dataclass(frozen=True) +class ServiceIngestExtractOptions: + method: str | None = None + dpi: int | None = None + extract_text: bool | None = None + extract_images: bool | None = None + extract_tables: bool | None = None + extract_charts: bool | None = None + extract_infographics: bool | None = None + extract_page_as_image: bool | None = None + use_page_elements: bool | None = None + use_graphic_elements: bool | None = None + use_table_structure: bool | None = None + table_output_format: str | None = None + ocr_version: str | None = None + + +@dataclass(frozen=True) +class ServiceIngestDedupOptions: + enabled: bool = False + iou_threshold: float | None = None + + +@dataclass(frozen=True) +class ServiceIngestCaptionOptions: + enabled: bool = False + context_text_max_chars: int | None = None + caption_infographics: bool | None = None + + +@dataclass(frozen=True) +class ServiceIngestChunkOptions: + enabled: bool = False + text_chunk_max_tokens: int | None = None + text_chunk_overlap_tokens: int | None = None + + +@dataclass(frozen=True) +class ServiceIngestEmbedOptions: + embed_modality: str | None = None + text_elements_modality: str | None = None + structured_elements_modality: str | None = None + embed_granularity: str | None = None + + +@dataclass(frozen=True) +class ServiceIngestImageStoreOptions: + images_uri: str | None = None + + +@dataclass(frozen=True) +class ServiceIngestPlanRequest: + source: ServiceIngestSourceOptions + connection: ServiceIngestConnectionOptions = field(default_factory=ServiceIngestConnectionOptions) + extract: ServiceIngestExtractOptions = field(default_factory=ServiceIngestExtractOptions) + dedup: ServiceIngestDedupOptions = field(default_factory=ServiceIngestDedupOptions) + caption: ServiceIngestCaptionOptions = field(default_factory=ServiceIngestCaptionOptions) + chunk: ServiceIngestChunkOptions = field(default_factory=ServiceIngestChunkOptions) + embed: ServiceIngestEmbedOptions = field(default_factory=ServiceIngestEmbedOptions) + image_store: ServiceIngestImageStoreOptions = field(default_factory=ServiceIngestImageStoreOptions) + + +@dataclass(frozen=True) +class ServiceIngestRequest: + documents: Sequence[str] + input_type: ServiceIngestInputTypeValue + extract_params: ExtractParams = field(default_factory=ExtractParams) + embed_params: EmbedParams = field(default_factory=EmbedParams) + text_chunk_params: TextChunkParams = field(default_factory=TextChunkParams) + enable_text_chunk: bool = False + dedup_params: DedupParams | None = None + caption_params: CaptionParams | None = None + store_params: StoreParams | None = None + connection: ServiceIngestConnectionOptions = field(default_factory=ServiceIngestConnectionOptions) + + +@dataclass(frozen=True) +class ServiceIngestExecutionResult: + """Structured result from executing a resolved service ingest request. + + Service mode does not locally verify the remote vector database after + ingest. ``result_n_rows`` counts rows from the service ingest result when + available, and ``n_rows`` mirrors that value so root CLI summaries keep the + same top-level row-count contract as local ingest results. + """ + + request: ServiceIngestRequest + result: object + n_rows: int | None + result_n_rows: int | None + metadata: dict[str, Any] + + @property + def documents(self) -> list[str]: + return list(self.request.documents) + + @property + def service_url(self) -> str: + return self.request.connection.service_url + + def to_summary_dict(self) -> dict[str, Any]: + return { + "run_mode": "service", + "documents": self.documents, + "service_url": self.service_url, + "result": self.result, + "n_rows": self.n_rows, + "result_n_rows": self.result_n_rows, + } + + +def resolve_service_ingest_request(request: ServiceIngestPlanRequest) -> ServiceIngestRequest: + """Resolve first-class root service ingest options into service params.""" + + source = request.source + input_type = _validate_service_input_type(source.input_type) + profile = _validate_service_profile(source.profile) + documents = resolve_service_documents(source.documents, input_type=input_type) + _validate_service_profile_documents(profile, documents) + + extract_kwargs = _service_profile_extract_defaults(profile) + extract_kwargs.update( + { + key: value + for key, value in { + "method": request.extract.method, + "dpi": request.extract.dpi, + "extract_text": request.extract.extract_text, + "extract_images": request.extract.extract_images, + "extract_tables": request.extract.extract_tables, + "extract_charts": request.extract.extract_charts, + "extract_infographics": request.extract.extract_infographics, + "extract_page_as_image": request.extract.extract_page_as_image, + "use_page_elements": request.extract.use_page_elements, + "use_graphic_elements": request.extract.use_graphic_elements, + "use_table_structure": request.extract.use_table_structure, + "table_output_format": request.extract.table_output_format, + "ocr_version": request.extract.ocr_version, + }.items() + if value is not None + } + ) + if request.extract.table_output_format == "markdown": + extract_kwargs["use_table_structure"] = True + + embed_kwargs = { + key: value + for key, value in { + "embed_modality": request.embed.embed_modality, + "text_elements_modality": request.embed.text_elements_modality, + "structured_elements_modality": request.embed.structured_elements_modality, + "embed_granularity": request.embed.embed_granularity, + }.items() + if value is not None + } + enable_text_chunk, text_chunk_params = _build_service_text_chunk_params(request.chunk) + + return ServiceIngestRequest( + documents=documents, + input_type=input_type, + extract_params=ExtractParams(**extract_kwargs), + embed_params=EmbedParams(**embed_kwargs) if embed_kwargs else EmbedParams(), + text_chunk_params=text_chunk_params, + enable_text_chunk=enable_text_chunk, + dedup_params=_build_service_dedup_params(request.dedup), + caption_params=_build_service_caption_params(request.caption), + store_params=_build_service_store_params(request.image_store), + connection=request.connection, + ) + + +def build_service_ingestor(request: ServiceIngestRequest) -> Any: + """Construct a remote-service ingestor with service-compatible stages.""" + + from nemo_retriever.service_ingestor import ServiceIngestor + + resolved_files = expand_service_file_patterns(request.documents) + if not resolved_files: + raise ValueError("No files matched the input patterns for service mode.") + + ingestor = ServiceIngestor( + base_url=request.connection.service_url, + max_concurrency=request.connection.service_concurrency, + api_token=request.connection.service_api_token, + ).files(resolved_files) + + ingestor = _attach_service_extract_stage( + ingestor, + input_type=request.input_type, + documents=resolved_files, + extract_params=request.extract_params, + enable_text_chunk=request.enable_text_chunk, + text_chunk_params=request.text_chunk_params, + ) + + if request.dedup_params is not None: + ingestor = ingestor.dedup(request.dedup_params) + + if request.caption_params is not None: + ingestor = ingestor.caption(_sanitize_service_caption_params(request.caption_params)) + + ingestor = ingestor.embed(request.embed_params) + + if request.store_params is not None: + ingestor = ingestor.store(request.store_params) + + return ingestor + + +def execute_service_ingest_request(request: ServiceIngestRequest) -> ServiceIngestExecutionResult: + """Execute a service ingest request and return its structured result.""" + + result = build_service_ingestor(request).ingest() + result_n_rows = _count_service_result_rows(result) + return ServiceIngestExecutionResult( + request=request, + result=result, + n_rows=result_n_rows, + result_n_rows=result_n_rows, + metadata={ + "service_url": request.connection.service_url, + "input_type": request.input_type, + }, + ) + + +def resolve_service_documents( + documents: Sequence[str], + *, + input_type: ServiceIngestInputTypeValue = "auto", +) -> list[str]: + """Expand root service ingest files/directories/globs without creating an ingestor.""" + + inputs: list[str] = [] + for document in documents: + raw_document = str(document) + path = Path(raw_document).expanduser() + if path.is_dir(): + directory_files = resolve_input_files(path, input_type) + if not directory_files: + raise FileNotFoundError(f"No supported ingest files found under directory: {path}") + inputs.extend(str(file) for file in directory_files) + else: + inputs.append(raw_document) + + document_list = expand_input_file_patterns(inputs) + _validate_service_document_types(document_list, input_type=input_type) + return document_list + + +def expand_service_file_patterns(documents: Sequence[str]) -> list[str]: + """Expand recursive file patterns for service ingest construction.""" + + resolved_files: list[str] = [] + for pattern in documents: + resolved_files.extend(sorted(_glob.glob(str(pattern), recursive=True))) + return resolved_files + + +def service_split_config_for_request(request: ServiceIngestRequest) -> dict[str, Any] | None: + """Build the service split configuration for a resolved ingest request.""" + + chunk_dict = _service_text_chunk_dict(request.text_chunk_params) if request.enable_text_chunk else None + return _split_config_for_input_type(request.input_type, chunk_dict, documents=request.documents) + + +def _validate_service_input_type(input_type: str) -> ServiceIngestInputTypeValue: + if input_type not in _SUPPORTED_SERVICE_INPUT_TYPES: + raise ValueError(f"input_type must be one of {', '.join(_SUPPORTED_SERVICE_INPUT_TYPES)}, got {input_type!r}.") + return cast(ServiceIngestInputTypeValue, input_type) + + +def _validate_service_profile(profile: str) -> ServiceIngestProfileValue: + if profile not in _SUPPORTED_SERVICE_PROFILES: + raise ValueError(f"profile must be one of {', '.join(_SUPPORTED_SERVICE_PROFILES)}, got {profile!r}.") + return cast(ServiceIngestProfileValue, profile) + + +def _validate_service_document_types( + documents: Sequence[str], + *, + input_type: ServiceIngestInputTypeValue, +) -> None: + allowed_extensions = AUTO_INPUT_EXTENSIONS if input_type == "auto" else INPUT_TYPE_EXTENSIONS[input_type] + unsupported = [ + document + for document in documents + if not _glob.has_magic(str(document)) and Path(document).suffix.lower() not in allowed_extensions + ] + if unsupported: + examples = ", ".join(unsupported[:3]) + raise ValueError(f"Unsupported input file type(s) for retriever ingest: {examples}") + + +def _validate_service_profile_documents(profile: ServiceIngestProfileValue, documents: Sequence[str]) -> None: + if profile != "fast-text": + return + disallowed = sorted( + { + family + for document in documents + if not _glob.has_magic(str(document)) + for family in [input_type_for_path(document)] + if family not in {"pdf", "doc"} + } + ) + if disallowed: + observed = ", ".join(disallowed) + raise ValueError(f"--profile {profile} only supports PDF/document inputs; observed {observed}.") + + +def _service_profile_extract_defaults(profile: ServiceIngestProfileValue) -> dict[str, Any]: + if profile == "fast-text": + return { + "method": "pdfium", + "extract_text": True, + "extract_images": False, + "extract_tables": False, + "extract_charts": False, + "extract_infographics": False, + "extract_page_as_image": False, + "use_page_elements": False, + } + return {} + + +def _build_service_text_chunk_params(chunk: ServiceIngestChunkOptions) -> tuple[bool, TextChunkParams]: + enabled = ( + bool(chunk.enabled) or chunk.text_chunk_max_tokens is not None or chunk.text_chunk_overlap_tokens is not None + ) + if not enabled: + return False, TextChunkParams() + return True, TextChunkParams( + max_tokens=( + int(chunk.text_chunk_max_tokens) + if chunk.text_chunk_max_tokens is not None + else _DEFAULT_TEXT_CHUNK_MAX_TOKENS + ), + overlap_tokens=( + int(chunk.text_chunk_overlap_tokens) + if chunk.text_chunk_overlap_tokens is not None + else _DEFAULT_TEXT_CHUNK_OVERLAP_TOKENS + ), + ) + + +def _build_service_dedup_params(dedup: ServiceIngestDedupOptions) -> DedupParams | None: + if not dedup.enabled: + if dedup.iou_threshold is not None: + raise ValueError("Dedup options require --dedup: dedup_iou_threshold.") + return None + dedup_kwargs = {} + if dedup.iou_threshold is not None: + dedup_kwargs["iou_threshold"] = dedup.iou_threshold + return DedupParams(**dedup_kwargs) + + +def _build_service_caption_params(caption: ServiceIngestCaptionOptions) -> CaptionParams | None: + overrides = { + "caption_context_text_max_chars": caption.context_text_max_chars, + "caption_infographics": caption.caption_infographics, + } + if not caption.enabled: + provided = [name for name, value in overrides.items() if value is not None] + if provided: + raise ValueError(f"Caption options require --caption: {', '.join(provided)}.") + return None + if caption.context_text_max_chars is not None and caption.context_text_max_chars < 0: + raise ValueError("caption_context_text_max_chars must be >= 0.") + + caption_kwargs = { + key: value + for key, value in { + "context_text_max_chars": caption.context_text_max_chars, + "caption_infographics": caption.caption_infographics, + }.items() + if value is not None + } + return CaptionParams(**caption_kwargs) + + +def _build_service_store_params(image_store: ServiceIngestImageStoreOptions) -> StoreParams | None: + if image_store.images_uri is None: + return None + return StoreParams(storage_uri=image_store.images_uri) + + +def _service_extraction_mode(input_type: str) -> str: + """Map ingest input type to :class:`PipelineSpec` ``extraction_mode``.""" + + return { + "pdf": "pdf", + "doc": "pdf", + "txt": "text", + "html": "html", + "audio": "audio", + "video": "auto", + }.get(input_type, "auto") + + +def _service_text_chunk_dict(text_chunk_params: TextChunkParams) -> dict[str, Any]: + """Serialize text-chunk knobs allowed by the service split_config policy.""" + + from nemo_retriever.service.policy import _DEFAULT_ALLOWED_SPLIT_KEYS + + raw = text_chunk_params.model_dump(exclude_none=True) + return {key: value for key, value in raw.items() if key in _DEFAULT_ALLOWED_SPLIT_KEYS} + + +def _attach_service_extract_stage( + ingestor: Any, + *, + input_type: str, + documents: Sequence[str], + extract_params: ExtractParams, + enable_text_chunk: bool, + text_chunk_params: TextChunkParams, +) -> Any: + """Wire the extraction stage for the remote service ingestor.""" + + chunk_dict = _service_text_chunk_dict(text_chunk_params) if enable_text_chunk else None + if input_type == "image": + return ingestor.extract_image_files( + extract_params, + split_config={"image": chunk_dict} if chunk_dict else None, + ) + return ingestor.extract( + extract_params, + split_config=_split_config_for_input_type(input_type, chunk_dict, documents=documents), + extraction_mode=_service_extraction_mode(input_type), + ) + + +def _split_config_for_input_type( + input_type: str, + chunk_dict: dict[str, Any] | None, + *, + documents: Sequence[str] = (), +) -> dict[str, Any] | None: + if chunk_dict is None: + return None + if input_type == "auto": + return _split_config_for_auto_documents(documents, chunk_dict) + if input_type in {"pdf", "doc"}: + return {"pdf": chunk_dict} + if input_type == "txt": + return {"text": chunk_dict} + if input_type == "html": + return {"html": chunk_dict} + if input_type == "image": + return {"image": chunk_dict} + if input_type == "audio": + return {"audio": chunk_dict} + if input_type == "video": + return {"video": chunk_dict, "audio": chunk_dict} + return None + + +def _split_config_for_auto_documents( + documents: Sequence[str], + chunk_dict: dict[str, Any], +) -> dict[str, Any] | None: + input_types = {input_type_for_path(document) for document in documents if not _glob.has_magic(str(document))} + split_config: dict[str, Any] = {} + if input_types & {"pdf", "doc"}: + split_config["pdf"] = dict(chunk_dict) + if "txt" in input_types: + split_config["text"] = dict(chunk_dict) + if "html" in input_types: + split_config["html"] = dict(chunk_dict) + if "image" in input_types: + split_config["image"] = dict(chunk_dict) + if "audio" in input_types: + split_config["audio"] = dict(chunk_dict) + if "video" in input_types: + split_config["video"] = dict(chunk_dict) + split_config["audio"] = dict(chunk_dict) + return split_config or None + + +def _sanitize_service_caption_params(caption_params: CaptionParams) -> CaptionParams: + params_dict = caption_params.model_dump(exclude_none=True, exclude_unset=True) + return CaptionParams( + **{ + key: value + for key, value in params_dict.items() + if key + in { + "prompt", + "system_prompt", + "batch_size", + "context_text_max_chars", + "caption_infographics", + "temperature", + "max_tokens", + "top_p", + "top_k", + } + } + ) + + +def _count_service_result_rows(result: object) -> int | None: + dataframe = getattr(result, "dataframe", None) + if dataframe is None: + return None + try: + return len(dataframe) + except TypeError: + return None diff --git a/nemo_retriever/src/nemo_retriever/pipeline/__main__.py b/nemo_retriever/src/nemo_retriever/pipeline/__main__.py index f0e4c8113..f6f0b9f5b 100644 --- a/nemo_retriever/src/nemo_retriever/pipeline/__main__.py +++ b/nemo_retriever/src/nemo_retriever/pipeline/__main__.py @@ -58,6 +58,7 @@ import pandas as pd import typer +from nemo_retriever.ingest import service as ingest_service from nemo_retriever.ingest.execution import execute_ingest_plan from nemo_retriever.ingest.plan import ( IngestCaptionOptions, @@ -121,7 +122,7 @@ # bound to local execution (Ray actors, GPU placement), or never wired # through the service ingestor (VDB upload is handled server-side; audio # and video extract paths still run locally). Flags wired into the -# service ``PipelineSpec`` by ``_build_service_ingestor`` — extract knobs, embed +# service ``PipelineSpec`` by ``ingest.service.build_service_ingestor`` — extract knobs, embed # granularity / modality, dedup threshold, caption behaviour, text chunk # config, ``--store-images-uri`` — are intentionally NOT in this list and # pass through to ``ServiceIngestor``; the server's @@ -191,16 +192,6 @@ def _reject_service_incompatible_flags(ctx: typer.Context) -> None: - """Raise ``typer.BadParameter`` if any ingest-only flag was user-supplied. - - Only flags whose click parameter source is ``COMMANDLINE`` or - ``ENVIRONMENT`` are treated as user-supplied — flags carrying their - declared default do not trigger the error. - """ - # Compare by enum *name*, not identity: depending on the environment, - # typer may return a source from its vendored ``typer._click.core`` enum - # rather than ``click.core.ParameterSource``, and the two enums are - # distinct objects whose members never compare equal via ``in``. user_set: list[str] = [] for cli_flag, param_name in _SERVICE_INCOMPATIBLE_FLAGS: source = ctx.get_parameter_source(param_name) @@ -496,69 +487,6 @@ def _build_embed_params( ) -def _service_extraction_mode(input_type: str) -> str: - """Map CLI ``--input-type`` to :class:`PipelineSpec` ``extraction_mode``.""" - return { - "pdf": "pdf", - "doc": "pdf", - "txt": "text", - "html": "html", - "audio": "audio", - "video": "auto", - }.get(input_type, "auto") - - -def _service_text_chunk_dict(text_chunk_params: TextChunkParams) -> dict[str, Any]: - """Serialize text-chunk knobs allowed by the service split_config policy.""" - from nemo_retriever.service.policy import _DEFAULT_ALLOWED_SPLIT_KEYS - - raw = text_chunk_params.model_dump(exclude_none=True) - return {key: value for key, value in raw.items() if key in _DEFAULT_ALLOWED_SPLIT_KEYS} - - -def _attach_service_extract_stage( - ingestor: Any, - *, - input_type: str, - extract_params: ExtractParams, - enable_text_chunk: bool, - text_chunk_params: TextChunkParams, -) -> Any: - """Wire the extraction stage for the remote service ingestor.""" - chunk_dict = _service_text_chunk_dict(text_chunk_params) if enable_text_chunk else None - if input_type == "image": - return ingestor.extract_image_files( - extract_params, - split_config={"image": chunk_dict} if chunk_dict else None, - ) - return ingestor.extract( - extract_params, - split_config=_split_config_for_input_type(input_type, chunk_dict), - extraction_mode=_service_extraction_mode(input_type), - ) - - -def _split_config_for_input_type( - input_type: str, - chunk_dict: Optional[dict[str, Any]], -) -> Optional[dict[str, Any]]: - if chunk_dict is None: - return None - if input_type in {"pdf", "doc"}: - return {"pdf": chunk_dict} - if input_type == "txt": - return {"text": chunk_dict} - if input_type == "html": - return {"html": chunk_dict} - if input_type == "image": - return {"image": chunk_dict} - if input_type == "audio": - return {"audio": chunk_dict} - if input_type == "video": - return {"video": chunk_dict, "audio": chunk_dict} - return None - - def _parse_vdb_kwargs_json(vdb_kwargs_json: Optional[str]) -> dict[str, Any]: """Parse opaque nv-ingest-client VDB constructor kwargs from CLI JSON.""" if vdb_kwargs_json: @@ -572,76 +500,6 @@ def _parse_vdb_kwargs_json(vdb_kwargs_json: Optional[str]) -> dict[str, Any]: return {} -def _build_service_ingestor( - *, - file_patterns: list[str], - input_type: str, - extract_params: ExtractParams, - embed_params: EmbedParams, - text_chunk_params: TextChunkParams, - enable_text_chunk: bool, - enable_dedup: bool, - enable_caption: bool, - dedup_iou_threshold: float, - caption_invoke_url: Optional[str], - caption_context_text_max_chars: int, - caption_temperature: float, - caption_top_p: Optional[float], - caption_max_tokens: int, - store_images_uri: Optional[str], - service_url: str = "http://localhost:7670", - service_concurrency: int = 8, - service_api_token: Optional[str] = None, -) -> Any: - """Construct a remote-service ingestor with service-compatible stages.""" - from nemo_retriever.service_ingestor import ServiceIngestor - - resolved_files: list[str] = [] - for pattern in file_patterns: - resolved_files.extend(sorted(_glob.glob(pattern, recursive=True))) - if not resolved_files: - raise typer.BadParameter("No files matched the input patterns for service mode.") - - ingestor = ServiceIngestor( - base_url=service_url, - max_concurrency=service_concurrency, - api_token=service_api_token, - ).files(resolved_files) - - ingestor = _attach_service_extract_stage( - ingestor, - input_type=input_type, - extract_params=extract_params, - enable_text_chunk=enable_text_chunk, - text_chunk_params=text_chunk_params, - ) - - if enable_dedup: - ingestor = ingestor.dedup(DedupParams(iou_threshold=dedup_iou_threshold)) - - if enable_caption: - if caption_invoke_url is not None: - logger.warning( - "Ignoring --caption-invoke-url in service mode; the retriever service " - "uses its operator-configured caption endpoint." - ) - ingestor = ingestor.caption( - CaptionParams( - context_text_max_chars=caption_context_text_max_chars, - temperature=caption_temperature, - top_p=caption_top_p, - max_tokens=caption_max_tokens, - ) - ) - - ingestor = ingestor.embed(embed_params) - - if store_images_uri is not None: - ingestor = ingestor.store(StoreParams(storage_uri=store_images_uri)) - - return ingestor - - def _collect_results(run_mode: str, result: Any) -> tuple[list[dict[str, Any]], Any, float, int]: """Materialize the graph result into a list of records + DataFrame. @@ -1427,7 +1285,7 @@ def run( pipeline_vdb_upload = VdbUploadParams(vdb_op=resolved_vdb_op, vdb_kwargs=resolved_vdb_kwargs) logger.info("Building graph pipeline (run_mode=%s) for %s ...", run_mode, input_path) - ingestor = None + service_request = None ingest_plan = None local_execute_kwargs: dict[str, Any] = {} if run_mode == "service": @@ -1481,25 +1339,30 @@ def run( embed_gpus_per_actor=embed_gpus_per_actor, local_ingest_embed_backend=local_ingest_embed_backend, ) - ingestor = _build_service_ingestor( - file_patterns=file_patterns, + service_request = ingest_service.ServiceIngestRequest( + documents=file_patterns, input_type=input_type, extract_params=extract_params, embed_params=embed_params, text_chunk_params=text_chunk_params, enable_text_chunk=enable_text_chunk, - enable_dedup=enable_dedup, - enable_caption=enable_caption, - dedup_iou_threshold=dedup_iou_threshold, - caption_invoke_url=caption_invoke_url, - caption_context_text_max_chars=caption_context_text_max_chars, - caption_temperature=caption_temperature, - caption_top_p=caption_top_p, - caption_max_tokens=caption_max_tokens, - store_images_uri=store_images_uri, - service_url=service_url, - service_concurrency=service_concurrency, - service_api_token=service_api_token, + dedup_params=DedupParams(iou_threshold=dedup_iou_threshold) if enable_dedup else None, + caption_params=( + CaptionParams( + context_text_max_chars=caption_context_text_max_chars, + temperature=caption_temperature, + top_p=caption_top_p, + max_tokens=caption_max_tokens, + ) + if enable_caption + else None + ), + store_params=StoreParams(storage_uri=store_images_uri) if store_images_uri is not None else None, + connection=ingest_service.ServiceIngestConnectionOptions( + service_url=service_url, + service_concurrency=service_concurrency, + service_api_token=service_api_token, + ), ) else: if store_actors and store_images_uri is None: @@ -1638,7 +1501,9 @@ def run( def _run_ingest() -> Any: if run_mode == "service": - return ingestor.ingest() + if service_request is None: + raise RuntimeError("service_request must be resolved before execution in service mode") + return ingest_service.execute_service_ingest_request(service_request).result if ingest_plan is None: raise RuntimeError("ingest_plan must be resolved before execution in non-service mode") return execute_ingest_plan(ingest_plan, **local_execute_kwargs).result diff --git a/nemo_retriever/tests/test_pipeline_helpers.py b/nemo_retriever/tests/test_pipeline_helpers.py index c8a097492..a24d4118e 100644 --- a/nemo_retriever/tests/test_pipeline_helpers.py +++ b/nemo_retriever/tests/test_pipeline_helpers.py @@ -11,10 +11,10 @@ import typer import nemo_retriever.pipeline as pipeline_pkg +from nemo_retriever.ingest.service import ServiceIngestRequest, build_service_ingestor from nemo_retriever.params import EmbedParams, ExtractParams, TextChunkParams from nemo_retriever.pipeline.__main__ import ( _build_embed_params, - _build_service_ingestor, _collect_results, _count_input_units, _count_uploadable_vdb_records, @@ -62,22 +62,15 @@ def test_build_service_ingestor_wires_extract_embed_and_chunking(tmp_path: Path) pdf = tmp_path / "doc.pdf" pdf.write_bytes(b"%PDF-1.4") - ingestor = _build_service_ingestor( - file_patterns=[str(pdf)], - input_type="pdf", - extract_params=ExtractParams(method="ocr", extract_text=False, dpi=300), - embed_params=EmbedParams(embed_granularity="page"), - text_chunk_params=TextChunkParams(max_tokens=64, overlap_tokens=8), - enable_text_chunk=True, - enable_dedup=False, - enable_caption=False, - dedup_iou_threshold=0.8, - caption_invoke_url=None, - caption_context_text_max_chars=0, - caption_temperature=1.0, - caption_top_p=None, - caption_max_tokens=1024, - store_images_uri=None, + ingestor = build_service_ingestor( + ServiceIngestRequest( + documents=[str(pdf)], + input_type="pdf", + extract_params=ExtractParams(method="ocr", extract_text=False, dpi=300), + embed_params=EmbedParams(embed_granularity="page"), + text_chunk_params=TextChunkParams(max_tokens=64, overlap_tokens=8), + enable_text_chunk=True, + ) ) assert isinstance(ingestor, ServiceIngestor) diff --git a/nemo_retriever/tests/test_root_cli_workflow.py b/nemo_retriever/tests/test_root_cli_workflow.py index 4e62be16d..23fac84ab 100644 --- a/nemo_retriever/tests/test_root_cli_workflow.py +++ b/nemo_retriever/tests/test_root_cli_workflow.py @@ -101,6 +101,152 @@ def fake_create_ingestor(**kwargs: Any) -> Any: assert "Ingested 1 file(s) → 7 row(s) in LanceDB lancedb/nemo-retriever." in result.output +def test_root_ingest_service_mode_uses_service_ingest_core(tmp_path, monkeypatch) -> None: + import nemo_retriever.service_ingestor as service_ingestor_module + + document = tmp_path / "service.pdf" + document.write_bytes(b"%PDF-1.4\n") + captured: dict[str, Any] = {} + + class _FakeServiceIngestor(list): + def __init__(self, *args: Any, **kwargs: Any) -> None: + super().__init__() + captured["init"] = kwargs + self.dataframe = None + + def files(self, files: list[str]): + captured["files"] = files + return self + + def extract(self, params=None, *, split_config=None, extraction_mode="auto", **_kwargs): + captured["extract_params"] = params + captured["split_config"] = split_config + captured["extraction_mode"] = extraction_mode + return self + + def dedup(self, params=None, **_kwargs): + captured["dedup_params"] = params + return self + + def caption(self, params=None, **_kwargs): + captured["caption_params"] = params + return self + + def embed(self, params=None, **_kwargs): + captured["embed_params"] = params + return self + + def ingest(self, *args: Any, **kwargs: Any): + return self + + monkeypatch.setattr(service_ingestor_module, "ServiceIngestor", _FakeServiceIngestor) + + result = RUNNER.invoke( + cli_main.app, + [ + "ingest", + str(document), + "--run-mode", + "service", + "--service-url", + "http://retriever-service:7670", + "--service-concurrency", + "3", + "--service-api-token", + "service-token", + "--dpi", + "300", + "--extract-images", + "--embed-granularity", + "page", + "--dedup", + "--dedup-iou-threshold", + "0.6", + "--caption", + "--caption-context-text-max-chars", + "12", + "--text-chunk", + "--text-chunk-max-tokens", + "64", + ], + ) + + assert result.exit_code == 0, result.output + assert captured["init"] == { + "base_url": "http://retriever-service:7670", + "max_concurrency": 3, + "api_token": "service-token", + } + assert captured["files"] == [str(document)] + assert captured["extraction_mode"] == "auto" + assert captured["extract_params"].dpi == 300 + assert captured["extract_params"].extract_images is True + assert captured["split_config"]["pdf"]["max_tokens"] == 64 + assert captured["dedup_params"].iou_threshold == 0.6 + assert captured["caption_params"].context_text_max_chars == 12 + assert captured["embed_params"].embed_granularity == "page" + assert "through retriever service http://retriever-service:7670" in result.output + + +def test_root_ingest_service_dry_run_redacts_token(tmp_path, monkeypatch) -> None: + import nemo_retriever.service_ingestor as service_ingestor_module + + document = tmp_path / "service.pdf" + document.write_bytes(b"%PDF-1.4\n") + + def fail_service_ingestor(*_args: Any, **_kwargs: Any) -> None: + raise AssertionError("ServiceIngestor should not be created for --dry-run") + + monkeypatch.setattr(service_ingestor_module, "ServiceIngestor", fail_service_ingestor) + + result = RUNNER.invoke( + cli_main.app, + [ + "ingest", + str(document), + "--run-mode", + "service", + "--service-api-token", + "service-token", + "--dry-run", + ], + ) + + assert result.exit_code == 0, result.output + payload = json.loads(result.output) + assert payload["run_mode"] == "service" + assert payload["documents"] == [str(document)] + assert payload["service"]["service_api_token"] == "" + assert payload["service"]["service_url"] == "http://localhost:7670" + + +def test_root_ingest_service_mode_rejects_local_only_flags(tmp_path) -> None: + document = tmp_path / "service.pdf" + document.write_bytes(b"%PDF-1.4\n") + + result = RUNNER.invoke( + cli_main.app, + [ + "ingest", + str(document), + "--run-mode", + "service", + "--lancedb-uri", + "custom-db", + "--embed-invoke-url", + "http://embed.example/v1", + "--ray-address", + "ray://localhost:10001", + ], + ) + + assert result.exit_code != 0 + assert "--run-mode=service" in result.output + assert "--lancedb-uri" in result.output + assert "--embed-invoke-url" in result.output + assert "--ray-address" in result.output + + def test_root_ingest_passes_vdb_options_and_run_mode(monkeypatch, tmp_path) -> None: fake_ingestor = _make_fake_ingestor() create_calls: list[dict[str, Any]] = []