Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 20 additions & 13 deletions py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@

import numpy as np
import torch
from tensorrt import ITensor as TRTTensor
from torch.fx.node import Argument, Node, Target
from torch_tensorrt import ENABLED_FEATURES
from torch_tensorrt._features import needs_not_tensorrt_rtx
Expand All @@ -28,6 +27,8 @@
)
from torch_tensorrt.dynamo.utils import DYNAMIC_DIM

from tensorrt import ITensor as TRTTensor

_LOGGER: logging.Logger = logging.getLogger(__name__)


Expand Down Expand Up @@ -597,8 +598,10 @@ def index_has_bool_indices(
# case and is checked first via HIGH priority.
@dynamo_tensorrt_converter(
torch.ops.aten.index.Tensor,
capability_validator=lambda node, settings: index_dtype_validator(node, settings)
and not index_has_bool_indices(node, settings),
capability_validator=lambda node, settings: (
index_dtype_validator(node, settings)
and not index_has_bool_indices(node, settings)
),
priority=ConverterPriority.HIGH,
supports_dynamic_shapes=True,
requires_output_allocator=False,
Expand Down Expand Up @@ -629,9 +632,11 @@ def aten_ops_index(
# output shapes, so an output allocator is required.
@dynamo_tensorrt_converter(
torch.ops.aten.index.Tensor,
capability_validator=lambda node, settings: index_dtype_validator(node, settings)
and index_nonbool_validator(node, settings)
and index_has_bool_indices(node, settings),
capability_validator=lambda node, settings: (
index_dtype_validator(node, settings)
and index_nonbool_validator(node, settings)
and index_has_bool_indices(node, settings)
),
supports_dynamic_shapes=True,
requires_output_allocator=True,
)
Expand Down Expand Up @@ -1154,9 +1159,11 @@ def aten_ops_index_put_accumulate(

@dynamo_tensorrt_converter(
torch.ops.aten.index_put.default,
capability_validator=lambda node, settings: index_dtype_validator(node, settings)
and index_nonbool_validator(node, settings)
and not args_bounds_check(node.args, 3, False),
capability_validator=lambda node, settings: (
index_dtype_validator(node, settings)
and index_nonbool_validator(node, settings)
and not args_bounds_check(node.args, 3, False)
),
supports_dynamic_shapes=True,
)
@enforce_tensor_types(
Expand Down Expand Up @@ -1368,8 +1375,8 @@ def validator(

@dynamo_tensorrt_converter(
torch.ops.aten.clone.default,
capability_validator=lambda node, settings: not is_only_operator_on_placeholder(
node, settings
capability_validator=lambda node, settings: (
not is_only_operator_on_placeholder(node, settings)
),
supports_dynamic_shapes=True,
)
Expand Down Expand Up @@ -3578,14 +3585,14 @@ def aten_ops_copy(
kwargs: Dict[str, Argument],
name: str,
) -> Union[TRTTensor, Sequence[TRTTensor]]:
src = args[1]
dest, src = args[0], args[1]
return impl.cast.to_copy(
ctx,
target,
SourceIR.ATEN,
name,
src,
src.dtype,
dest.dtype,
force_layer=True,
)

Expand Down
15 changes: 11 additions & 4 deletions py/torch_tensorrt/dynamo/conversion/impl/elementwise/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
import warnings
from typing import Any, Callable, Optional, Union

import tensorrt as trt
import torch
from torch.fx.node import Target
from torch_tensorrt import _enums
Expand All @@ -20,6 +19,8 @@
)
from torch_tensorrt.dynamo.types import TRTElementWiseOp, TRTTensor

import tensorrt as trt

logger = logging.getLogger(__name__)


Expand Down Expand Up @@ -133,10 +134,16 @@ def convert_binary_elementwise(
lhs_val = get_trt_tensor(ctx, lhs_val, f"{name}_lhs", lhs_dtype)
rhs_val = get_trt_tensor(ctx, rhs_val, f"{name}_rhs", rhs_dtype)

# Match PyTorch's actual elementwise dtype rules, which use the weak
# ZeroDim promotion rule (0-dim fp32 * Nd fp16 -> fp16, not fp32).
# torch.promote_types is the strong rule and gives the wrong answer
# whenever one operand is a 0-dim tensor of a higher-precision dtype.
lhs_torch_dtype = _enums.dtype._from(lhs_val.dtype).to(torch.dtype)
rhs_torch_dtype = _enums.dtype._from(rhs_val.dtype).to(torch.dtype)
promoted_type = _enums.dtype._from(
torch.promote_types(
_enums.dtype._from(lhs_val.dtype).to(torch.dtype),
_enums.dtype._from(rhs_val.dtype).to(torch.dtype),
torch.result_type(
torch.empty([1] * len(lhs_val.shape), dtype=lhs_torch_dtype),
torch.empty([1] * len(rhs_val.shape), dtype=rhs_torch_dtype),
)
)
trt_promoted_type = promoted_type.to(trt.DataType)
Expand Down
19 changes: 19 additions & 0 deletions tests/py/dynamo/conversion/test_binary_ops_aten.py
Original file line number Diff line number Diff line change
Expand Up @@ -252,6 +252,25 @@ def forward(self, x):
inputs = [torch.randn(2, 2, dtype=torch.bfloat16)]
self.run_test(m, inputs)

def test_elementwise_mul_zerodim_fp32_against_fp16_tensor(self):
# Regression for pytorch/TensorRT#4265: a 0-dim fp32 scalar (e.g.
# an nn.Parameter holding `torch.tensor(1.0)`) multiplied with an
# Nd fp16 tensor must produce fp16 (PyTorch's weak-ZeroDim rule),
# not fp32 from torch.promote_types. The previous strong-promotion
# path forced fp32 through downstream ops and broke the next
# type-strict layer (here MatMul) with `A=Float, B=Half`.
class ZeroDimMulThenMatmul(nn.Module):
def forward(self, alpha_0d_fp32, x_fp16, weight_fp16):
scaled = alpha_0d_fp32 * x_fp16
return torch.matmul(scaled, weight_fp16)

inputs = [
torch.tensor(1.0, dtype=torch.float32),
torch.randn(2, 4, 8, dtype=torch.float16),
torch.randn(8, 8, dtype=torch.float16),
]
self.run_test(ZeroDimMulThenMatmul(), inputs, use_dynamo_tracer=True)


if __name__ == "__main__":
run_tests()
24 changes: 24 additions & 0 deletions tests/py/dynamo/conversion/test_copy_aten.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,30 @@ def forward(self, input, src):
input_specs,
)

def test_copy_mixed_dtype_dest_fp32_src_fp16(self):
# Regression for pytorch/TensorRT#4265: aten.copy(self, src) returns a
# tensor with self's dtype (src is implicitly cast). The previous
# converter cast to src.dtype, so an fp16 update flowed into the
# downstream fp32 scatter (produced by lowering slice-assignment)
# and TRT IScatterLayer rejected the build with
# `input Float / updates Half`.
class IndexAssignMixedDtype(nn.Module):
def forward(self, dest_fp32, src_fp16):
out = dest_fp32.clone()
out[..., :2] = src_fp16
return out

inputs = [
torch.randn(2, 3, 5, dtype=torch.float32),
torch.randn(2, 3, 2, dtype=torch.float16),
]
self.run_test(
IndexAssignMixedDtype(),
inputs,
use_dynamo_tracer=True,
enable_passes=True,
)


if __name__ == "__main__":
run_tests()
Loading
Loading