From 9f4e7b04c49b73eb61e4a2adeeb44a605f9ab9a5 Mon Sep 17 00:00:00 2001 From: Anthony Shoumikhin Date: Mon, 18 May 2026 15:35:14 -0700 Subject: [PATCH] fix: ExecuTorch export of TRT engines >2 GB by storing engine as Tensor attribute MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit `torch_tensorrt.save(..., output_format="executorch", retrace=False)` fails on models whose serialized TensorRT engine is larger than ~2 GB. The error: ``` SyntaxError: unterminated string literal (detected at line 6) (.N, line 6) ``` The threshold isn't TensorRT's: it's CPython's. The Python tokenizer cannot parse string literals larger than `INT32_MAX` (~2 GiB) — see `Parser/tokenizer.c`. When FX compiles a graph whose source contains a giant string literal, `compile()` raises this `SyntaxError`. Smaller engines (a few hundred MB) export fine; the issue is invisible until you cross the limit. ```python import torch import torch_tensorrt class Big(torch.nn.Module): def __init__(self): super().__init__() # Big enough that the FP32 TRT engine exceeds ~2 GB. self.linear = torch.nn.Linear(32768, 32768) def forward(self, x): return self.linear(x) model = Big().cuda().eval() example = (torch.randn(1, 32768, device="cuda"),) trt_program = torch_tensorrt.dynamo.compile( torch.export.export(model, example), inputs=list(example), enabled_precisions={torch.float32}, ) torch_tensorrt.save( trt_program, "model.pte", output_format="executorch", retrace=False, arg_inputs=list(example), ) ``` `py/torch_tensorrt/_compile.py::_replace_execute_engine_for_executorch` base64-encodes the engine bytes into a Python `str` and passes that `str` as a **positional argument** to `torch.ops.tensorrt.no_op_placeholder_for_execute_engine`: ```python engine_info[ENGINE_IDX] = base64.b64encode(engine_bytes).decode("utf-8") ... no_op_node = gm.graph.call_function(no_op, (inputs_arg, *engine_info_strs)) ``` When `gm.recompile()` then re-emits the FX graph as Python source, that literal string lands directly in the source as `'ZnRydA…'`. The bigger the engine, the bigger the literal. Past ~2 GB the source no longer tokenizes and `exec(compile(src, …))` raises `SyntaxError`. The op schema was also declared with `serialized_engine: str` (`py/torch_tensorrt/dynamo/runtime/meta_ops/register_meta_ops.py`), which encouraged the inline-literal codegen path. Treat the engine the same way FX treats any other large binary constant: store it as a tensor attribute on the graph module and reference it from the call site via a `get_attr` node. The literal that lands in source is then `self._trt_engine_0`, not 2 GB of base64. Three source files change: 1. **`py/torch_tensorrt/dynamo/runtime/meta_ops/register_meta_ops.py`** Change `serialized_engine: str` → `serialized_engine: torch.Tensor` in both the `@torch.library.custom_op` and `@torch.library.register_fake` declarations of `tensorrt::no_op_placeholder_for_execute_engine`. 2. **`py/torch_tensorrt/_compile.py::_replace_execute_engine_for_executorch`** Instead of base64-encoding the engine bytes into a string positional arg, wrap them in a `torch.uint8` 1-D tensor, register them as a persistent buffer on the graph module (`_trt_engine_`), create a `get_attr` FX node pointing at that buffer, and place the `get_attr` node at the `ENGINE_IDX` slot of the no_op call. Every other slot stays `str` — they are small metadata. The now-orphan original engine attribute is `delattr`'d from the module so it doesn't double-serialize into `state_dict` alongside the new buffer. 3. **`py/torch_tensorrt/executorch/backend.py`** When reading the `tensorrt::no_op_placeholder_for_execute_engine` node inside the ExecuTorch backend, resolve the `get_attr` FX node at `ENGINE_IDX` back to the underlying `torch.Tensor` and convert it to --- py/torch_tensorrt/_compile.py | 79 ++++++++++++++++--- .../runtime/meta_ops/register_meta_ops.py | 4 +- py/torch_tensorrt/executorch/backend.py | 75 +++++++++++++++--- tests/py/dynamo/executorch/test_backend.py | 11 ++- 4 files changed, 141 insertions(+), 28 deletions(-) diff --git a/py/torch_tensorrt/_compile.py b/py/torch_tensorrt/_compile.py index 21da99447f..1afccff535 100644 --- a/py/torch_tensorrt/_compile.py +++ b/py/torch_tensorrt/_compile.py @@ -1119,14 +1119,18 @@ def _replace_execute_engine_for_executorch(exp_program: Any) -> Any: """Replace execute_engine nodes with no_op_placeholder_for_execute_engine. ExecuTorch's to_edge_transform_and_lower runs ExportPass subclasses that - dispatch through the C++ schema validator. The validator rejects the + dispatch through the C++ schema validator. The validator rejects the ScriptObject engine arg (it arrives as a CustomObjArgument placeholder - rather than a real FakeScriptObject). Converting each execute_engine node + rather than a real FakeScriptObject). Converting each execute_engine node to no_op_placeholder_for_execute_engine (which carries all engine info as plain strings) avoids the ScriptObject entirely so the passes succeed. - """ - import base64 + The TRT engine bytes are stored as a ``torch.uint8`` buffer on the graph + module and referenced from the no_op call via a ``get_attr`` FX node. This + keeps the engine out of the FX-emitted Python source: CPython's tokenizer + cannot parse string literals larger than ~2 GB, so an inline base64 string + breaks ``gm.recompile()`` for any engine whose payload exceeds that limit. + """ from torch_tensorrt.dynamo.runtime._TorchTensorRTModule import ( ENGINE_IDX, SERIALIZATION_LEN, @@ -1144,7 +1148,7 @@ def _replace_execute_engine_for_executorch(exp_program: Any) -> Any: if not nodes_to_replace: return exp_program - for node in nodes_to_replace: + for engine_idx_in_graph, node in enumerate(nodes_to_replace): inputs_arg = node.args[0] engine_node = node.args[1] @@ -1176,17 +1180,60 @@ def _replace_execute_engine_for_executorch(exp_program: Any) -> Any: _validate_executorch_engine_info(engine_info, node_name=node.name) # Ensure the engine bytes slot is a base64 string (no_op takes str args). engine_bytes = engine_info[ENGINE_IDX] - if isinstance(engine_bytes, (bytes, bytearray)): - engine_info[ENGINE_IDX] = base64.b64encode(engine_bytes).decode("utf-8") + if isinstance(engine_bytes, str): + # `_get_engine_info_from_state` returns the engine as a + # base64-encoded `str` when the engine arrived through the + # serialized TRT runtime round-trip path. Decode back to raw + # bytes so it can land in a uint8 buffer below. + import base64 + + engine_bytes = base64.b64decode(engine_bytes) + elif not isinstance(engine_bytes, (bytes, bytearray)): + engine_bytes = bytes(engine_bytes) + # Store engine payload as a uint8 buffer + get_attr ref. FX emits a + # name reference instead of an inline literal, sidestepping the + # tokenizer's >2 GB string-literal limit. + engine_tensor = torch.frombuffer(bytearray(engine_bytes), dtype=torch.uint8) + # Use FX's unique-attr-name helper so re-export passes (which may + # invoke this rewriter multiple times on the same `gm`) don't + # silently overwrite earlier engine buffers. + from torch.fx.experimental.const_fold import ( + get_unique_attr_name_in_module, + ) - engine_info_strs = [ + buffer_name = get_unique_attr_name_in_module(gm, "_trt_engine_0") + gm.register_buffer(buffer_name, engine_tensor, persistent=True) + exp_program.state_dict[buffer_name] = engine_tensor + + str_args = [ str(x) if x is not None else "" for x in engine_info[:SERIALIZATION_LEN] ] + # Build a FakeTensor mirror so downstream FX passes (FakeTensorProp, + # ExecuTorch lowering, export-serde) that read `node.meta["val"]` + # on the `get_attr` reference don't `KeyError`. Must reuse the + # graph's existing FakeTensorMode — creating a fresh one would + # fail downstream with "fake mode from input 0 doesn't match + # mode from input 1" the moment any pass mixes the two. + from torch._guards import detect_fake_mode + + fake_mode = detect_fake_mode( + [n.meta["val"] for n in gm.graph.nodes if "val" in n.meta] + ) + fake_engine = ( + fake_mode.from_tensor(engine_tensor) + if fake_mode is not None + else engine_tensor + ) with gm.graph.inserting_before(node): - no_op_node = gm.graph.call_function( - no_op, - (inputs_arg, *engine_info_strs), + engine_attr_node = gm.graph.get_attr(buffer_name) + engine_attr_node.meta["val"] = fake_engine + no_op_args = ( + inputs_arg, + *str_args[:ENGINE_IDX], + engine_attr_node, + *str_args[ENGINE_IDX + 1 :], ) + no_op_node = gm.graph.call_function(no_op, no_op_args) no_op_node.meta["val"] = node.meta.get("val") node.replace_all_uses_with(no_op_node) @@ -1195,6 +1242,16 @@ def _replace_execute_engine_for_executorch(exp_program: Any) -> Any: # Erase the engine get_attr node if it is now unused. if engine_node.op == "get_attr" and not engine_node.users: gm.graph.erase_node(engine_node) + # Also drop the now-orphan attribute from the module so the + # original engine bytes aren't double-serialized into state_dict + # alongside the new uint8 buffer. Use FX's dotted-path helper + # so nested-target attrs (e.g. `submod.engine`) are deleted + # correctly — plain `delattr(gm, "a.b")` only works on + # top-level names. + from torch.fx.graph_module import _del_attr, _has_attr + + if _has_attr(gm, engine_node.target): + _del_attr(gm, engine_node.target) gm.graph.eliminate_dead_code() gm.graph.lint() diff --git a/py/torch_tensorrt/dynamo/runtime/meta_ops/register_meta_ops.py b/py/torch_tensorrt/dynamo/runtime/meta_ops/register_meta_ops.py index 66ba5dfbc4..c8fa088ef1 100644 --- a/py/torch_tensorrt/dynamo/runtime/meta_ops/register_meta_ops.py +++ b/py/torch_tensorrt/dynamo/runtime/meta_ops/register_meta_ops.py @@ -312,7 +312,7 @@ def no_op_placeholder_for_execute_engine( abi_version: str, name: str, serialized_device_info: str, - serialized_engine: str, + serialized_engine: torch.Tensor, serialized_in_binding_names: str, serialized_out_binding_names: str, serialized_hardware_compatible: str, @@ -333,7 +333,7 @@ def fake_no_op_placeholder_for_execute_engine( abi_version: str, name: str, serialized_device_info: str, - serialized_engine: str, + serialized_engine: torch.Tensor, serialized_in_binding_names: str, serialized_out_binding_names: str, serialized_hardware_compatible: str, diff --git a/py/torch_tensorrt/executorch/backend.py b/py/torch_tensorrt/executorch/backend.py index 5cb34e2009..03c7236afa 100644 --- a/py/torch_tensorrt/executorch/backend.py +++ b/py/torch_tensorrt/executorch/backend.py @@ -1,8 +1,9 @@ # ExecuTorch TensorRT backend: serialize engines to a libtorch-free runtime blob. -import base64 from typing import Any, List, final +import torch +import torch.fx from executorch.exir.backend.backend_details import ( BackendDetails, CompileSpec, @@ -74,7 +75,56 @@ def _get_engine_info_from_edge_program(edge_program: ExportedProgram) -> List[An name = _schema_name(node.target) if name == "tensorrt::no_op_placeholder_for_execute_engine": - return list(node.args[1:]) + engine_info = list(node.args[1:]) + # ENGINE_IDX slot is either a `get_attr` FX node (when this runs + # before constant-lifting) or a `placeholder` FX node (after + # ExecuTorch's lifter rewrote the get_attr into a graph input + # referencing the buffer). Resolve both shapes to the raw uint8 + # tensor so the rest of the backend can stay engine-format + # agnostic. + engine_slot = engine_info[ENGINE_IDX] + if isinstance(engine_slot, torch.fx.Node): + engine_tensor = None + if engine_slot.op == "get_attr": + engine_tensor = getattr(gm, engine_slot.target, None) + elif engine_slot.op == "placeholder": + # The lifter mangles the placeholder name (e.g. + # "b__trt_engine_0" with a "b_" buffer prefix). The + # canonical attribute target lives in + # graph_signature.input_specs[i].target. + target = engine_slot.target + sig = getattr(edge_program, "graph_signature", None) + if sig is not None: + for ispec in sig.input_specs: + arg = getattr(ispec, "arg", None) + if ( + arg is not None + and getattr(arg, "name", None) == engine_slot.name + ): + target = ispec.target or target + break + state_dict = getattr(edge_program, "state_dict", {}) or {} + constants = getattr(edge_program, "constants", {}) or {} + # Explicit None-check: `state_dict.get(target) or ...` + # would call `bool(tensor)`, which raises + # "Boolean value of Tensor with more than one element + # is ambiguous" for any multi-element engine tensor. + engine_tensor = state_dict.get(target) + if engine_tensor is None: + engine_tensor = constants.get(target) + else: + raise RuntimeError( + f"no_op_placeholder node '{node.name}': unexpected engine " + f"slot op '{engine_slot.op}' (target={engine_slot.target})" + ) + if engine_tensor is None: + raise RuntimeError( + f"no_op_placeholder node '{node.name}': engine slot " + f"'{engine_slot.target}' (op={engine_slot.op}) did not " + f"resolve to a tensor in gm, state_dict, or constants" + ) + engine_info[ENGINE_IDX] = engine_tensor + return engine_info engine_node = node.args[1] if engine_node.op == "get_attr": @@ -163,16 +213,17 @@ def preprocess( engine_info = list(engine_info) _validate_engine_info(engine_info) serialized_engine = engine_info[ENGINE_IDX] - if isinstance(serialized_engine, str): - try: - engine_info[ENGINE_IDX] = base64.b64decode( - serialized_engine.encode("utf-8") - ) - except Exception as exc: - raise RuntimeError( - "TensorRT ExecuTorch backend failed to decode the serialized " - "engine payload." - ) from exc + if isinstance(serialized_engine, torch.Tensor): + # Single copy out of the underlying storage. The prior + # `.numpy().tobytes()` path allocated a fresh bytes buffer + # on top of the numpy view, which for a >2 GB engine + # roughly doubled peak memory at this step. `.cpu()` and + # `.contiguous()` are no-ops when already host-side and + # contiguous (the common case for the uint8 buffer this + # backend produces). + engine_info[ENGINE_IDX] = bytes( + serialized_engine.cpu().contiguous().untyped_storage() + ) elif not isinstance(serialized_engine, (bytes, bytearray)): engine_info[ENGINE_IDX] = bytes(serialized_engine) input_names = _split_binding_names( diff --git a/tests/py/dynamo/executorch/test_backend.py b/tests/py/dynamo/executorch/test_backend.py index 92aaa9e361..89ab51d0f5 100644 --- a/tests/py/dynamo/executorch/test_backend.py +++ b/tests/py/dynamo/executorch/test_backend.py @@ -1,10 +1,10 @@ -import base64 from types import SimpleNamespace import pytest executorch = pytest.importorskip("executorch.exir") +import torch # noqa: E402 from torch_tensorrt.dynamo.runtime._TorchTensorRTModule import ( # noqa: E402 DEVICE_IDX, ENGINE_IDX, @@ -44,9 +44,14 @@ def _make_edge_program(*nodes): ) +def _engine_tensor(payload: bytes) -> torch.Tensor: + return torch.frombuffer(bytearray(payload), dtype=torch.uint8) + + @pytest.mark.unit def test_get_engine_info_rejects_multiple_engine_nodes(): engine_info = [""] * SERIALIZATION_LEN + engine_info[ENGINE_IDX] = _engine_tensor(b"engine") edge_program = _make_edge_program( _make_placeholder_node(*engine_info), _make_placeholder_node(*engine_info), @@ -59,7 +64,7 @@ def test_get_engine_info_rejects_multiple_engine_nodes(): @pytest.mark.unit def test_preprocess_rejects_output_allocator(): engine_info = [""] * SERIALIZATION_LEN - engine_info[ENGINE_IDX] = base64.b64encode(b"engine").decode("utf-8") + engine_info[ENGINE_IDX] = _engine_tensor(b"engine") engine_info[REQUIRES_OUTPUT_ALLOCATOR_IDX] = "1" edge_program = _make_edge_program(_make_placeholder_node(*engine_info)) @@ -70,7 +75,7 @@ def test_preprocess_rejects_output_allocator(): @pytest.mark.unit def test_preprocess_serializes_engine_blob(): engine_info = [""] * SERIALIZATION_LEN - engine_info[ENGINE_IDX] = base64.b64encode(b"engine-bytes").decode("utf-8") + engine_info[ENGINE_IDX] = _engine_tensor(b"engine-bytes") engine_info[DEVICE_IDX] = "2%8%0%0%GPU" engine_info[INPUT_BINDING_NAMES_IDX] = "x" engine_info[OUTPUT_BINDING_NAMES_IDX] = "y"