Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 27 additions & 4 deletions py/torch_tensorrt/_compile.py
Original file line number Diff line number Diff line change
Expand Up @@ -713,9 +713,11 @@ def save(

- If both dynamic_shapes and Input objects are provided, the explicit dynamic_shapes
parameter takes precedence.
kwargs: Additional format-specific kwargs. ``partitioners=`` is only used
with ``output_format="executorch"``; otherwise it is ignored with
a warning.
kwargs: Additional format-specific kwargs. ``partitioners=`` and
``compile_specs=`` are only used with ``output_format="executorch"``;
otherwise they are ignored with a warning. Pass
``compile_specs=[CompileSpec("target_device", b"cuda:<i>")]`` to
override the default target device (``cuda:0``).
"""
if isinstance(module, CudaGraphsTorchTensorRTModule):
module = module.compiled_module
Expand All @@ -741,6 +743,7 @@ def save(
raise ValueError("kwargs should not include None.")

executorch_partitioners = kwargs.pop("partitioners", None)
executorch_compile_specs = kwargs.pop("compile_specs", None)

def _all_are_input_objects(obj: Any) -> bool:
"""Recursively check if all elements in nested collections are Input objects."""
Expand Down Expand Up @@ -850,6 +853,11 @@ def _extract_tensor(obj: Any) -> Any:
"partitioners= is only used with output_format='executorch' and will be "
f"ignored for output_format='{output_format}'."
)
if executorch_compile_specs and output_format != "executorch":
logger.warning(
"compile_specs= is only used with output_format='executorch' and will "
f"be ignored for output_format='{output_format}'."
)
if output_format == "aot_inductor" and platform.system() != "Linux":
raise ValueError(
f"The AOT Inductor format is only supported on Linux, {platform.system()} is not a supported platform for this format"
Expand Down Expand Up @@ -916,6 +924,7 @@ def _extract_tensor(obj: Any) -> Any:
module,
file_path,
partitioners=executorch_partitioners,
compile_specs=executorch_compile_specs,
)
else:
raise RuntimeError(
Expand Down Expand Up @@ -979,6 +988,7 @@ def _extract_tensor(obj: Any) -> Any:
exp_program,
file_path,
partitioners=executorch_partitioners,
compile_specs=executorch_compile_specs,
)
else:
raise RuntimeError(
Expand Down Expand Up @@ -1063,6 +1073,7 @@ def _extract_tensor(obj: Any) -> Any:
exp_program,
file_path,
partitioners=executorch_partitioners,
compile_specs=executorch_compile_specs,
)
else:
raise RuntimeError(
Expand Down Expand Up @@ -1236,7 +1247,19 @@ def _save_as_executorch(exp_program: Any, file_path: str, **kwargs: Any) -> None
"partitioners must be a list or tuple when using "
"output_format='executorch'"
)
partitioners = [TensorRTPartitioner()] + list(extra_partitioners)
# Forward any caller-provided compile_specs to TensorRTPartitioner so users
# can override the default target_device ("cuda:0") by passing e.g.
# `compile_specs=[CompileSpec("target_device", b"cuda:1")]` to save().
# When omitted, TensorRTPartitioner auto-appends the cuda:0 default.
executorch_compile_specs = kwargs.get("compile_specs") or []
if not isinstance(executorch_compile_specs, (list, tuple)):
raise TypeError(
"compile_specs must be a list or tuple when using "
"output_format='executorch'"
)
partitioners = [
TensorRTPartitioner(compile_specs=list(executorch_compile_specs))
] + list(extra_partitioners)

engine_count = _count_executorch_engine_nodes(exp_program)
if engine_count > 1:
Expand Down
32 changes: 31 additions & 1 deletion py/torch_tensorrt/executorch/partitioner.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,21 +15,51 @@
from torch_tensorrt.executorch.backend import TensorRTBackend
from torch_tensorrt.executorch.operator_support import TensorRTOperatorSupport

# Key recognized by ExecuTorch's PropagateDevicePass that tags delegate I/O
# TensorSpecs with the target device, which is then serialized into the
# .pte's extra_tensor_info.device_type field.
#
# Prefer the canonical constant when ExecuTorch exposes it (will fail loudly
# at import time if the key is renamed upstream) and fall back to the inlined
# string for older ExecuTorch revisions that don't yet ship the constant.
try:
from executorch.exir.passes.propagate_device_pass import (
TARGET_DEVICE_COMPILE_SPEC_KEY as _TARGET_DEVICE_COMPILE_SPEC_KEY,
)
except ImportError:
_TARGET_DEVICE_COMPILE_SPEC_KEY = "target_device"


class TensorRTPartitioner(Partitioner): # type: ignore[misc]
"""Partitions the graph for TensorRT delegation.

Only nodes that are torch.ops.tensorrt.execute_engine are supported;
each such node becomes its own partition so the backend can serialize
the engine to the same format as the TRT runtime.

If `compile_specs` does not already contain a ``target_device`` entry,
one defaulting to ``cuda:0`` is auto-appended (mirroring CudaPartitioner).
Callers targeting a non-default GPU should pre-populate
``compile_specs`` with the desired ``CompileSpec("target_device",
b"cuda:<index>")`` to override the default.
"""

def __init__(
self,
compile_specs: Optional[List[CompileSpec]] = None,
) -> None:
super().__init__()
self.compile_specs = compile_specs or []
self.compile_specs = list(compile_specs) if compile_specs else []
# Mirror CudaPartitioner: emit a target_device CompileSpec so that
# ExecuTorch's PropagateDevicePass tags delegate I/O TensorSpecs with
# the correct device, which is then serialized into the .pte's
# extra_tensor_info.device_type field.
if not any(
s.key == _TARGET_DEVICE_COMPILE_SPEC_KEY for s in self.compile_specs
):
self.compile_specs.append(
CompileSpec(_TARGET_DEVICE_COMPILE_SPEC_KEY, b"cuda:0")
)
self.delegation_spec = DelegationSpec(
backend_id=TensorRTBackend.__name__,
compile_specs=self.compile_specs,
Expand Down
Loading