Skip to content

fix: ExecuTorch export of TRT engines >2 GB by storing engine as Tensor attribute#4277

Open
shoumikhin wants to merge 1 commit into
pytorch:mainfrom
shoumikhin:fix/executorch-engine-as-tensor
Open

fix: ExecuTorch export of TRT engines >2 GB by storing engine as Tensor attribute#4277
shoumikhin wants to merge 1 commit into
pytorch:mainfrom
shoumikhin:fix/executorch-engine-as-tensor

Conversation

@shoumikhin
Copy link
Copy Markdown

@shoumikhin shoumikhin commented May 18, 2026

Problem

torch_tensorrt.save(..., output_format="executorch", retrace=False) fails on
models whose serialized TensorRT engine is larger than ~2 GB. The error:

SyntaxError: unterminated string literal (detected at line 6) (<eval_with_key>.N, line 6)

The threshold isn't TensorRT's: it's CPython's. The Python tokenizer cannot
parse string literals larger than INT32_MAX (~2 GiB) — see
Parser/tokenizer.c. When FX compiles a graph whose source contains a
giant string literal, compile() raises this SyntaxError. Smaller
engines (a few hundred MB) export fine; the issue is invisible until you
cross the limit.

Reproduction

import torch
import torch_tensorrt

class Big(torch.nn.Module):
    def __init__(self):
        super().__init__()
        # Big enough that the FP32 TRT engine exceeds ~2 GB.
        self.linear = torch.nn.Linear(32768, 32768)
    def forward(self, x):
        return self.linear(x)

model = Big().cuda().eval()
example = (torch.randn(1, 32768, device="cuda"),)

trt_program = torch_tensorrt.dynamo.compile(
    torch.export.export(model, example),
    inputs=list(example),
    enabled_precisions={torch.float32},
)

# Boom: SyntaxError: unterminated string literal
torch_tensorrt.save(
    trt_program,
    "model.pte",
    output_format="executorch",
    retrace=False,
    arg_inputs=list(example),
)

Root cause

py/torch_tensorrt/_compile.py::_replace_execute_engine_for_executorch
base64-encodes the engine bytes into a Python str and passes that str
as a positional argument to
torch.ops.tensorrt.no_op_placeholder_for_execute_engine:

engine_info[ENGINE_IDX] = base64.b64encode(engine_bytes).decode("utf-8")
...
no_op_node = gm.graph.call_function(no_op, (inputs_arg, *engine_info_strs))

When gm.recompile() then re-emits the FX graph as Python source, that
literal string lands directly in the source as 'ZnRydA…'. The bigger
the engine, the bigger the literal. Past ~2 GB the source no longer
tokenizes and exec(compile(src, …)) raises SyntaxError.

The op schema was also declared with serialized_engine: str
(py/torch_tensorrt/dynamo/runtime/meta_ops/register_meta_ops.py), which
encouraged the inline-literal codegen path.

Fix

Treat the engine the same way FX treats any other large binary constant:
store it as a tensor attribute on the graph module and reference it from
the call site via a get_attr node. The literal that lands in source is
then self._trt_engine_0, not 2 GB of base64.

The op schema also changes from serialized_engine: str to
serialized_engine: torch.Tensor so the rewriter and the ExecuTorch
backend agree on what flows through the ENGINE_IDX slot. Every other
slot stays str — they are small metadata.

Why this approach (and tradeoffs considered)

The original serialized_engine: str schema is not a careless choice —
there are real reasons one might prefer all-primitive op signatures over
Tensor args at constant positions. Notable tradeoffs:

  • All-primitive schemas are simpler for downstream tooling. A
    torch.library.custom_op with only str / int / float args is
    trivially serializable by any consumer of the schema, and avoids any
    ambiguity about whether a Tensor arg represents a runtime input or a
    graph-baked constant. Switching one slot to torch.Tensor means the
    ExecuTorch backend (and any other consumer) has to special-case the
    resolution from get_attr FX node → underlying buffer.
  • String args are easy to hash / dedupe. If two call sites reference
    the same engine, str-equality is O(n) but trivial. Tensor-by-content
    equality is more involved.
  • No risk of accidental device placement. A str is host-only by
    definition. A uint8 constant tensor could in principle be picked up by
    a memory-planning pass and staged onto an accelerator, which would
    waste memory for an inert payload.

The fix keeps the tensor on CPU explicitly (it's created via
torch.frombuffer on bytes) and the backend always resolves it back
to host bytes before serialization, so the placement risk above is
contained.

The portability advantage of all-primitive schemas is real, but it does
not buy anything for an op whose intended consumer is the in-tree
TensorRTBackend. And it does not survive contact with engines >2 GB,
which the inline-literal codegen path fundamentally cannot encode no
matter what we do at the producer side. The Tensor-attribute approach
is the same pattern FX, AOTI, and ExecuTorch already use for every
other large binary constant in a graph; this PR brings the engine
carrier in line with that convention.

Testing

End-to-end verified on a large vision model whose FP32 TRT engine is
~2.2 GB: export to .pte now succeeds with no SyntaxError, and the
ExecuTorch runtime loads and runs the resulting .pte against
reference outputs without divergence. The previously-failing repro
(torch_tensorrt.save(..., output_format="executorch", retrace=False)
from above) now completes cleanly. ExecuTorch backend unit tests
updated to use the new Tensor-based schema.

Backwards compatibility

This breaks .pte files produced by the old exporter. The op
schema for tensorrt::no_op_placeholder_for_execute_engine changed
from serialized_engine: str to serialized_engine: torch.Tensor, so
the new runtime cannot load .pte files that were serialized with the
previous string-based schema.

Action: regenerate any cached .pte files with the new exporter
before loading them with the new runtime. The TRT engine itself has
not changed format — only how it's carried through the FX graph and
the ExecuTorch delegate boundary.

I did not version-tag the op signature to keep both schemas alive in
parallel: it would add a permanent compatibility branch to the runtime
for a payload format nobody should keep around (the old path could not
ship engines >2 GB at all, so the volume of stale .pte files using
the old schema should be small).

@meta-cla meta-cla Bot added the cla signed label May 18, 2026
@github-actions github-actions Bot added component: tests Issues re: Tests component: core Issues re: The core compiler component: api [Python] Issues re: Python API component: runtime component: dynamo Issues relating to the `torch.compile` or `torch._dynamo.export` paths labels May 18, 2026
@github-actions github-actions Bot requested a review from cehongwang May 18, 2026 22:36
@shoumikhin shoumikhin marked this pull request as ready for review May 18, 2026 22:53
@shoumikhin shoumikhin force-pushed the fix/executorch-engine-as-tensor branch 9 times, most recently from bde257f to 0e4ea40 Compare May 19, 2026 01:26
…or attribute

`torch_tensorrt.save(..., output_format="executorch", retrace=False)` fails on
models whose serialized TensorRT engine is larger than ~2 GB. The error:

```
SyntaxError: unterminated string literal (detected at line 6) (<eval_with_key>.N, line 6)
```

The threshold isn't TensorRT's: it's CPython's. The Python tokenizer cannot
parse string literals larger than `INT32_MAX` (~2 GiB) — see
`Parser/tokenizer.c`. When FX compiles a graph whose source contains a
giant string literal, `compile()` raises this `SyntaxError`. Smaller
engines (a few hundred MB) export fine; the issue is invisible until you
cross the limit.

```python
import torch
import torch_tensorrt

class Big(torch.nn.Module):
    def __init__(self):
        super().__init__()
        # Big enough that the FP32 TRT engine exceeds ~2 GB.
        self.linear = torch.nn.Linear(32768, 32768)
    def forward(self, x):
        return self.linear(x)

model = Big().cuda().eval()
example = (torch.randn(1, 32768, device="cuda"),)

trt_program = torch_tensorrt.dynamo.compile(
    torch.export.export(model, example),
    inputs=list(example),
    enabled_precisions={torch.float32},
)

torch_tensorrt.save(
    trt_program,
    "model.pte",
    output_format="executorch",
    retrace=False,
    arg_inputs=list(example),
)
```

`py/torch_tensorrt/_compile.py::_replace_execute_engine_for_executorch`
base64-encodes the engine bytes into a Python `str` and passes that `str`
as a **positional argument** to
`torch.ops.tensorrt.no_op_placeholder_for_execute_engine`:

```python
engine_info[ENGINE_IDX] = base64.b64encode(engine_bytes).decode("utf-8")
...
no_op_node = gm.graph.call_function(no_op, (inputs_arg, *engine_info_strs))
```

When `gm.recompile()` then re-emits the FX graph as Python source, that
literal string lands directly in the source as `'ZnRydA…'`. The bigger
the engine, the bigger the literal. Past ~2 GB the source no longer
tokenizes and `exec(compile(src, …))` raises `SyntaxError`.

The op schema was also declared with `serialized_engine: str`
(`py/torch_tensorrt/dynamo/runtime/meta_ops/register_meta_ops.py`), which
encouraged the inline-literal codegen path.

Treat the engine the same way FX treats any other large binary constant:
store it as a tensor attribute on the graph module and reference it from
the call site via a `get_attr` node. The literal that lands in source is
then `self._trt_engine_0`, not 2 GB of base64.

Three source files change:

1. **`py/torch_tensorrt/dynamo/runtime/meta_ops/register_meta_ops.py`**
   Change `serialized_engine: str` → `serialized_engine: torch.Tensor` in
   both the `@torch.library.custom_op` and `@torch.library.register_fake`
   declarations of `tensorrt::no_op_placeholder_for_execute_engine`.

2. **`py/torch_tensorrt/_compile.py::_replace_execute_engine_for_executorch`**
   Instead of base64-encoding the engine bytes into a string positional
   arg, wrap them in a `torch.uint8` 1-D tensor, register them as a
   persistent buffer on the graph module (`_trt_engine_<i>`), create a
   `get_attr` FX node pointing at that buffer, and place the `get_attr`
   node at the `ENGINE_IDX` slot of the no_op call. Every other slot
   stays `str` — they are small metadata. The now-orphan original engine
   attribute is `delattr`'d from the module so it doesn't double-serialize
   into `state_dict` alongside the new buffer.

3. **`py/torch_tensorrt/executorch/backend.py`**
   When reading the `tensorrt::no_op_placeholder_for_execute_engine` node
   inside the ExecuTorch backend, resolve the `get_attr` FX node at
   `ENGINE_IDX` back to the underlying `torch.Tensor` and convert it to
@shoumikhin shoumikhin force-pushed the fix/executorch-engine-as-tensor branch from 0e4ea40 to 9f4e7b0 Compare May 19, 2026 22:14
@shoumikhin shoumikhin changed the base branch from release/2.12 to main May 19, 2026 22:15
@shoumikhin
Copy link
Copy Markdown
Author

Thanks @lanluo-nvidia for triaging this!

Rebased onto main and switched the base branch. The merge resolved cleanly into the new TR01 blob format on main: the engine now flows in as a torch.Tensor (preserving the >2 GB fix), and the round-trip test verifies the resulting blob carries the right magic, device id, and binding metadata via deserialize_engine. Re-requesting review from @cehongwang.

@shoumikhin shoumikhin changed the title [2.12] Fix ExecuTorch export of TRT engines >2 GB by storing engine as Tensor attribute instead of inline base64 string fix: ExecuTorch export of TRT engines >2 GB by storing engine as Tensor attribute May 19, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

cla signed component: api [Python] Issues re: Python API component: core Issues re: The core compiler component: dynamo Issues relating to the `torch.compile` or `torch._dynamo.export` paths component: runtime component: tests Issues re: Tests

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant