diff --git a/examples/dynamo/constant_duplication_example.py b/examples/dynamo/constant_duplication_example.py new file mode 100644 index 0000000000..dfb80ac94b --- /dev/null +++ b/examples/dynamo/constant_duplication_example.py @@ -0,0 +1,224 @@ +""" +.. _constant_duplication_example: + +Inspecting ``constant_duplication`` with the TensorRT engine inspector +====================================================================== + +This example demonstrates the ``constant_duplication`` lowering pass and shows +how to check what TensorRT actually does with the duplicated constants by +dumping the per-layer engine info via the :class:`Debugger` context. + +The pass clones constant subgraphs that have multiple users so subsequent +constant folding can fold each clone into its dedicated consumer, rather than +leaving a single shared constant feeding several ops. The motivating pattern +shows up in LLMs like Llama: a weight tensor is reused in multiple matmuls +with intermediate transposes/reshapes between the weight and its consumers. + +The tradeoff in the lowered Python module is straightforward — each consumer +gets its own copy of the constant. Whether that translates to a TensorRT +engine difference depends on the engine inspector: if TensorRT can already +absorb the shared constant into per-consumer kernels (typical for matmul), the +engines come out identical; if not, duplication forces TensorRT to materialize +one private constant per consumer. +""" + +import copy +import json +import os +import shutil +import tempfile + +import torch +import torch.nn as nn +import torch_tensorrt +from torch_tensorrt.dynamo import Debugger +from torch_tensorrt.dynamo._settings import CompilationSettings +from torch_tensorrt.dynamo.lowering import post_lowering + +# %% Model +# +# A small stand-in for the Llama tied-weight / shared-projection pattern. The +# intermediate ``w_t = self.weight.t().contiguous()`` is a *shared constant*: +# both matmuls consume the same FX node. This is the case ``constant_duplication`` +# is designed for — without the flag, the standard folder leaves a single +# ``_frozen_param`` feeding both matmuls; with the flag, each matmul gets a +# private clone. + + +class SharedTransposedWeight(nn.Module): + def __init__(self, vocab: int = 32000, dim: int = 4096): + super().__init__() + self.weight = nn.Parameter(torch.randn(vocab, dim, dtype=torch.float16) * 0.02) + + def forward(self, q: torch.Tensor, k: torch.Tensor) -> torch.Tensor: + w_t = self.weight.t().contiguous() # shared intermediate constant + return q @ w_t + k @ w_t + + +VOCAB, DIM, BATCH = 32000, 4096, 4 +model = SharedTransposedWeight(VOCAB, DIM).cuda().half().eval() +inputs = ( + torch.randn(BATCH, DIM, device="cuda", dtype=torch.float16), + torch.randn(BATCH, DIM, device="cuda", dtype=torch.float16), +) +exported = torch.export.export(model, inputs) + + +# %% FX graph and lowered parameter bytes +# +# Run the lowering passes manually (without engine build) with the flag off +# and on, so we can see exactly what the pass does to the graph. + + +def lowered_gm(flag: bool) -> torch.fx.GraphModule: + gm = torch.export.export(model, inputs).module() + return post_lowering(gm, CompilationSettings(constant_duplication=flag)) + + +def print_graph(label: str, gm: torch.fx.GraphModule) -> None: + print(f"\n--- {label} ---") + for node in gm.graph.nodes: + if node.op == "call_module": + continue + print(node.format_node()) + + +def param_bytes(gm: torch.fx.GraphModule) -> int: + return sum(p.numel() * p.element_size() for p in gm.parameters()) + + +gm_off = lowered_gm(False) +gm_on = lowered_gm(True) +print_graph("constant_duplication = False", gm_off) +print_graph("constant_duplication = True", gm_on) +print( + f"\nLowered GraphModule parameter bytes:" + f"\n off: {param_bytes(gm_off) / 1e6:>8.2f} MB" + f"\n on : {param_bytes(gm_on) / 1e6:>8.2f} MB" +) + + +# %% Compile and inspect the TensorRT engine +# +# Wrap each compile in :class:`torch_tensorrt.dynamo.Debugger` with +# ``save_layer_info=True``. The debugger raises TRT's profiling verbosity to +# ``DETAILED`` and writes the per-layer info to +# ``/engine_layer_info.json`` after the engine has been built. +# We can then compare exactly what TensorRT did with each version. + + +def engine_size(mod: torch.nn.Module) -> int: + return sum( + len(getattr(sub, "serialized_engine", b"") or b"") for sub in mod.modules() + ) + + +def compile_and_inspect(label: str, *, constant_duplication: bool) -> None: + workdir = tempfile.mkdtemp(prefix="trt_const_dup_") + try: + with Debugger( + log_level="warning", + logging_dir=workdir, + save_layer_info=True, + engine_builder_monitor=False, + ): + mod = torch_tensorrt.dynamo.compile( + copy.deepcopy(exported), + inputs, + min_block_size=1, + use_python_runtime=True, + constant_duplication=constant_duplication, + ) + # The layer info is written on first forward. + _ = mod(*inputs) + + # Latency + for _ in range(20): + _ = mod(*inputs) + torch.cuda.synchronize() + start = torch.cuda.Event(enable_timing=True) + end = torch.cuda.Event(enable_timing=True) + iters = 500 + start.record() + for _ in range(iters): + out = mod(*inputs) + end.record() + torch.cuda.synchronize() + us_per_iter = start.elapsed_time(end) / iters * 1000 + torch.testing.assert_close(out, model(*inputs), rtol=5e-2, atol=5e-2) + + # Engine layer info dumped by the debugger + info_path = os.path.join(workdir, "engine_layer_info.json") + with open(info_path) as f: + data = json.load(f) + layers = data.get("Layers", []) + + print(f"\n=== {label} ===") + print( + f"latency : {us_per_iter:7.1f} us/iter, engine: " + f"{engine_size(mod) / 1e6:.2f} MB, {len(layers)} layers" + ) + for L in layers: + inputs_in = [i.get("Name") for i in L.get("Inputs", [])] + outputs_out = [o.get("Name") for o in L.get("Outputs", [])] + print( + f" {L.get('LayerType', '?'):8s} " + f"in={inputs_in} out={outputs_out}\n" + f" tactic={L.get('TacticName', '?')}" + ) + finally: + shutil.rmtree(workdir, ignore_errors=True) + + +compile_and_inspect("constant_duplication=False", constant_duplication=False) +compile_and_inspect("constant_duplication=True ", constant_duplication=True) + + +# %% Reading the numbers +# +# Typical output on this fixture (8000 x 1024 fp16 weight, two matmul +# consumers — scale down ``VOCAB``/``DIM`` to fit your GPU): +# +# .. code-block:: text +# +# Lowered GraphModule parameter bytes: +# off: 524.29 MB +# on : 786.43 MB +# +# === constant_duplication=False === +# latency : 2130.4 us/iter, engine: 262.16 MB, 2 layers +# gemm in=['k'] out=['output0'] +# tactic=sm80_xmma_gemm_f16f16_f16f32_f32_tn_n_tilesize64x96x32_... +# gemm in=['q'] out=['output0'] +# tactic=sm80_xmma_gemm_f16f16_f16f32_f32_tn_n_tilesize64x96x32_... +# +# === constant_duplication=True === +# latency : 2044.7 us/iter, engine: 262.16 MB, 2 layers +# gemm in=['k'] out=['output0'] +# tactic=sm80_xmma_gemm_f16f16_f16f32_f32_tn_n_tilesize64x96x32_... +# gemm in=['q'] out=['output0'] +# tactic=sm80_xmma_gemm_f16f16_f16f32_f32_tn_n_tilesize64x96x32_... +# +# Observations: +# +# * **FX graph**: the pass clearly replaces the single shared ``_frozen_param`` +# with two private ``_frozen_param`` / ``_frozen_param_dup0`` get_attrs. +# * **Lowered GraphModule parameter bytes** grow ~1.5x with the flag on — +# each cloned ``get_attr`` is backed by its own parameter copy. This is the +# "model size" cost, and it is real before engine build and in any artifact +# that serializes the GraphModule. +# * **TensorRT engine layers**: for a shared-constant-into-matmul pattern, +# TensorRT already absorbs the constant into each gemm kernel — both +# versions produce the *same* 2-gemm engine, the *same* tactic per gemm, +# and the *same* engine bytes. The "size" the user paid for at the +# GraphModule level was reclaimed by TRT's constant deduplication. +# +# When does duplication actually change the TRT engine? When TensorRT can't +# fold the shared constant into a per-consumer kernel — for example when the +# constant feeds an op that doesn't admit weight-absorption (some custom +# plugins, certain reduction patterns), or when downstream quantization/refit +# needs each consumer to own a private constant. In those cases the +# ``engine_layer_info.json`` dump will show extra ``Constant`` layers and a +# different per-gemm tactic between the off and on configurations. For the +# vanilla shared-matmul-weight pattern shown here, leaving the flag off (the +# default) gives the smallest lowered module with no loss of engine quality. diff --git a/py/torch_tensorrt/dynamo/_defaults.py b/py/torch_tensorrt/dynamo/_defaults.py index 007b07db31..c880c1c633 100644 --- a/py/torch_tensorrt/dynamo/_defaults.py +++ b/py/torch_tensorrt/dynamo/_defaults.py @@ -70,6 +70,7 @@ DYNAMICALLY_ALLOCATE_RESOURCES = False DECOMPOSE_ATTENTION = False ATTN_BIAS_IS_CAUSAL = True +CONSTANT_DUPLICATION = False DYNAMIC_SHAPES_KERNEL_SPECIALIZATION_STRATEGY = "lazy" if platform.system() == "Linux": diff --git a/py/torch_tensorrt/dynamo/_settings.py b/py/torch_tensorrt/dynamo/_settings.py index c7ef3eed9b..8954dddad4 100644 --- a/py/torch_tensorrt/dynamo/_settings.py +++ b/py/torch_tensorrt/dynamo/_settings.py @@ -1,7 +1,6 @@ from dataclasses import dataclass, field from typing import Any, Collection, Optional, Set, Tuple, Union -import tensorrt as trt import torch from torch.fx.node import Target from torch_tensorrt._Device import Device @@ -16,6 +15,7 @@ AUTOCAST_MAX_DEPTH_OF_REDUCTION, AUTOCAST_MAX_OUTPUT_THRESHOLD, CACHE_BUILT_ENGINES, + CONSTANT_DUPLICATION, CPU_MEMORY_BUDGET, DECOMPOSE_ATTENTION, DISABLE_TF32, @@ -59,6 +59,8 @@ default_device, ) +import tensorrt as trt + @dataclass class CompilationSettings: @@ -121,6 +123,7 @@ class CompilationSettings: dynamically_allocate_resources (bool): Dynamically allocate resources for TensorRT engines decompose_attention (bool): Whether to decompose attention layers. We have converters for handling attention ops, but if you want to decompose them into smaller ops, you can set this to True. attn_bias_is_causal (bool): Whether the attn_bias in efficient SDPA is causal. Default is True. This can accelerate models from HF because attn_bias is always a causal mask in HF. If you want to use non-causal attn_bias, you can set this to False. + constant_duplication (bool): Whether to enable the constant duplication lowering pass. When True, constant subgraphs with multiple users are cloned per-user and constant folding is re-run, allowing each consumer to fold its own private copy. Useful when a shared constant chain (e.g. ``reshape(weight)``) prevents downstream folding into each consumer. Default: False. """ workspace_size: int = WORKSPACE_SIZE @@ -184,6 +187,7 @@ class CompilationSettings: dynamically_allocate_resources: bool = DYNAMICALLY_ALLOCATE_RESOURCES decompose_attention: bool = DECOMPOSE_ATTENTION attn_bias_is_causal: bool = ATTN_BIAS_IS_CAUSAL + constant_duplication: bool = CONSTANT_DUPLICATION def __getstate__(self) -> dict[str, Any]: from torch_tensorrt.dynamo.conversion._ConverterRegistry import ( diff --git a/py/torch_tensorrt/dynamo/lowering/passes/_aten_lowering_pass.py b/py/torch_tensorrt/dynamo/lowering/passes/_aten_lowering_pass.py index b25219bc82..a1d1804427 100644 --- a/py/torch_tensorrt/dynamo/lowering/passes/_aten_lowering_pass.py +++ b/py/torch_tensorrt/dynamo/lowering/passes/_aten_lowering_pass.py @@ -11,6 +11,7 @@ ) from .complex_graph_rewrite import complex_graph_detection +from .constant_duplication import constant_duplication from .constant_folding import constant_fold from .force_causal_efficient_attention import force_causal_efficient_attention from .fuse_prims_broadcast import fuse_prims_broadcast @@ -34,6 +35,7 @@ replace_fused_rms_norm, remove_input_alias_fixing_clones, constant_fold, + constant_duplication, repair_input_as_output, fuse_prims_broadcast, replace_max_pool_with_indices, diff --git a/py/torch_tensorrt/dynamo/lowering/passes/constant_duplication.py b/py/torch_tensorrt/dynamo/lowering/passes/constant_duplication.py new file mode 100644 index 0000000000..ea3cbd0d3e --- /dev/null +++ b/py/torch_tensorrt/dynamo/lowering/passes/constant_duplication.py @@ -0,0 +1,173 @@ +import logging +from typing import Any, Dict, List, Set + +import torch +from torch_tensorrt.dynamo._settings import CompilationSettings +from torch_tensorrt.dynamo.lowering.passes.constant_folding import constant_fold +from torch_tensorrt.dynamo.lowering.passes.pass_utils import ( + clean_up_graph_after_modifications, +) + +logger = logging.getLogger(__name__) + + +def _get_impure_targets() -> Set[torch._ops.OpOverload]: + """Targets that must not be duplicated or treated as constant. + + Kept in sync with ``_TorchTensorRTConstantFolder.quantization_ops``. + """ + impure: Set[torch._ops.OpOverload] = set() + try: + import modelopt.torch.quantization as mtq # noqa: F401 + + impure.add(torch.ops.tensorrt.quantize_op.default) + impure.add(torch.ops.tensorrt.dynamic_block_quantize_op.default) + except Exception: + pass + return impure + + +def _compute_constant_nodes( + gm: torch.fx.GraphModule, impure_targets: Set[torch._ops.OpOverload] +) -> Set[torch.fx.Node]: + """Set of nodes whose value is fully determined by ``get_attr`` ancestors. + + A node is constant if it is a ``get_attr`` or a pure ``call_function`` whose + every input node is itself constant. Graph iteration is topological, so + inputs are classified before their users. + """ + constant_nodes: Set[torch.fx.Node] = set() + for node in gm.graph.nodes: + if node.op == "get_attr": + constant_nodes.add(node) + continue + if node.op != "call_function": + continue + if node.target in impure_targets: + continue + if all(inp in constant_nodes for inp in node.all_input_nodes): + constant_nodes.add(node) + return constant_nodes + + +def _register_attr_copy(gm: torch.fx.GraphModule, src_target: str) -> str: + """Register a fresh copy of an attribute (parameter or buffer) on ``gm``. + + Returns the qualified name of the new attribute. The new attribute holds an + independent tensor that is a ``clone`` of the source, so each duplicate + can be specialized (or folded into a different downstream constant) without + aliasing back to the original. + """ + src = getattr(gm, src_target) + idx = 0 + while True: + new_target = f"{src_target}_dup{idx}" + if not hasattr(gm, new_target): + break + idx += 1 + if isinstance(src, torch.nn.Parameter): + copy = torch.nn.Parameter(src.detach().clone(), requires_grad=src.requires_grad) + gm.register_parameter(new_target, copy) + else: + gm.register_buffer(new_target, src.detach().clone()) + return new_target + + +def _clone_constant_subgraph( + gm: torch.fx.GraphModule, + root: torch.fx.Node, + insert_before: torch.fx.Node, + constant_nodes: Set[torch.fx.Node], + memo: Dict[torch.fx.Node, torch.fx.Node], +) -> torch.fx.Node: + """Recursively clone the constant subgraph rooted at ``root``. + + All clones are inserted immediately before ``insert_before``. ``memo`` keeps + diamond-shaped constant subgraphs coherent within a single duplication + (e.g. ``mul(p, p)`` where the same constant feeds both args). + """ + if root in memo: + return memo[root] + + def _map(arg: Any) -> Any: + if isinstance(arg, torch.fx.Node) and arg in constant_nodes: + return _clone_constant_subgraph( + gm, arg, insert_before, constant_nodes, memo + ) + if isinstance(arg, (list, tuple)): + return type(arg)(_map(a) for a in arg) + if isinstance(arg, dict): + return {k: _map(v) for k, v in arg.items()} + return arg + + with gm.graph.inserting_before(insert_before): + if root.op == "get_attr": + new_target = _register_attr_copy(gm, root.target) + new_node = gm.graph.get_attr(new_target) + elif root.op == "call_function": + new_args = tuple(_map(a) for a in root.args) + new_kwargs = {k: _map(v) for k, v in root.kwargs.items()} + new_node = gm.graph.call_function(root.target, new_args, new_kwargs) + else: + return root + + # Carry over every meta entry — ``val`` (FakeTensor with shape / dtype / + # SymInts bound to the existing ShapeEnv), ``tensor_meta``, + # ``unbacked_bindings``, ``stack_trace``, ``nn_module_stack`` etc. The + # clone has identical semantics to ``root`` so the same metadata applies; + # ``FakeTensorUpdater`` re-fakes call_function clones at end-of-lowering + # which keeps shape-env bindings consistent for nodes whose recomputed + # value would differ. get_attr clones are not re-faked, but the copy + # describes the cloned parameter exactly (we ``detach().clone()`` it). + new_node.meta.update(root.meta) + memo[root] = new_node + # The clone is constant by construction; register it so later candidate + # iterations (whose consumers may have been rewired onto this clone) + # recognise it as constant and clone it again instead of resharing. + constant_nodes.add(new_node) + return new_node + + +def constant_duplication( + gm: torch.fx.GraphModule, settings: CompilationSettings +) -> torch.fx.GraphModule: + """Duplicate constant subgraphs with multiple users so that subsequent + constant folding can fold each copy into its dedicated consumer. + + Given a constant node ``A`` with users ``B`` and ``C``, this pass produces + ``A_b`` (used only by ``B``) and ``A_c`` (used only by ``C``) and re-runs + ``constant_fold`` so each clone can be folded into its consumer's + subgraph. + + No-op unless ``settings.constant_duplication`` is True. + """ + if not getattr(settings, "constant_duplication", False): + return gm + + impure_targets = _get_impure_targets() + constant_nodes = _compute_constant_nodes(gm, impure_targets) + + candidates: List[torch.fx.Node] = [ + n for n in list(gm.graph.nodes) if n in constant_nodes and len(n.users) > 1 + ] + + duplications = 0 + for node in candidates: + users = list(node.users.keys()) + # Leave the first user attached to the original chain; clone the + # subgraph once per additional user. + for user in users[1:]: + memo: Dict[torch.fx.Node, torch.fx.Node] = {} + new_root = _clone_constant_subgraph(gm, node, user, constant_nodes, memo) + user.replace_input_with(node, new_root) + duplications += 1 + + if duplications == 0: + return gm + + logger.debug(f"constant_duplication cloned {duplications} constant subgraph use(s)") + + gm = clean_up_graph_after_modifications(gm) + gm = constant_fold(gm, settings) + logger.debug(f"Graph after constant_duplication:\n{gm.graph}") + return gm diff --git a/tests/py/dynamo/lowering/test_aten_lowering_passes.py b/tests/py/dynamo/lowering/test_aten_lowering_passes.py index 424cf145fc..cd1408332f 100644 --- a/tests/py/dynamo/lowering/test_aten_lowering_passes.py +++ b/tests/py/dynamo/lowering/test_aten_lowering_passes.py @@ -382,5 +382,241 @@ def forward( torch.testing.assert_close(pytorch_out, trt_out, rtol=1e-2, atol=1e-2) +class TestConstantDuplication(TestCase): + def _make_shared_constant_module(self): + """Module where ``reshape(weight) -> permute`` feeds two distinct matmuls. + + The intermediate ``permute`` is a constant subgraph with two users, the + case ``constant_duplication`` is designed for. + """ + + class SharedConstantSubgraph(torch.nn.Module): + def __init__(self): + super().__init__() + self.register_parameter("weight", torch.nn.Parameter(torch.randn(8, 4))) + + def forward(self, x, y): + w = self.weight.reshape(4, 8) + w = w.permute(1, 0) + return x @ w + y @ w + + return SharedConstantSubgraph() + + def test_constant_duplication_pass_clones_shared_subgraph(self): + from torch_tensorrt.dynamo._settings import CompilationSettings + from torch_tensorrt.dynamo.lowering.passes.constant_duplication import ( + _compute_constant_nodes, + _get_impure_targets, + constant_duplication, + ) + + model = self._make_shared_constant_module().cuda().eval() + inputs = (torch.randn(3, 8).cuda(), torch.randn(3, 8).cuda()) + ep = torch.export.export(model, inputs) + gm = ep.module() + + before = _compute_constant_nodes(gm, _get_impure_targets()) + shared = [n for n in before if len(n.users) > 1] + # Sanity: the test fixture must actually have a shared constant. + self.assertGreater( + len(shared), + 0, + msg="Test fixture has no shared constant subgraph to duplicate.", + ) + + gm = constant_duplication(gm, CompilationSettings(constant_duplication=True)) + + after = _compute_constant_nodes(gm, _get_impure_targets()) + remaining_shared = [n for n in after if len(n.users) > 1] + self.assertEqual( + len(remaining_shared), + 0, + msg=( + "After constant_duplication, no constant node should still have " + f"multiple users; found: {remaining_shared}" + ), + ) + + def test_constant_duplication_end_to_end(self): + model = self._make_shared_constant_module().cuda().eval() + inputs = (torch.randn(3, 8).cuda(), torch.randn(3, 8).cuda()) + pytorch_out = model(*inputs) + ep = torch.export.export(model, inputs) + trt_module = torch_tensorrt.dynamo.compile( + ep, + inputs, + min_block_size=1, + constant_duplication=True, + ) + trt_out = trt_module(*inputs) + torch.testing.assert_close(pytorch_out, trt_out, rtol=1e-3, atol=1e-3) + + def test_constant_duplication_nested_constants_no_resharing(self): + """Regression: when a constant subgraph has *multiple* multi-user + constant nodes (e.g. a shared weight W feeding two shared intermediates + r2 and r1_t), the clones produced for the outer candidate must + themselves be classified as constants so the inner candidate's + duplication does not re-share them. + + We run only the duplication step (no trailing ``constant_fold``) so the + invariant is checked directly on the post-duplication graph: every + constant node has exactly one user. + """ + from torch_tensorrt.dynamo.lowering.passes.constant_duplication import ( + _clone_constant_subgraph, + _compute_constant_nodes, + _get_impure_targets, + ) + + class ChainedSharedConstants(torch.nn.Module): + def __init__(self): + super().__init__() + self.w = torch.nn.Parameter(torch.randn(16, 16)) + + def forward(self, x, y, p, q): + # ``self.w`` has two reshape users → multi-user. + # ``r1_t`` has two matmul users → multi-user. + # ``r2`` has two matmul users → multi-user. + r1 = self.w.reshape(8, 32) + r1_t = r1.t() # (32, 8) + r2 = self.w.reshape(32, 8) # (32, 8) + return x @ r1_t + y @ r1_t + p @ r2 + q @ r2 + + model = ChainedSharedConstants().cuda().eval() + inputs = ( + torch.randn(2, 32).cuda(), + torch.randn(2, 32).cuda(), + torch.randn(2, 32).cuda(), + torch.randn(2, 32).cuda(), + ) + gm = torch.export.export(model, inputs).module() + + constant_nodes = _compute_constant_nodes(gm, _get_impure_targets()) + candidates = [ + n for n in list(gm.graph.nodes) if n in constant_nodes and len(n.users) > 1 + ] + # Sanity: the fixture must contain a nested chain of multi-user + # constants for this regression to be meaningful. + self.assertGreaterEqual( + len(candidates), + 2, + msg=f"Test fixture has no nested multi-user constants: {candidates}", + ) + + for node in candidates: + users = list(node.users.keys()) + for user in users[1:]: + memo = {} + new_root = _clone_constant_subgraph( + gm, node, user, constant_nodes, memo + ) + user.replace_input_with(node, new_root) + + # Re-classify and verify no constant node ended up multi-user. Without + # the fix, an outer-candidate clone (e.g. ``w_dup0``) is reused as-is + # by an inner candidate's duplication and ends up with 2 users. + post = _compute_constant_nodes(gm, _get_impure_targets()) + leftovers = [n for n in post if len(n.users) > 1] + self.assertEqual( + len(leftovers), + 0, + msg=( + "Duplication step left these constants multi-user " + f"(should be impossible): {[(n.name, len(n.users)) for n in leftovers]}" + ), + ) + + def test_constant_duplication_many_consumers(self): + """A constant subgraph with N > 2 consumers should produce N - 1 clones + (one extra chain per additional consumer), each carrying the original's + shape / dtype metadata. + """ + from torch_tensorrt.dynamo._settings import CompilationSettings + from torch_tensorrt.dynamo.lowering.passes.constant_duplication import ( + _compute_constant_nodes, + _get_impure_targets, + constant_duplication, + ) + + N_CONSUMERS = 5 + + class ManyConsumer(torch.nn.Module): + def __init__(self): + super().__init__() + self.weight = torch.nn.Parameter(torch.randn(8, 4)) + + def forward(self, *xs): + w = self.weight.reshape(4, 8).permute(1, 0) + return sum(x @ w for x in xs) + + model = ManyConsumer().cuda().eval() + inputs = tuple(torch.randn(3, 8).cuda() for _ in range(N_CONSUMERS)) + gm = torch.export.export(model, inputs).module() + + # Snapshot the multi-user constant and its shape before the pass so we + # can compare against the clones. + before = _compute_constant_nodes(gm, _get_impure_targets()) + shared = [n for n in before if len(n.users) == N_CONSUMERS] + self.assertEqual( + len(shared), + 1, + msg=f"Expected exactly one {N_CONSUMERS}-user constant, got {shared}", + ) + original = shared[0] + original_shape = tuple(original.meta["val"].shape) + original_dtype = original.meta["val"].dtype + + gm = constant_duplication(gm, CompilationSettings(constant_duplication=True)) + + # After the pass + internal constant_fold, each consumer's matmul + # should consume an independent frozen constant of the right shape. + matmul_nodes = [ + n + for n in gm.graph.nodes + if n.op == "call_function" and n.target is torch.ops.aten.matmul.default + ] + self.assertEqual( + len(matmul_nodes), + N_CONSUMERS, + msg=f"Expected {N_CONSUMERS} matmuls, got {len(matmul_nodes)}", + ) + + seen_constants = set() + for mm in matmul_nodes: + const_input = mm.args[1] + self.assertEqual( + const_input.op, + "get_attr", + msg=( + f"Matmul {mm.name} should consume a get_attr after folding, " + f"got {const_input.op}={const_input.target}" + ), + ) + self.assertEqual(tuple(const_input.meta["val"].shape), original_shape) + self.assertEqual(const_input.meta["val"].dtype, original_dtype) + self.assertNotIn( + const_input.target, + seen_constants, + msg="Each matmul should consume its own private frozen constant", + ) + seen_constants.add(const_input.target) + + def test_constant_duplication_disabled_is_noop(self): + from torch_tensorrt.dynamo._settings import CompilationSettings + from torch_tensorrt.dynamo.lowering.passes.constant_duplication import ( + constant_duplication, + ) + + model = self._make_shared_constant_module().cuda().eval() + inputs = (torch.randn(3, 8).cuda(), torch.randn(3, 8).cuda()) + ep = torch.export.export(model, inputs) + gm = ep.module() + + node_count_before = sum(1 for _ in gm.graph.nodes) + gm = constant_duplication(gm, CompilationSettings(constant_duplication=False)) + node_count_after = sum(1 for _ in gm.graph.nodes) + self.assertEqual(node_count_before, node_count_after) + + if __name__ == "__main__": run_tests()