Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
79 changes: 68 additions & 11 deletions py/torch_tensorrt/_compile.py
Original file line number Diff line number Diff line change
Expand Up @@ -1119,14 +1119,18 @@ def _replace_execute_engine_for_executorch(exp_program: Any) -> Any:
"""Replace execute_engine nodes with no_op_placeholder_for_execute_engine.

ExecuTorch's to_edge_transform_and_lower runs ExportPass subclasses that
dispatch through the C++ schema validator. The validator rejects the
dispatch through the C++ schema validator. The validator rejects the
ScriptObject engine arg (it arrives as a CustomObjArgument placeholder
rather than a real FakeScriptObject). Converting each execute_engine node
rather than a real FakeScriptObject). Converting each execute_engine node
to no_op_placeholder_for_execute_engine (which carries all engine info as
plain strings) avoids the ScriptObject entirely so the passes succeed.
"""
import base64

The TRT engine bytes are stored as a ``torch.uint8`` buffer on the graph
module and referenced from the no_op call via a ``get_attr`` FX node. This
keeps the engine out of the FX-emitted Python source: CPython's tokenizer
cannot parse string literals larger than ~2 GB, so an inline base64 string
breaks ``gm.recompile()`` for any engine whose payload exceeds that limit.
"""
from torch_tensorrt.dynamo.runtime._TorchTensorRTModule import (
ENGINE_IDX,
SERIALIZATION_LEN,
Expand All @@ -1144,7 +1148,7 @@ def _replace_execute_engine_for_executorch(exp_program: Any) -> Any:
if not nodes_to_replace:
return exp_program

for node in nodes_to_replace:
for engine_idx_in_graph, node in enumerate(nodes_to_replace):
inputs_arg = node.args[0]
engine_node = node.args[1]

Expand Down Expand Up @@ -1176,17 +1180,60 @@ def _replace_execute_engine_for_executorch(exp_program: Any) -> Any:
_validate_executorch_engine_info(engine_info, node_name=node.name)
# Ensure the engine bytes slot is a base64 string (no_op takes str args).
engine_bytes = engine_info[ENGINE_IDX]
if isinstance(engine_bytes, (bytes, bytearray)):
engine_info[ENGINE_IDX] = base64.b64encode(engine_bytes).decode("utf-8")
if isinstance(engine_bytes, str):
# `_get_engine_info_from_state` returns the engine as a
# base64-encoded `str` when the engine arrived through the
# serialized TRT runtime round-trip path. Decode back to raw
# bytes so it can land in a uint8 buffer below.
import base64

engine_bytes = base64.b64decode(engine_bytes)
elif not isinstance(engine_bytes, (bytes, bytearray)):
engine_bytes = bytes(engine_bytes)
# Store engine payload as a uint8 buffer + get_attr ref. FX emits a
# name reference instead of an inline literal, sidestepping the
# tokenizer's >2 GB string-literal limit.
engine_tensor = torch.frombuffer(bytearray(engine_bytes), dtype=torch.uint8)
# Use FX's unique-attr-name helper so re-export passes (which may
# invoke this rewriter multiple times on the same `gm`) don't
# silently overwrite earlier engine buffers.
from torch.fx.experimental.const_fold import (
get_unique_attr_name_in_module,
)

engine_info_strs = [
buffer_name = get_unique_attr_name_in_module(gm, "_trt_engine_0")
gm.register_buffer(buffer_name, engine_tensor, persistent=True)
exp_program.state_dict[buffer_name] = engine_tensor

str_args = [
str(x) if x is not None else "" for x in engine_info[:SERIALIZATION_LEN]
]
# Build a FakeTensor mirror so downstream FX passes (FakeTensorProp,
# ExecuTorch lowering, export-serde) that read `node.meta["val"]`
# on the `get_attr` reference don't `KeyError`. Must reuse the
# graph's existing FakeTensorMode — creating a fresh one would
# fail downstream with "fake mode from input 0 doesn't match
# mode from input 1" the moment any pass mixes the two.
from torch._guards import detect_fake_mode

fake_mode = detect_fake_mode(
[n.meta["val"] for n in gm.graph.nodes if "val" in n.meta]
)
fake_engine = (
fake_mode.from_tensor(engine_tensor)
if fake_mode is not None
else engine_tensor
)
with gm.graph.inserting_before(node):
no_op_node = gm.graph.call_function(
no_op,
(inputs_arg, *engine_info_strs),
engine_attr_node = gm.graph.get_attr(buffer_name)
engine_attr_node.meta["val"] = fake_engine
no_op_args = (
inputs_arg,
*str_args[:ENGINE_IDX],
engine_attr_node,
*str_args[ENGINE_IDX + 1 :],
)
no_op_node = gm.graph.call_function(no_op, no_op_args)
no_op_node.meta["val"] = node.meta.get("val")

node.replace_all_uses_with(no_op_node)
Expand All @@ -1195,6 +1242,16 @@ def _replace_execute_engine_for_executorch(exp_program: Any) -> Any:
# Erase the engine get_attr node if it is now unused.
if engine_node.op == "get_attr" and not engine_node.users:
gm.graph.erase_node(engine_node)
# Also drop the now-orphan attribute from the module so the
# original engine bytes aren't double-serialized into state_dict
# alongside the new uint8 buffer. Use FX's dotted-path helper
# so nested-target attrs (e.g. `submod.engine`) are deleted
# correctly — plain `delattr(gm, "a.b")` only works on
# top-level names.
from torch.fx.graph_module import _del_attr, _has_attr

if _has_attr(gm, engine_node.target):
_del_attr(gm, engine_node.target)

gm.graph.eliminate_dead_code()
gm.graph.lint()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -312,7 +312,7 @@ def no_op_placeholder_for_execute_engine(
abi_version: str,
name: str,
serialized_device_info: str,
serialized_engine: str,
serialized_engine: torch.Tensor,
serialized_in_binding_names: str,
serialized_out_binding_names: str,
serialized_hardware_compatible: str,
Expand All @@ -333,7 +333,7 @@ def fake_no_op_placeholder_for_execute_engine(
abi_version: str,
name: str,
serialized_device_info: str,
serialized_engine: str,
serialized_engine: torch.Tensor,
serialized_in_binding_names: str,
serialized_out_binding_names: str,
serialized_hardware_compatible: str,
Expand Down
75 changes: 63 additions & 12 deletions py/torch_tensorrt/executorch/backend.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
# ExecuTorch TensorRT backend: serialize engines to a libtorch-free runtime blob.

import base64
from typing import Any, List, final

import torch
import torch.fx
from executorch.exir.backend.backend_details import (
BackendDetails,
CompileSpec,
Expand Down Expand Up @@ -74,7 +75,56 @@ def _get_engine_info_from_edge_program(edge_program: ExportedProgram) -> List[An
name = _schema_name(node.target)

if name == "tensorrt::no_op_placeholder_for_execute_engine":
return list(node.args[1:])
engine_info = list(node.args[1:])
# ENGINE_IDX slot is either a `get_attr` FX node (when this runs
# before constant-lifting) or a `placeholder` FX node (after
# ExecuTorch's lifter rewrote the get_attr into a graph input
# referencing the buffer). Resolve both shapes to the raw uint8
# tensor so the rest of the backend can stay engine-format
# agnostic.
engine_slot = engine_info[ENGINE_IDX]
if isinstance(engine_slot, torch.fx.Node):
engine_tensor = None
if engine_slot.op == "get_attr":
engine_tensor = getattr(gm, engine_slot.target, None)
elif engine_slot.op == "placeholder":
# The lifter mangles the placeholder name (e.g.
# "b__trt_engine_0" with a "b_" buffer prefix). The
# canonical attribute target lives in
# graph_signature.input_specs[i].target.
target = engine_slot.target
sig = getattr(edge_program, "graph_signature", None)
if sig is not None:
for ispec in sig.input_specs:
arg = getattr(ispec, "arg", None)
if (
arg is not None
and getattr(arg, "name", None) == engine_slot.name
):
target = ispec.target or target
break
state_dict = getattr(edge_program, "state_dict", {}) or {}
constants = getattr(edge_program, "constants", {}) or {}
# Explicit None-check: `state_dict.get(target) or ...`
# would call `bool(tensor)`, which raises
# "Boolean value of Tensor with more than one element
# is ambiguous" for any multi-element engine tensor.
engine_tensor = state_dict.get(target)
if engine_tensor is None:
engine_tensor = constants.get(target)
else:
raise RuntimeError(
f"no_op_placeholder node '{node.name}': unexpected engine "
f"slot op '{engine_slot.op}' (target={engine_slot.target})"
)
if engine_tensor is None:
raise RuntimeError(
f"no_op_placeholder node '{node.name}': engine slot "
f"'{engine_slot.target}' (op={engine_slot.op}) did not "
f"resolve to a tensor in gm, state_dict, or constants"
)
engine_info[ENGINE_IDX] = engine_tensor
return engine_info

engine_node = node.args[1]
if engine_node.op == "get_attr":
Expand Down Expand Up @@ -163,16 +213,17 @@ def preprocess(
engine_info = list(engine_info)
_validate_engine_info(engine_info)
serialized_engine = engine_info[ENGINE_IDX]
if isinstance(serialized_engine, str):
try:
engine_info[ENGINE_IDX] = base64.b64decode(
serialized_engine.encode("utf-8")
)
except Exception as exc:
raise RuntimeError(
"TensorRT ExecuTorch backend failed to decode the serialized "
"engine payload."
) from exc
if isinstance(serialized_engine, torch.Tensor):
# Single copy out of the underlying storage. The prior
# `.numpy().tobytes()` path allocated a fresh bytes buffer
# on top of the numpy view, which for a >2 GB engine
# roughly doubled peak memory at this step. `.cpu()` and
# `.contiguous()` are no-ops when already host-side and
# contiguous (the common case for the uint8 buffer this
# backend produces).
engine_info[ENGINE_IDX] = bytes(
serialized_engine.cpu().contiguous().untyped_storage()
)
elif not isinstance(serialized_engine, (bytes, bytearray)):
engine_info[ENGINE_IDX] = bytes(serialized_engine)
input_names = _split_binding_names(
Expand Down
11 changes: 8 additions & 3 deletions tests/py/dynamo/executorch/test_backend.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
import base64
from types import SimpleNamespace

import pytest

executorch = pytest.importorskip("executorch.exir")

import torch # noqa: E402
from torch_tensorrt.dynamo.runtime._TorchTensorRTModule import ( # noqa: E402
DEVICE_IDX,
ENGINE_IDX,
Expand Down Expand Up @@ -44,9 +44,14 @@ def _make_edge_program(*nodes):
)


def _engine_tensor(payload: bytes) -> torch.Tensor:
return torch.frombuffer(bytearray(payload), dtype=torch.uint8)


@pytest.mark.unit
def test_get_engine_info_rejects_multiple_engine_nodes():
engine_info = [""] * SERIALIZATION_LEN
engine_info[ENGINE_IDX] = _engine_tensor(b"engine")
edge_program = _make_edge_program(
_make_placeholder_node(*engine_info),
_make_placeholder_node(*engine_info),
Expand All @@ -59,7 +64,7 @@ def test_get_engine_info_rejects_multiple_engine_nodes():
@pytest.mark.unit
def test_preprocess_rejects_output_allocator():
engine_info = [""] * SERIALIZATION_LEN
engine_info[ENGINE_IDX] = base64.b64encode(b"engine").decode("utf-8")
engine_info[ENGINE_IDX] = _engine_tensor(b"engine")
engine_info[REQUIRES_OUTPUT_ALLOCATOR_IDX] = "1"
edge_program = _make_edge_program(_make_placeholder_node(*engine_info))

Expand All @@ -70,7 +75,7 @@ def test_preprocess_rejects_output_allocator():
@pytest.mark.unit
def test_preprocess_serializes_engine_blob():
engine_info = [""] * SERIALIZATION_LEN
engine_info[ENGINE_IDX] = base64.b64encode(b"engine-bytes").decode("utf-8")
engine_info[ENGINE_IDX] = _engine_tensor(b"engine-bytes")
engine_info[DEVICE_IDX] = "2%8%0%0%GPU"
engine_info[INPUT_BINDING_NAMES_IDX] = "x"
engine_info[OUTPUT_BINDING_NAMES_IDX] = "y"
Expand Down
Loading