diff --git a/py/torch_tensorrt/_compile.py b/py/torch_tensorrt/_compile.py index 21da99447f..eb2b879daa 100644 --- a/py/torch_tensorrt/_compile.py +++ b/py/torch_tensorrt/_compile.py @@ -713,9 +713,11 @@ def save( - If both dynamic_shapes and Input objects are provided, the explicit dynamic_shapes parameter takes precedence. - kwargs: Additional format-specific kwargs. ``partitioners=`` is only used - with ``output_format="executorch"``; otherwise it is ignored with - a warning. + kwargs: Additional format-specific kwargs. ``partitioners=`` and + ``compile_specs=`` are only used with ``output_format="executorch"``; + otherwise they are ignored with a warning. Pass + ``compile_specs=[CompileSpec("target_device", b"cuda:")]`` to + override the default target device (``cuda:0``). """ if isinstance(module, CudaGraphsTorchTensorRTModule): module = module.compiled_module @@ -741,6 +743,7 @@ def save( raise ValueError("kwargs should not include None.") executorch_partitioners = kwargs.pop("partitioners", None) + executorch_compile_specs = kwargs.pop("compile_specs", None) def _all_are_input_objects(obj: Any) -> bool: """Recursively check if all elements in nested collections are Input objects.""" @@ -850,6 +853,11 @@ def _extract_tensor(obj: Any) -> Any: "partitioners= is only used with output_format='executorch' and will be " f"ignored for output_format='{output_format}'." ) + if executorch_compile_specs and output_format != "executorch": + logger.warning( + "compile_specs= is only used with output_format='executorch' and will " + f"be ignored for output_format='{output_format}'." + ) if output_format == "aot_inductor" and platform.system() != "Linux": raise ValueError( f"The AOT Inductor format is only supported on Linux, {platform.system()} is not a supported platform for this format" @@ -916,6 +924,7 @@ def _extract_tensor(obj: Any) -> Any: module, file_path, partitioners=executorch_partitioners, + compile_specs=executorch_compile_specs, ) else: raise RuntimeError( @@ -979,6 +988,7 @@ def _extract_tensor(obj: Any) -> Any: exp_program, file_path, partitioners=executorch_partitioners, + compile_specs=executorch_compile_specs, ) else: raise RuntimeError( @@ -1063,6 +1073,7 @@ def _extract_tensor(obj: Any) -> Any: exp_program, file_path, partitioners=executorch_partitioners, + compile_specs=executorch_compile_specs, ) else: raise RuntimeError( @@ -1236,7 +1247,19 @@ def _save_as_executorch(exp_program: Any, file_path: str, **kwargs: Any) -> None "partitioners must be a list or tuple when using " "output_format='executorch'" ) - partitioners = [TensorRTPartitioner()] + list(extra_partitioners) + # Forward any caller-provided compile_specs to TensorRTPartitioner so users + # can override the default target_device ("cuda:0") by passing e.g. + # `compile_specs=[CompileSpec("target_device", b"cuda:1")]` to save(). + # When omitted, TensorRTPartitioner auto-appends the cuda:0 default. + executorch_compile_specs = kwargs.get("compile_specs") or [] + if not isinstance(executorch_compile_specs, (list, tuple)): + raise TypeError( + "compile_specs must be a list or tuple when using " + "output_format='executorch'" + ) + partitioners = [ + TensorRTPartitioner(compile_specs=list(executorch_compile_specs)) + ] + list(extra_partitioners) engine_count = _count_executorch_engine_nodes(exp_program) if engine_count > 1: diff --git a/py/torch_tensorrt/executorch/partitioner.py b/py/torch_tensorrt/executorch/partitioner.py index 9fcab9f709..7fde508450 100644 --- a/py/torch_tensorrt/executorch/partitioner.py +++ b/py/torch_tensorrt/executorch/partitioner.py @@ -15,6 +15,20 @@ from torch_tensorrt.executorch.backend import TensorRTBackend from torch_tensorrt.executorch.operator_support import TensorRTOperatorSupport +# Key recognized by ExecuTorch's PropagateDevicePass that tags delegate I/O +# TensorSpecs with the target device, which is then serialized into the +# .pte's extra_tensor_info.device_type field. +# +# Prefer the canonical constant when ExecuTorch exposes it (will fail loudly +# at import time if the key is renamed upstream) and fall back to the inlined +# string for older ExecuTorch revisions that don't yet ship the constant. +try: + from executorch.exir.passes.propagate_device_pass import ( + TARGET_DEVICE_COMPILE_SPEC_KEY as _TARGET_DEVICE_COMPILE_SPEC_KEY, + ) +except ImportError: + _TARGET_DEVICE_COMPILE_SPEC_KEY = "target_device" + class TensorRTPartitioner(Partitioner): # type: ignore[misc] """Partitions the graph for TensorRT delegation. @@ -22,6 +36,12 @@ class TensorRTPartitioner(Partitioner): # type: ignore[misc] Only nodes that are torch.ops.tensorrt.execute_engine are supported; each such node becomes its own partition so the backend can serialize the engine to the same format as the TRT runtime. + + If `compile_specs` does not already contain a ``target_device`` entry, + one defaulting to ``cuda:0`` is auto-appended (mirroring CudaPartitioner). + Callers targeting a non-default GPU should pre-populate + ``compile_specs`` with the desired ``CompileSpec("target_device", + b"cuda:")`` to override the default. """ def __init__( @@ -29,7 +49,17 @@ def __init__( compile_specs: Optional[List[CompileSpec]] = None, ) -> None: super().__init__() - self.compile_specs = compile_specs or [] + self.compile_specs = list(compile_specs) if compile_specs else [] + # Mirror CudaPartitioner: emit a target_device CompileSpec so that + # ExecuTorch's PropagateDevicePass tags delegate I/O TensorSpecs with + # the correct device, which is then serialized into the .pte's + # extra_tensor_info.device_type field. + if not any( + s.key == _TARGET_DEVICE_COMPILE_SPEC_KEY for s in self.compile_specs + ): + self.compile_specs.append( + CompileSpec(_TARGET_DEVICE_COMPILE_SPEC_KEY, b"cuda:0") + ) self.delegation_spec = DelegationSpec( backend_id=TensorRTBackend.__name__, compile_specs=self.compile_specs,