Skip to content

🐛 [Bug] Exporting mixed precision copy and linear layer errors #4265

@mattangus

Description

@mattangus

Bug Description

There seem to be some edge cases with converting the graph to tensorrt.

To Reproduce

The following code shows the bug. Both eager and the compiled graph are able to do a forward pass. But the TRT export fails with dtype issues:

code
import torch
import torch.nn as nn
import traceback
import torch_tensorrt


class ScatterCopyDtypeRepro(nn.Module):

    def forward(self, inp: torch.Tensor) -> torch.Tensor:
        # inp: [B, N, 4, 4], fp32 input

        A = inp[..., :3, 3]
        B = inp[..., :3, :3]

        # Force the matmul branch to fp16, mimicking autocast-created fp16 matmul.
        B_h = B.transpose(-2, -1).to(torch.float16)
        T = A.unsqueeze(-1).to(torch.float16)

        T_inv = -(B_h @ T).squeeze(-1)  # fp16
        B_inv = B.transpose(-2, -1)                    # fp32

        out = torch.zeros(
            B.shape[:-2] + (4, 4),
            device=inp.device,
            dtype=torch.float32,
        )

        out[..., 3, 3] = 1.0
        out[..., :3, :3] = B_inv

        # This is the important line.
        # PyTorch semantics: assignment/copy into fp32 destination casts to fp32.
        # Problem seen in TRT lowering: scatter input fp32, update fp16.
        out[..., :3, 3] = T_inv

        return out


class LinearAfterFp32ScalarElementwiseRepro(nn.Module):

    def __init__(self, hidden: int = 768, mlp_hidden: int = 2048):
        super().__init__()

        # Keep LayerNorm weights fp32, as in many mixed-precision transformer blocks.
        self.ln = nn.LayerNorm(hidden, elementwise_affine=True)

        # Use fp16 linear weights to mimic autocast/frozen fp16 weights in TRT.
        self.fc1_weight = nn.Parameter(torch.randn(2 * mlp_hidden, hidden, dtype=torch.float16) * 0.01)
        self.fc2_weight = nn.Parameter(torch.randn(hidden, mlp_hidden, dtype=torch.float16) * 0.01)

        # Scalar fp32 parameters. These are the important mixed dtype params.
        self.alpha = nn.Parameter(torch.tensor(1.0, dtype=torch.float32))
        self.x2_bias = nn.Parameter(torch.tensor(0.0, dtype=torch.float32))

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        # x: [B, S, H], fp32 input

        residual = x
        x = self.ln(x)                 # fp32
        x = x.to(torch.float16)        # explicit fp16 MLP compute

        x = torch.nn.functional.linear(x, self.fc1_weight)  # fp16
        x1, x2 = torch.split(x, x.shape[-1] // 2, dim=-1)

        x1 = torch.clamp(x1, max=7.0)
        x2 = torch.clamp(x2, min=-7.0, max=7.0)

        # Important mixed dtype scalar ops:
        # alpha and x2_bias are fp32 scalars, x1/x2 are fp16 tensors.
        gated = x1 * torch.sigmoid(self.alpha * x1)
        gated = gated * (x2 + self.x2_bias)

        # This is the line that can fail during TRT build:
        # TensorRT may see gated as Float while fc2_weight is Half.
        y = torch.nn.functional.linear(gated, self.fc2_weight)

        # The following residual add is also mixed fp32/fp16, but the observed error
        # can occur earlier inside the linear/matmul above.
        return residual + y


def compile_with_torch_tensorrt(module: nn.Module, example_inputs):
    module = module.eval().cuda()

    with torch.no_grad():
        exported = torch.export.export(module, example_inputs)
        graph_output = exported.module()(*example_inputs)
        print("EP output dtype:", graph_output.dtype, "shape:", tuple(graph_output.shape))

    # Adjust options as needed for your local setup.
    # The important part is enabling fp16 so TRT builds a mixed precision network.
    return torch_tensorrt.dynamo.compile(
        exported,
        inputs=example_inputs,
        # enabled_precisions={torch.float16, torch.float32},
        debug=True,
    )


def run_scatter_repro():
    print("\n=== Scatter/copy dtype repro ===")

    model = ScatterCopyDtypeRepro()

    to_mul = torch.eye(4, dtype=torch.float32, device="cuda").reshape(1, 1, 4, 4)
    to_mul = to_mul.expand(1, 5, 4, 4).contiguous()

    # Eager should work.
    with torch.no_grad():
        eager_out = model.cuda()(to_mul)
    print("Eager output dtype:", eager_out.dtype, "shape:", tuple(eager_out.shape))

    # TRT compile may fail with IScatterLayer input Float / updates Half.
    compile_with_torch_tensorrt(model, (to_mul,))


def run_linear_repro():
    print("\n=== Linear after fp32 scalar elementwise repro ===")

    model = LinearAfterFp32ScalarElementwiseRepro(hidden=768, mlp_hidden=2048)

    x = torch.randn(30, 1680, 768, dtype=torch.float32, device="cuda")

    # Eager should work.
    with torch.no_grad():
        eager_out = model.cuda()(x)
    print("Eager output dtype:", eager_out.dtype, "shape:", tuple(eager_out.shape))

    # TRT compile may fail with IMatrixMultiplyLayer A Float / B Half.
    compile_with_torch_tensorrt(model, (x,))


torch.manual_seed(0)

try:
    run_scatter_repro()
except Exception:
    print("\nScatter repro exception")
    traceback.print_exc()

try:
    run_linear_repro()
except Exception:
    print("\nLinear repro exception")
    traceback.print_exc()
logs
11:55:18 - INFO - Compilation Settings: CompilationSettings(enabled_precisions={<dtype.f32: 7>}, workspace_size=0, min_block_size=5, torch_executed_ops=set(), pass_through_build_failures=False, max_aux_streams=None, version_compatible=False, optimization_level=None, use_python_runtime=False, truncate_double=False, use_fast_partitioner=True, enable_experimental_decompositions=False, device=Device(type=DeviceType.GPU, gpu_id=0), require_full_compilation=False, disable_tf32=False, assume_dynamic_shape_support=False, sparse_weights=False, engine_capability=<EngineCapability.STANDARD: 1>, num_avg_timing_iters=1, dla_sram_size=1048576, dla_local_dram_size=1073741824, dla_global_dram_size=536870912, dryrun=False, hardware_compatible=False, timing_cache_path='/tmp/torch_tensorrt_engine_cache/timing_cache.bin', lazy_engine_init=False, cache_built_engines=False, reuse_cached_engines=False, use_explicit_typing=True, use_fp32_acc=False, refit_identical_engine_weights=False, strip_engine_weights=False, immutable_weights=True, enable_weight_streaming=False, enable_cross_compile_for_windows=False, tiling_optimization_level='none', l2_limit_for_tiling=-1, use_distributed_mode_trace=False, offload_module_to_cpu=False, enable_autocast=False, autocast_low_precision_type=None, autocast_excluded_nodes=set(), autocast_excluded_ops=set(), autocast_max_output_threshold=512, autocast_max_depth_of_reduction=None, autocast_calibration_dataloader=None, enable_resource_partitioning=False, cpu_memory_budget=None, dynamically_allocate_resources=False, decompose_attention=False, attn_bias_is_causal=True)


=== Scatter/copy dtype repro ===
Eager output dtype: torch.float32 shape: (1, 5, 4, 4)
EP output dtype: torch.float32 shape: (1, 5, 4, 4)
11:55:19 - INFO - Partitioning the graph via the fast partitioner
11:55:19 - WARNING - WARNING The logger passed into createInferBuilder differs from one already registered for an existing builder, runtime, or refitter. So the current new logger is ignored, and TensorRT will use the existing one which is returned by nvinfer1::getLogger() instead.
11:55:19 - INFO - Converted node inp [inp] (Inputs: () | Outputs: (inp: (1, 5, 4, 4)@torch.float32))
11:55:19 - INFO - Converted node [/slice_1](https://vscode-remote+ssh-002dremote-002bcoder-002dvscode-002ecoder-002eaks-002dprod-002dcoder-002dswedencentral-002eazr-002ewayve-002eai-002d-002dmattangus-002d-002dmanx-002eworkspace.vscode-resource.vscode-cdn.net/slice_1) [aten.slice.Tensor] (Inputs: (inp: (1, 5, 4, 4)@torch.float32, 2, 0, 3) | Outputs: (slice_1: (1, 5, 3, 4)@torch.float32))
11:55:19 - INFO - Converted node [/select](https://vscode-remote+ssh-002dremote-002bcoder-002dvscode-002ecoder-002eaks-002dprod-002dcoder-002dswedencentral-002eazr-002ewayve-002eai-002d-002dmattangus-002d-002dmanx-002eworkspace.vscode-resource.vscode-cdn.net/select) [aten.select.int] (Inputs: (slice_1: (1, 5, 3, 4)@torch.float32, 3, 3) | Outputs: (select: (1, 5, 3)@torch.float32))
11:55:19 - INFO - Converted node [/slice_2](https://vscode-remote+ssh-002dremote-002bcoder-002dvscode-002ecoder-002eaks-002dprod-002dcoder-002dswedencentral-002eazr-002ewayve-002eai-002d-002dmattangus-002d-002dmanx-002eworkspace.vscode-resource.vscode-cdn.net/slice_2) [aten.slice.Tensor] (Inputs: (inp: (1, 5, 4, 4)@torch.float32, 2, 0, 3) | Outputs: (slice_2: (1, 5, 3, 4)@torch.float32))
11:55:19 - INFO - Converted node [/slice_3](https://vscode-remote+ssh-002dremote-002bcoder-002dvscode-002ecoder-002eaks-002dprod-002dcoder-002dswedencentral-002eazr-002ewayve-002eai-002d-002dmattangus-002d-002dmanx-002eworkspace.vscode-resource.vscode-cdn.net/slice_3) [aten.slice.Tensor] (Inputs: (slice_2: (1, 5, 3, 4)@torch.float32, 3, 0, 3) | Outputs: (slice_3: (1, 5, 3, 3)@torch.float32))
11:55:19 - INFO - Converted node [/permute](https://vscode-remote+ssh-002dremote-002bcoder-002dvscode-002ecoder-002eaks-002dprod-002dcoder-002dswedencentral-002eazr-002ewayve-002eai-002d-002dmattangus-002d-002dmanx-002eworkspace.vscode-resource.vscode-cdn.net/permute) [aten.permute.default] (Inputs: (slice_3: (1, 5, 3, 3)@torch.float32, [0, 1, 3, 2]) | Outputs: (permute: (1, 5, 3, 3)@torch.float32))
11:55:19 - INFO - Converted node [/_to_copy](https://vscode-remote+ssh-002dremote-002bcoder-002dvscode-002ecoder-002eaks-002dprod-002dcoder-002dswedencentral-002eazr-002ewayve-002eai-002d-002dmattangus-002d-002dmanx-002eworkspace.vscode-resource.vscode-cdn.net/_to_copy) [aten._to_copy.default] (Inputs: (permute: (1, 5, 3, 3)@torch.float32) | Outputs: (_to_copy: (1, 5, 3, 3)@torch.float16))
11:55:19 - INFO - Converted node [/unsqueeze](https://vscode-remote+ssh-002dremote-002bcoder-002dvscode-002ecoder-002eaks-002dprod-002dcoder-002dswedencentral-002eazr-002ewayve-002eai-002d-002dmattangus-002d-002dmanx-002eworkspace.vscode-resource.vscode-cdn.net/unsqueeze) [aten.unsqueeze.default] (Inputs: (select: (1, 5, 3)@torch.float32, -1) | Outputs: (unsqueeze: (1, 5, 3, 1)@torch.float32))
11:55:19 - INFO - Converted node [/_to_copy_1](https://vscode-remote+ssh-002dremote-002bcoder-002dvscode-002ecoder-002eaks-002dprod-002dcoder-002dswedencentral-002eazr-002ewayve-002eai-002d-002dmattangus-002d-002dmanx-002eworkspace.vscode-resource.vscode-cdn.net/_to_copy_1) [aten._to_copy.default] (Inputs: (unsqueeze: (1, 5, 3, 1)@torch.float32) | Outputs: (_to_copy_1: (1, 5, 3, 1)@torch.float16))
11:55:19 - INFO - Converted node [/matmul](https://vscode-remote+ssh-002dremote-002bcoder-002dvscode-002ecoder-002eaks-002dprod-002dcoder-002dswedencentral-002eazr-002ewayve-002eai-002d-002dmattangus-002d-002dmanx-002eworkspace.vscode-resource.vscode-cdn.net/matmul) [aten.matmul.default] (Inputs: (_to_copy: (1, 5, 3, 3)@torch.float16, _to_copy_1: (1, 5, 3, 1)@torch.float16) | Outputs: (matmul: (1, 5, 3, 1)@torch.float16))
11:55:19 - INFO - Converted node [/squeeze](https://vscode-remote+ssh-002dremote-002bcoder-002dvscode-002ecoder-002eaks-002dprod-002dcoder-002dswedencentral-002eazr-002ewayve-002eai-002d-002dmattangus-002d-002dmanx-002eworkspace.vscode-resource.vscode-cdn.net/squeeze) [aten.squeeze.dim] (Inputs: (matmul: (1, 5, 3, 1)@torch.float16, -1) | Outputs: (squeeze: (1, 5, 3)@torch.float16))
11:55:19 - INFO - Converted node [/neg](https://vscode-remote+ssh-002dremote-002bcoder-002dvscode-002ecoder-002eaks-002dprod-002dcoder-002dswedencentral-002eazr-002ewayve-002eai-002d-002dmattangus-002d-002dmanx-002eworkspace.vscode-resource.vscode-cdn.net/neg) [aten.neg.default] (Inputs: (squeeze: (1, 5, 3)@torch.float16) | Outputs: (neg: (1, 5, 3)@torch.float16))
11:55:19 - INFO - Converted node [/permute_1](https://vscode-remote+ssh-002dremote-002bcoder-002dvscode-002ecoder-002eaks-002dprod-002dcoder-002dswedencentral-002eazr-002ewayve-002eai-002d-002dmattangus-002d-002dmanx-002eworkspace.vscode-resource.vscode-cdn.net/permute_1) [aten.permute.default] (Inputs: (slice_3: (1, 5, 3, 3)@torch.float32, [0, 1, 3, 2]) | Outputs: (permute_1: (1, 5, 3, 3)@torch.float32))
11:55:19 - INFO - Converted node [/_frozen_param1](https://vscode-remote+ssh-002dremote-002bcoder-002dvscode-002ecoder-002eaks-002dprod-002dcoder-002dswedencentral-002eazr-002ewayve-002eai-002d-002dmattangus-002d-002dmanx-002eworkspace.vscode-resource.vscode-cdn.net/_frozen_param1) [_frozen_param1] (Inputs: () | Outputs: (_frozen_param1: (1, 5, 3, 3)@torch.float32))
11:55:19 - INFO - Converted node [/copy_1](https://vscode-remote+ssh-002dremote-002bcoder-002dvscode-002ecoder-002eaks-002dprod-002dcoder-002dswedencentral-002eazr-002ewayve-002eai-002d-002dmattangus-002d-002dmanx-002eworkspace.vscode-resource.vscode-cdn.net/copy_1) [aten.copy.default] (Inputs: (_frozen_param1: (1, 5, 3, 3)@torch.float32, permute_1: (1, 5, 3, 3)@torch.float32) | Outputs: (copy_1: (1, 5, 3, 3)@torch.float32))
11:55:19 - INFO - Converted node [/_frozen_param2](https://vscode-remote+ssh-002dremote-002bcoder-002dvscode-002ecoder-002eaks-002dprod-002dcoder-002dswedencentral-002eazr-002ewayve-002eai-002d-002dmattangus-002d-002dmanx-002eworkspace.vscode-resource.vscode-cdn.net/_frozen_param2) [_frozen_param2] (Inputs: () | Outputs: (_frozen_param2: (1, 5, 3, 4)@torch.float32))
11:55:19 - INFO - Converted node [/_frozen_param3](https://vscode-remote+ssh-002dremote-002bcoder-002dvscode-002ecoder-002eaks-002dprod-002dcoder-002dswedencentral-002eazr-002ewayve-002eai-002d-002dmattangus-002d-002dmanx-002eworkspace.vscode-resource.vscode-cdn.net/_frozen_param3) [_frozen_param3] (Inputs: () | Outputs: (_frozen_param3: (1, 5, 3, 3)@torch.int64))
11:55:19 - INFO - Converted node [/scatter_2](https://vscode-remote+ssh-002dremote-002bcoder-002dvscode-002ecoder-002eaks-002dprod-002dcoder-002dswedencentral-002eazr-002ewayve-002eai-002d-002dmattangus-002d-002dmanx-002eworkspace.vscode-resource.vscode-cdn.net/scatter_2) [aten.scatter.src] (Inputs: (_frozen_param2: (1, 5, 3, 4)@torch.float32, 3, _frozen_param3: (1, 5, 3, 3)@torch.int64, copy_1: (1, 5, 3, 3)@torch.float32) | Outputs: (scatter_2: (1, 5, 3, 4)@torch.float32))
11:55:19 - INFO - Converted node [/_frozen_param0](https://vscode-remote+ssh-002dremote-002bcoder-002dvscode-002ecoder-002eaks-002dprod-002dcoder-002dswedencentral-002eazr-002ewayve-002eai-002d-002dmattangus-002d-002dmanx-002eworkspace.vscode-resource.vscode-cdn.net/_frozen_param0) [_frozen_param0] (Inputs: () | Outputs: (_frozen_param0: (1, 5, 4, 4)@torch.float32))
11:55:19 - INFO - Converted node [/_frozen_param4](https://vscode-remote+ssh-002dremote-002bcoder-002dvscode-002ecoder-002eaks-002dprod-002dcoder-002dswedencentral-002eazr-002ewayve-002eai-002d-002dmattangus-002d-002dmanx-002eworkspace.vscode-resource.vscode-cdn.net/_frozen_param4) [_frozen_param4] (Inputs: () | Outputs: (_frozen_param4: (1, 5, 3, 4)@torch.int64))
11:55:19 - INFO - Converted node [/scatter_3](https://vscode-remote+ssh-002dremote-002bcoder-002dvscode-002ecoder-002eaks-002dprod-002dcoder-002dswedencentral-002eazr-002ewayve-002eai-002d-002dmattangus-002d-002dmanx-002eworkspace.vscode-resource.vscode-cdn.net/scatter_3) [aten.scatter.src] (Inputs: (_frozen_param0: (1, 5, 4, 4)@torch.float32, 2, _frozen_param4: (1, 5, 3, 4)@torch.int64, scatter_2: (1, 5, 3, 4)@torch.float32) | Outputs: (scatter_3: (1, 5, 4, 4)@torch.float32))
11:55:19 - INFO - Converted node [/slice_14](https://vscode-remote+ssh-002dremote-002bcoder-002dvscode-002ecoder-002eaks-002dprod-002dcoder-002dswedencentral-002eazr-002ewayve-002eai-002d-002dmattangus-002d-002dmanx-002eworkspace.vscode-resource.vscode-cdn.net/slice_14) [aten.slice.Tensor] (Inputs: (scatter_3: (1, 5, 4, 4)@torch.float32, 2, 0, 3) | Outputs: (slice_14: (1, 5, 3, 4)@torch.float32))
11:55:19 - INFO - Converted node [/select_7](https://vscode-remote+ssh-002dremote-002bcoder-002dvscode-002ecoder-002eaks-002dprod-002dcoder-002dswedencentral-002eazr-002ewayve-002eai-002d-002dmattangus-002d-002dmanx-002eworkspace.vscode-resource.vscode-cdn.net/select_7) [aten.select.int] (Inputs: (slice_14: (1, 5, 3, 4)@torch.float32, 3, 3) | Outputs: (select_7: (1, 5, 3)@torch.float32))
11:55:19 - INFO - Converted node [/copy_2](https://vscode-remote+ssh-002dremote-002bcoder-002dvscode-002ecoder-002eaks-002dprod-002dcoder-002dswedencentral-002eazr-002ewayve-002eai-002d-002dmattangus-002d-002dmanx-002eworkspace.vscode-resource.vscode-cdn.net/copy_2) [aten.copy.default] (Inputs: (select_7: (1, 5, 3)@torch.float32, neg: (1, 5, 3)@torch.float16) | Outputs: (copy_2: (1, 5, 3)@torch.float32))
11:55:19 - INFO - Converted node [/slice_15](https://vscode-remote+ssh-002dremote-002bcoder-002dvscode-002ecoder-002eaks-002dprod-002dcoder-002dswedencentral-002eazr-002ewayve-002eai-002d-002dmattangus-002d-002dmanx-002eworkspace.vscode-resource.vscode-cdn.net/slice_15) [aten.slice.Tensor] (Inputs: (scatter_3: (1, 5, 4, 4)@torch.float32, 2, 0, 3) | Outputs: (slice_15: (1, 5, 3, 4)@torch.float32))
11:55:19 - INFO - Converted node [/unsqueeze_3](https://vscode-remote+ssh-002dremote-002bcoder-002dvscode-002ecoder-002eaks-002dprod-002dcoder-002dswedencentral-002eazr-002ewayve-002eai-002d-002dmattangus-002d-002dmanx-002eworkspace.vscode-resource.vscode-cdn.net/unsqueeze_3) [aten.unsqueeze.default] (Inputs: (copy_2: (1, 5, 3)@torch.float32, 3) | Outputs: (unsqueeze_3: (1, 5, 3, 1)@torch.float32))
11:55:19 - INFO - Converted node [/_frozen_param5](https://vscode-remote+ssh-002dremote-002bcoder-002dvscode-002ecoder-002eaks-002dprod-002dcoder-002dswedencentral-002eazr-002ewayve-002eai-002d-002dmattangus-002d-002dmanx-002eworkspace.vscode-resource.vscode-cdn.net/_frozen_param5) [_frozen_param5] (Inputs: () | Outputs: (_frozen_param5: (1, 5, 3, 1)@torch.int64))
11:55:19 - INFO - Converted node [/scatter_4](https://vscode-remote+ssh-002dremote-002bcoder-002dvscode-002ecoder-002eaks-002dprod-002dcoder-002dswedencentral-002eazr-002ewayve-002eai-002d-002dmattangus-002d-002dmanx-002eworkspace.vscode-resource.vscode-cdn.net/scatter_4) [aten.scatter.src] (Inputs: (slice_15: (1, 5, 3, 4)@torch.float32, 3, _frozen_param5: (1, 5, 3, 1)@torch.int64, unsqueeze_3: (1, 5, 3, 1)@torch.float32) | Outputs: (scatter_4: (1, 5, 3, 4)@torch.float32))
11:55:19 - INFO - Converted node [/_frozen_param6](https://vscode-remote+ssh-002dremote-002bcoder-002dvscode-002ecoder-002eaks-002dprod-002dcoder-002dswedencentral-002eazr-002ewayve-002eai-002d-002dmattangus-002d-002dmanx-002eworkspace.vscode-resource.vscode-cdn.net/_frozen_param6) [_frozen_param6] (Inputs: () | Outputs: (_frozen_param6: (1, 5, 3, 4)@torch.int64))
11:55:19 - INFO - Converted node [/scatter_5](https://vscode-remote+ssh-002dremote-002bcoder-002dvscode-002ecoder-002eaks-002dprod-002dcoder-002dswedencentral-002eazr-002ewayve-002eai-002d-002dmattangus-002d-002dmanx-002eworkspace.vscode-resource.vscode-cdn.net/scatter_5) [aten.scatter.src] (Inputs: (scatter_3: (1, 5, 4, 4)@torch.float32, 2, _frozen_param6: (1, 5, 3, 4)@torch.int64, scatter_4: (1, 5, 3, 4)@torch.float32) | Outputs: (scatter_5: (1, 5, 4, 4)@torch.float32))
11:55:19 - ERROR - ITensor::getDimensions: Error Code 4: API Usage Error ([SCATTER]-[aten_ops.scatter.src]-[/scatter_4_scatter_layer]: IScatterLayer `input` and `updates` must have identical types. `input` type is Float and `updates` type is Half. In validateTypes at [/_src/optimizer/common/nodes/scatterNode.cpp:100](https://vscode-remote+ssh-002dremote-002bcoder-002dvscode-002ecoder-002eaks-002dprod-002dcoder-002dswedencentral-002eazr-002ewayve-002eai-002d-002dmattangus-002d-002dmanx-002eworkspace.vscode-resource.vscode-cdn.net/_src/optimizer/common/nodes/scatterNode.cpp:100))
11:55:19 - INFO - Converted node output [output] (Inputs: (scatter_5: (1, 5, 4, 4)@torch.float32) | Outputs: (output: ))
11:55:19 - INFO - TRT INetwork construction elapsed time: 0:00:00.025301
11:55:19 - INFO - Not found cached TRT engines. Start building engine.
11:55:19 - ERROR - ITensor::getDimensions: Error Code 4: API Usage Error (Output shape can not be computed for node [SCATTER]-[aten_ops.scatter.src]-[/scatter_4_scatter_layer]. In needTypeAndDimensions at [/_src/optimizer/shapeof/graphShapeAnalyzer.cpp:2993](https://vscode-remote+ssh-002dremote-002bcoder-002dvscode-002ecoder-002eaks-002dprod-002dcoder-002dswedencentral-002eazr-002ewayve-002eai-002d-002dmattangus-002d-002dmanx-002eworkspace.vscode-resource.vscode-cdn.net/_src/optimizer/shapeof/graphShapeAnalyzer.cpp:2993))
11:55:19 - ERROR - IBuilder::buildEngineWithConfig: Error Code 4: API Usage Error ([SCATTER]-[aten_ops.scatter.src]-[/scatter_4_scatter_layer]: IScatterLayer `input` and `updates` must have identical types. `input` type is Float and `updates` type is Half. In validateTypes at [/_src/optimizer/common/nodes/scatterNode.cpp:100](https://vscode-remote+ssh-002dremote-002bcoder-002dvscode-002ecoder-002eaks-002dprod-002dcoder-002dswedencentral-002eazr-002ewayve-002eai-002d-002dmattangus-002d-002dmanx-002eworkspace.vscode-resource.vscode-cdn.net/_src/optimizer/common/nodes/scatterNode.cpp:100))
Traceback ...
    assert cuda_engine
AssertionError
11:55:19 - INFO - Compilation Settings: CompilationSettings(enabled_precisions={<dtype.f32: 7>}, workspace_size=0, min_block_size=5, torch_executed_ops=set(), pass_through_build_failures=False, max_aux_streams=None, version_compatible=False, optimization_level=None, use_python_runtime=False, truncate_double=False, use_fast_partitioner=True, enable_experimental_decompositions=False, device=Device(type=DeviceType.GPU, gpu_id=0), require_full_compilation=False, disable_tf32=False, assume_dynamic_shape_support=False, sparse_weights=False, engine_capability=<EngineCapability.STANDARD: 1>, num_avg_timing_iters=1, dla_sram_size=1048576, dla_local_dram_size=1073741824, dla_global_dram_size=536870912, dryrun=False, hardware_compatible=False, timing_cache_path='/tmp/torch_tensorrt_engine_cache/timing_cache.bin', lazy_engine_init=False, cache_built_engines=False, reuse_cached_engines=False, use_explicit_typing=True, use_fp32_acc=False, refit_identical_engine_weights=False, strip_engine_weights=False, immutable_weights=True, enable_weight_streaming=False, enable_cross_compile_for_windows=False, tiling_optimization_level='none', l2_limit_for_tiling=-1, use_distributed_mode_trace=False, offload_module_to_cpu=False, enable_autocast=False, autocast_low_precision_type=None, autocast_excluded_nodes=set(), autocast_excluded_ops=set(), autocast_max_output_threshold=512, autocast_max_depth_of_reduction=None, autocast_calibration_dataloader=None, enable_resource_partitioning=False, cpu_memory_budget=None, dynamically_allocate_resources=False, decompose_attention=False, attn_bias_is_causal=True)


Scatter repro exception

=== Linear after fp32 scalar elementwise repro ===
Eager output dtype: torch.float32 shape: (30, 1680, 768)
EP output dtype: torch.float32 shape: (30, 1680, 768)
11:55:19 - INFO - Partitioning the graph via the fast partitioner
11:55:19 - WARNING - WARNING The logger passed into createInferBuilder differs from one already registered for an existing builder, runtime, or refitter. So the current new logger is ignored, and TensorRT will use the existing one which is returned by nvinfer1::getLogger() instead.
11:55:19 - INFO - Converted node x [x] (Inputs: () | Outputs: (x: (30, 1680, 768)@torch.float32))
11:55:19 - INFO - Converted node ln_weight [ln.weight] (Inputs: () | Outputs: (ln_weight: (768,)@torch.float32))
11:55:19 - INFO - Converted node ln_bias [ln.bias] (Inputs: () | Outputs: (ln_bias: (768,)@torch.float32))
11:55:19 - INFO - Converted node ln/native_layer_norm [aten.native_layer_norm.default] (Inputs: (x: (30, 1680, 768)@torch.float32, [768], ln_weight: (768,)@torch.float32, ln_bias: (768,)@torch.float32, 1e-05) | Outputs: (native_layer_norm: ((30, 1680, 768)@torch.float32, (30, 1680, 1)@torch.float32, (30, 1680, 1)@torch.float32)))
11:55:19 - INFO - Converted node ln/getitem [<built-in function getitem>] (Inputs: (native_layer_norm: ((30, 1680, 768)@torch.float32, (30, 1680, 1)@torch.float32, (30, 1680, 1)@torch.float32), 0) | Outputs: (getitem: (30, 1680, 768)@torch.float32))
11:55:19 - INFO - Converted node [/_to_copy](https://vscode-remote+ssh-002dremote-002bcoder-002dvscode-002ecoder-002eaks-002dprod-002dcoder-002dswedencentral-002eazr-002ewayve-002eai-002d-002dmattangus-002d-002dmanx-002eworkspace.vscode-resource.vscode-cdn.net/_to_copy) [aten._to_copy.default] (Inputs: (getitem: (30, 1680, 768)@torch.float32) | Outputs: (_to_copy: (30, 1680, 768)@torch.float16))
11:55:19 - INFO - Converted node fc1_weight [fc1_weight] (Inputs: () | Outputs: (fc1_weight: (4096, 768)@torch.float16))
11:55:19 - INFO - Converted node [/linear](https://vscode-remote+ssh-002dremote-002bcoder-002dvscode-002ecoder-002eaks-002dprod-002dcoder-002dswedencentral-002eazr-002ewayve-002eai-002d-002dmattangus-002d-002dmanx-002eworkspace.vscode-resource.vscode-cdn.net/linear) [aten.linear.default] (Inputs: (_to_copy: (30, 1680, 768)@torch.float16, fc1_weight: (4096, 768)@torch.float16) | Outputs: (linear: (30, 1680, 4096)@torch.float16))
11:55:19 - INFO - Converted node [/split](https://vscode-remote+ssh-002dremote-002bcoder-002dvscode-002ecoder-002eaks-002dprod-002dcoder-002dswedencentral-002eazr-002ewayve-002eai-002d-002dmattangus-002d-002dmanx-002eworkspace.vscode-resource.vscode-cdn.net/split) [aten.split.Tensor] (Inputs: (linear: (30, 1680, 4096)@torch.float16, 2048, -1) | Outputs: (split: ((30, 1680, 2048)@torch.float16, (30, 1680, 2048)@torch.float16)))
11:55:19 - INFO - Converted node [/getitem_3](https://vscode-remote+ssh-002dremote-002bcoder-002dvscode-002ecoder-002eaks-002dprod-002dcoder-002dswedencentral-002eazr-002ewayve-002eai-002d-002dmattangus-002d-002dmanx-002eworkspace.vscode-resource.vscode-cdn.net/getitem_3) [<built-in function getitem>] (Inputs: (split: ((30, 1680, 2048)@torch.float16, (30, 1680, 2048)@torch.float16), 0) | Outputs: (getitem_3: (30, 1680, 2048)@torch.float16))
11:55:19 - INFO - Converted node [/getitem_4](https://vscode-remote+ssh-002dremote-002bcoder-002dvscode-002ecoder-002eaks-002dprod-002dcoder-002dswedencentral-002eazr-002ewayve-002eai-002d-002dmattangus-002d-002dmanx-002eworkspace.vscode-resource.vscode-cdn.net/getitem_4) [<built-in function getitem>] (Inputs: (split: ((30, 1680, 2048)@torch.float16, (30, 1680, 2048)@torch.float16), 1) | Outputs: (getitem_4: (30, 1680, 2048)@torch.float16))
11:55:19 - INFO - Converted node [/clamp](https://vscode-remote+ssh-002dremote-002bcoder-002dvscode-002ecoder-002eaks-002dprod-002dcoder-002dswedencentral-002eazr-002ewayve-002eai-002d-002dmattangus-002d-002dmanx-002eworkspace.vscode-resource.vscode-cdn.net/clamp) [aten.clamp.default] (Inputs: (getitem_3: (30, 1680, 2048)@torch.float16, None, 7.0) | Outputs: (clamp: (30, 1680, 2048)@torch.float16))
11:55:19 - INFO - Converted node [/clamp_1](https://vscode-remote+ssh-002dremote-002bcoder-002dvscode-002ecoder-002eaks-002dprod-002dcoder-002dswedencentral-002eazr-002ewayve-002eai-002d-002dmattangus-002d-002dmanx-002eworkspace.vscode-resource.vscode-cdn.net/clamp_1) [aten.clamp.default] (Inputs: (getitem_4: (30, 1680, 2048)@torch.float16, -7.0, 7.0) | Outputs: (clamp_1: (30, 1680, 2048)@torch.float16))
11:55:19 - INFO - Converted node alpha [alpha] (Inputs: () | Outputs: (alpha: ()@torch.float32))
11:55:19 - INFO - Converted node [/mul](https://vscode-remote+ssh-002dremote-002bcoder-002dvscode-002ecoder-002eaks-002dprod-002dcoder-002dswedencentral-002eazr-002ewayve-002eai-002d-002dmattangus-002d-002dmanx-002eworkspace.vscode-resource.vscode-cdn.net/mul) [aten.mul.Tensor] (Inputs: (alpha: ()@torch.float32, clamp: (30, 1680, 2048)@torch.float16) | Outputs: (mul: (30, 1680, 2048)@torch.float16))
11:55:19 - INFO - Converted node [/sigmoid](https://vscode-remote+ssh-002dremote-002bcoder-002dvscode-002ecoder-002eaks-002dprod-002dcoder-002dswedencentral-002eazr-002ewayve-002eai-002d-002dmattangus-002d-002dmanx-002eworkspace.vscode-resource.vscode-cdn.net/sigmoid) [aten.sigmoid.default] (Inputs: (mul: (30, 1680, 2048)@torch.float16) | Outputs: (sigmoid: (30, 1680, 2048)@torch.float16))
11:55:19 - INFO - skip broadcast for [/mul_1](https://vscode-remote+ssh-002dremote-002bcoder-002dvscode-002ecoder-002eaks-002dprod-002dcoder-002dswedencentral-002eazr-002ewayve-002eai-002d-002dmattangus-002d-002dmanx-002eworkspace.vscode-resource.vscode-cdn.net/mul_1)
11:55:19 - INFO - Converted node [/mul_1](https://vscode-remote+ssh-002dremote-002bcoder-002dvscode-002ecoder-002eaks-002dprod-002dcoder-002dswedencentral-002eazr-002ewayve-002eai-002d-002dmattangus-002d-002dmanx-002eworkspace.vscode-resource.vscode-cdn.net/mul_1) [aten.mul.Tensor] (Inputs: (clamp: (30, 1680, 2048)@torch.float16, sigmoid: (30, 1680, 2048)@torch.float16) | Outputs: (mul_1: (30, 1680, 2048)@torch.float16))
11:55:19 - INFO - Converted node x2_bias [x2_bias] (Inputs: () | Outputs: (x2_bias: ()@torch.float32))
11:55:19 - INFO - Converted node [/add](https://vscode-remote+ssh-002dremote-002bcoder-002dvscode-002ecoder-002eaks-002dprod-002dcoder-002dswedencentral-002eazr-002ewayve-002eai-002d-002dmattangus-002d-002dmanx-002eworkspace.vscode-resource.vscode-cdn.net/add) [aten.add.Tensor] (Inputs: (clamp_1: (30, 1680, 2048)@torch.float16, x2_bias: ()@torch.float32) | Outputs: (add: (30, 1680, 2048)@torch.float16))
11:55:19 - INFO - skip broadcast for [/mul_2](https://vscode-remote+ssh-002dremote-002bcoder-002dvscode-002ecoder-002eaks-002dprod-002dcoder-002dswedencentral-002eazr-002ewayve-002eai-002d-002dmattangus-002d-002dmanx-002eworkspace.vscode-resource.vscode-cdn.net/mul_2)
11:55:19 - INFO - Converted node [/mul_2](https://vscode-remote+ssh-002dremote-002bcoder-002dvscode-002ecoder-002eaks-002dprod-002dcoder-002dswedencentral-002eazr-002ewayve-002eai-002d-002dmattangus-002d-002dmanx-002eworkspace.vscode-resource.vscode-cdn.net/mul_2) [aten.mul.Tensor] (Inputs: (mul_1: (30, 1680, 2048)@torch.float16, add: (30, 1680, 2048)@torch.float16) | Outputs: (mul_2: (30, 1680, 2048)@torch.float16))
11:55:19 - INFO - Converted node fc2_weight [fc2_weight] (Inputs: () | Outputs: (fc2_weight: (768, 2048)@torch.float16))
11:55:19 - INFO - Converted node [/linear_1](https://vscode-remote+ssh-002dremote-002bcoder-002dvscode-002ecoder-002eaks-002dprod-002dcoder-002dswedencentral-002eazr-002ewayve-002eai-002d-002dmattangus-002d-002dmanx-002eworkspace.vscode-resource.vscode-cdn.net/linear_1) [aten.linear.default] (Inputs: (mul_2: (30, 1680, 2048)@torch.float16, fc2_weight: (768, 2048)@torch.float16) | Outputs: (linear_1: (30, 1680, 768)@torch.float16))
11:55:19 - ERROR - ITensor::getDimensions: Error Code 4: API Usage Error ([MATRIX_MULTIPLY]-[aten_ops.linear.default]-[/linear_1_matrix_multiply]: IMatrixMultiplyLayer must have same input types. `A` is of type Float and `B` is of type Half. In validateTypes at [/_src/optimizer/common/nodes/matrixMultiplyNode.cpp:37](https://vscode-remote+ssh-002dremote-002bcoder-002dvscode-002ecoder-002eaks-002dprod-002dcoder-002dswedencentral-002eazr-002ewayve-002eai-002d-002dmattangus-002d-002dmanx-002eworkspace.vscode-resource.vscode-cdn.net/_src/optimizer/common/nodes/matrixMultiplyNode.cpp:37))

Linear repro exception
Traceback ...
ValueError: __len__() should return >= 0

While executing %add_1 : [num_users=1] = call_function[target=torch.ops.aten.add.Tensor](args = (%x, %linear_1), kwargs = {})
Original traceback:
File "/tmp/ipykernel_143457/3067791630.py", line 79, in forward
    return residual + y
Use tlparse to see full graph. (https://github.com/pytorch/tlparse?tab=readme-ov-file#tlparse-parse-structured-pt2-logs)

Expected behavior

Since both the torch and graph version work, I would expect the exported TRT version to work too.

Environment

Build information about Torch-TensorRT can be found by turning on debug messages

  • Torch-TensorRT Version (e.g. 1.0.0): 2.11.0+cu128
  • PyTorch Version (e.g. 1.0): 2.11.0
  • CPU Architecture: x86
  • OS (e.g., Linux): Ubuntu 24.04
  • How you installed PyTorch (conda, pip, libtorch, source): pip from wheel
  • Build command you used (if compiling from source):
  • Are you using local sources or building from archives: https://download.pytorch.org/whl/cu128/torch_tensorrt-2.11.0%2Bcu128-cp310-cp310-manylinux_2_28_x86_64.whl
  • Python version: 3.10
  • CUDA version: 12.8
  • GPU models and configuration: A100
  • Any other relevant information:

Metadata

Metadata

Assignees

No one assigned

    Labels

    bugSomething isn't working

    Type

    No type
    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions