From d0081f0b57e63a296da3169e28938d066186a4cd Mon Sep 17 00:00:00 2001 From: Tim Moon Date: Wed, 8 Apr 2026 03:21:19 +0000 Subject: [PATCH 1/4] Skip quantization kernels when tensor size is zero Signed-off-by: Tim Moon --- tests/cpp/operator/test_act.cu | 7 +++++-- tests/cpp/operator/test_cast.cu | 4 +++- tests/cpp/operator/test_cast_gated_swiglu.cu | 4 +++- tests/cpp/operator/test_cast_mxfp8_gated_swiglu.cu | 2 ++ transformer_engine/common/cast/fp8/gated_fp8.cuh | 3 +++ transformer_engine/common/cast/fp8/quantize_fp8.cuh | 13 +++++++++++++ .../common/cast/mxfp8/dequantize_mxfp8.cuh | 2 ++ .../common/cast/mxfp8/gated_mxfp8.cuh | 11 +++++++---- .../common/cast/mxfp8/group_quantize_mxfp8.cuh | 9 +++++++++ .../common/cast/mxfp8/quantize_mxfp8.cuh | 8 ++++++++ .../common/cast/nvfp4/dequantize_nvfp4.cuh | 1 + .../cast/nvfp4/group_quantize_transpose_nvfp4.cuh | 1 + .../common/cast/nvfp4/quantize_nvfp4.cuh | 1 + .../common/cast/nvfp4/quantize_transpose_nvfp4.cuh | 1 + .../quantize_transpose_nvfp4_tuned_1D.cuh | 1 + 15 files changed, 60 insertions(+), 8 deletions(-) diff --git a/tests/cpp/operator/test_act.cu b/tests/cpp/operator/test_act.cu index b4280818a8..602d50ac4d 100644 --- a/tests/cpp/operator/test_act.cu +++ b/tests/cpp/operator/test_act.cu @@ -193,7 +193,8 @@ void performTestGLU(const size_t N, const size_t H) { auto err = cudaGetLastError(); ASSERT_EQ(err, cudaSuccess) << cudaGetErrorString(err); - if (otype == DType::kFloat8E4M3 || otype == DType::kFloat8E5M2) { + if ((otype == DType::kFloat8E4M3 || otype == DType::kFloat8E5M2) + && N * H > 0) { auto [atol, rtol] = getTolerances(DType::kFloat32); compareResults("amax", output.amax(), ref_amax, atol, rtol); if (output.scaling_mode() == NVTE_DELAYED_TENSOR_SCALING) { @@ -392,7 +393,9 @@ std::vector> act_test_cases = {{2048, 12288}, {65536, 128}, {256, 256}, {257, 259}, - {128, 128+1}}; + {128, 128+1}, + {0, 128}, + {128, 0}}; } // namespace diff --git a/tests/cpp/operator/test_cast.cu b/tests/cpp/operator/test_cast.cu index 35d9dd2efd..8f566181a1 100644 --- a/tests/cpp/operator/test_cast.cu +++ b/tests/cpp/operator/test_cast.cu @@ -64,7 +64,7 @@ void performTest(const std::vector& shape) { cudaDeviceSynchronize(); auto err = cudaGetLastError(); ASSERT_EQ(err, cudaSuccess) << cudaGetErrorString(err); - if (isFp8Type(otype)) { + if (isFp8Type(otype) && full_size > 0) { auto [atol_amax, rtol_amax] = getTolerances(DType::kFloat32); compareResults("amax", output_c.amax(), ref_amax, atol_amax, rtol_amax); float ref_scale_inv = 1.f / output_c.scale(); @@ -91,6 +91,8 @@ std::vector> test_cases = { {5, 160}, {5, 4, 3, 160}, {217, 256}, + {0, 128}, + {128, 0}, }; } // namespace diff --git a/tests/cpp/operator/test_cast_gated_swiglu.cu b/tests/cpp/operator/test_cast_gated_swiglu.cu index 298b978f2a..6894ad9b1d 100644 --- a/tests/cpp/operator/test_cast_gated_swiglu.cu +++ b/tests/cpp/operator/test_cast_gated_swiglu.cu @@ -97,7 +97,7 @@ void performTest(const std::vector& shape) { rows, cols); - if (isFp8Type(otype)) { + if (isFp8Type(otype) && input_size > 0) { auto [atol_amax, rtol_amax] = getTolerances(DType::kFloat32); compareResults("amax", output_c.amax(), ref_amax, atol_amax, rtol_amax); float ref_scale_inv = 1.f / output_c.scale(); @@ -118,6 +118,8 @@ std::vector> test_cases = { {217, 256}, {1296}, {5, 4, 3, 160}, + {0, 128}, + {128, 0}, }; } // namespace diff --git a/tests/cpp/operator/test_cast_mxfp8_gated_swiglu.cu b/tests/cpp/operator/test_cast_mxfp8_gated_swiglu.cu index 3ff0e8ae99..aad32ba427 100644 --- a/tests/cpp/operator/test_cast_mxfp8_gated_swiglu.cu +++ b/tests/cpp/operator/test_cast_mxfp8_gated_swiglu.cu @@ -415,6 +415,8 @@ std::vector> matrix_sizes = { {768, 1024}, {8192, 128}, {577, 1632}, + {0, 128}, + {128, 0}, }; std::vector> block_sizes = { diff --git a/transformer_engine/common/cast/fp8/gated_fp8.cuh b/transformer_engine/common/cast/fp8/gated_fp8.cuh index 6123d7130b..83025a75ba 100644 --- a/transformer_engine/common/cast/fp8/gated_fp8.cuh +++ b/transformer_engine/common/cast/fp8/gated_fp8.cuh @@ -282,9 +282,12 @@ void cast_gated_tma(const Tensor &gated_input, const Tensor &grad, Tensor *outpu checkCuDriverContext(stream); NVTE_CHECK(!output->has_columnwise_data(), "Only rowwise cast supported in this function."); + + // Tensor dimensions const size_t rows = gated_input.flat_first_dim(); const size_t cols = gated_input.flat_last_dim() / 2; const size_t output_cols = (IS_BWD ? 2 : 1) * cols; + if (rows == 0 || cols == 0) { return; } const size_t blocks_Y = DIVUP(rows, CHUNK_DIM_Y); const size_t blocks_X = DIVUP(cols, CHUNK_DIM_X); diff --git a/transformer_engine/common/cast/fp8/quantize_fp8.cuh b/transformer_engine/common/cast/fp8/quantize_fp8.cuh index 96a42b494d..57e4a2d874 100644 --- a/transformer_engine/common/cast/fp8/quantize_fp8.cuh +++ b/transformer_engine/common/cast/fp8/quantize_fp8.cuh @@ -355,7 +355,10 @@ __global__ void __launch_bounds__(THREADS_PER_BLOCK) template void quantize_1D(const Tensor &input, Tensor *output, cudaStream_t stream) { using namespace quantize_1D_kernel; + + // Tensor size const size_t N = product(input.data.shape); + if (N == 0) { return; } const bool isFullTile = (N % ELEMS_PER_BLOCK == 0); NVTE_CHECK(isFullTile, "Only full tiles are supported."); @@ -391,8 +394,18 @@ void quantize_2D(const Tensor &input, const Tensor *act_input, Tensor *output, T using namespace quantize_2D_kernel; checkCuDriverContext(stream); + // Tensor dimensions const size_t rows = input.flat_first_dim(); const size_t cols = input.flat_last_dim(); + + // Skip kernel if tensor size is zero + if (rows == 0 || cols == 0) { + if constexpr (IS_DBIAS) { + NVTE_ERROR("Invalid tensor shape for DBias computation (shape=", input.shape(), ")."); + } + return; + } + const size_t chunks_Y = DIVUP(rows, FP8_CHUNK_DIM_Y); const size_t chunks_X = DIVUP(cols, FP8_CHUNK_DIM_X); const size_t blocks_Y = chunks_Y; diff --git a/transformer_engine/common/cast/mxfp8/dequantize_mxfp8.cuh b/transformer_engine/common/cast/mxfp8/dequantize_mxfp8.cuh index f8fecaa4e1..bacc9de485 100644 --- a/transformer_engine/common/cast/mxfp8/dequantize_mxfp8.cuh +++ b/transformer_engine/common/cast/mxfp8/dequantize_mxfp8.cuh @@ -249,6 +249,8 @@ inline void dequantize(const Tensor &input, Tensor *output, cudaStream_t stream) const size_t rows = input.flat_first_dim(); const size_t cols = input.flat_last_dim(); + if (rows == 0 || cols == 0) { return; } + const size_t chunks_Y = DIVUP(rows, CHUNK_DIM_Y); const size_t chunks_X = DIVUP(cols, CHUNK_DIM_X); diff --git a/transformer_engine/common/cast/mxfp8/gated_mxfp8.cuh b/transformer_engine/common/cast/mxfp8/gated_mxfp8.cuh index 49169a4e14..16159262ed 100644 --- a/transformer_engine/common/cast/mxfp8/gated_mxfp8.cuh +++ b/transformer_engine/common/cast/mxfp8/gated_mxfp8.cuh @@ -707,6 +707,11 @@ void quantize_gated(const Tensor &gated_input, const Tensor &grad, Tensor *outpu using namespace gated_kernel; checkCuDriverContext(stream); + const size_t rows = gated_input.flat_first_dim(); + const size_t cols = gated_input.flat_last_dim() / 2; + const size_t output_cols = (IS_BWD ? 2 : 1) * cols; + if (rows == 0 || cols == 0) { return; } + const bool USE_ROWWISE_SCALING = output->has_data(); const bool USE_COLWISE_SCALING = output->has_columnwise_data(); const bool with_gemm_swizzled_scales = output->with_gemm_swizzled_scales; @@ -725,12 +730,10 @@ void quantize_gated(const Tensor &gated_input, const Tensor &grad, Tensor *outpu scaling_type = ScalingType::COLWISE; } else if (USE_ROWWISE_SCALING && USE_COLWISE_SCALING) { scaling_type = ScalingType::BIDIMENSIONAL; + } else { + NVTE_ERROR("Missing both row-wise and column-wise data."); } - const size_t rows = gated_input.flat_first_dim(); - const size_t cols = gated_input.flat_last_dim() / 2; - const size_t output_cols = (IS_BWD ? 2 : 1) * cols; - const size_t blocks_Y = DIVUP(rows, CHUNK_DIM_Y); const size_t blocks_X = DIVUP(cols, CHUNK_DIM_X); diff --git a/transformer_engine/common/cast/mxfp8/group_quantize_mxfp8.cuh b/transformer_engine/common/cast/mxfp8/group_quantize_mxfp8.cuh index ce6917aa42..18fbfadd1a 100644 --- a/transformer_engine/common/cast/mxfp8/group_quantize_mxfp8.cuh +++ b/transformer_engine/common/cast/mxfp8/group_quantize_mxfp8.cuh @@ -870,6 +870,15 @@ void group_quantize(const GroupedTensor *input, const GroupedTensor *activations } } + // Skip kernel if tensor size is zero + if (elts_total == 0) { + if constexpr (IS_DBIAS) { + NVTE_ERROR("Invalid grouped tensor shape for DBias computation (first_logical_dim=", + first_logical_dim, ", last_logical_dim=", last_logical_dim, ")"); + } + return; + } + TRANSFORMER_ENGINE_TYPE_SWITCH_NON_FP8ONLY( input->dtype(), IType, TRANSFORMER_ENGINE_TYPE_SWITCH_FP8ONLY( diff --git a/transformer_engine/common/cast/mxfp8/quantize_mxfp8.cuh b/transformer_engine/common/cast/mxfp8/quantize_mxfp8.cuh index f36b071081..f344ea7b79 100644 --- a/transformer_engine/common/cast/mxfp8/quantize_mxfp8.cuh +++ b/transformer_engine/common/cast/mxfp8/quantize_mxfp8.cuh @@ -579,6 +579,14 @@ void quantize(const Tensor &input, const Tensor *act_input, const Tensor *noop, const size_t rows = input.flat_first_dim(); const size_t cols = input.flat_last_dim(); + // Skip kernel if tensor size is zero + if (rows == 0 || cols == 0) { + if constexpr (IS_DBIAS) { + NVTE_ERROR("Invalid tensor shape for DBias computation (shape=", input.shape(), ")."); + } + return; + } + // Tensor chunk handled by each CUDA block constexpr size_t CHUNK_DIM_Y = CAST_DBIAS_ONLY ? 128 : 64; constexpr size_t CHUNK_DIM_X = CAST_DBIAS_ONLY ? 128 : 64; diff --git a/transformer_engine/common/cast/nvfp4/dequantize_nvfp4.cuh b/transformer_engine/common/cast/nvfp4/dequantize_nvfp4.cuh index ccdc4c93e3..a59dd657f4 100644 --- a/transformer_engine/common/cast/nvfp4/dequantize_nvfp4.cuh +++ b/transformer_engine/common/cast/nvfp4/dequantize_nvfp4.cuh @@ -87,6 +87,7 @@ inline void dequantize(const Tensor &input, Tensor *output, cudaStream_t stream) constexpr int FP4_BLOCK_SIZE = 16; const size_t N = input.flat_first_dim(); const size_t M = input.flat_last_dim(); + if (N == 0 || M == 0) { return; } NVTE_CHECK(M % FP4_BLOCK_SIZE == 0, "Last dimension of FP4 tensors needs to be divisible by ", FP4_BLOCK_SIZE, ", but got ", input.data.shape, "."); diff --git a/transformer_engine/common/cast/nvfp4/group_quantize_transpose_nvfp4.cuh b/transformer_engine/common/cast/nvfp4/group_quantize_transpose_nvfp4.cuh index a2f3dac15a..65c0a35189 100644 --- a/transformer_engine/common/cast/nvfp4/group_quantize_transpose_nvfp4.cuh +++ b/transformer_engine/common/cast/nvfp4/group_quantize_transpose_nvfp4.cuh @@ -785,6 +785,7 @@ void group_quantize_transpose(const Tensor &input, const Tensor *noop, const size_t rows = input.flat_first_dim(); const size_t cols = input.flat_last_dim(); + if (rows == 0 || cols == 0) { return; } NVTE_CHECK(rows % 32 == 0, "Number of tensor rows must be a multiple of 32"); // 16B alignment for TMA diff --git a/transformer_engine/common/cast/nvfp4/quantize_nvfp4.cuh b/transformer_engine/common/cast/nvfp4/quantize_nvfp4.cuh index ec80924df5..ffe1f832b3 100644 --- a/transformer_engine/common/cast/nvfp4/quantize_nvfp4.cuh +++ b/transformer_engine/common/cast/nvfp4/quantize_nvfp4.cuh @@ -560,6 +560,7 @@ inline void quantize(const Tensor &input, const Tensor *noop, Tensor *output, cu const size_t rows = input.flat_first_dim(); const size_t cols = input.flat_last_dim(); + if (rows == 0 || cols == 0) { return; } constexpr size_t CHUNK_DIM_Y = 128; constexpr size_t CHUNK_DIM_X = 128; diff --git a/transformer_engine/common/cast/nvfp4/quantize_transpose_nvfp4.cuh b/transformer_engine/common/cast/nvfp4/quantize_transpose_nvfp4.cuh index f164636e38..b1801b0069 100644 --- a/transformer_engine/common/cast/nvfp4/quantize_transpose_nvfp4.cuh +++ b/transformer_engine/common/cast/nvfp4/quantize_transpose_nvfp4.cuh @@ -1197,6 +1197,7 @@ void quantize_transpose(const Tensor &input, const Tensor *noop, Tensor *output, const size_t rows = input.flat_first_dim(); const size_t cols = input.flat_last_dim(); + if (rows == 0 || cols == 0) { return; } NVTE_CHECK(rows % 32 == 0, "Number of tensor rows must be a multiple of 32"); // 16B alignment for TMA diff --git a/transformer_engine/common/cast/nvfp4/specialized/quantize_transpose_nvfp4_tuned_1D.cuh b/transformer_engine/common/cast/nvfp4/specialized/quantize_transpose_nvfp4_tuned_1D.cuh index fc337f6078..f6ebf7489c 100644 --- a/transformer_engine/common/cast/nvfp4/specialized/quantize_transpose_nvfp4_tuned_1D.cuh +++ b/transformer_engine/common/cast/nvfp4/specialized/quantize_transpose_nvfp4_tuned_1D.cuh @@ -704,6 +704,7 @@ inline void quantize_transpose_tuned_1D(const Tensor &input, const Tensor *noop, const size_t rows = input.flat_first_dim(); const size_t cols = input.flat_last_dim(); + if (rows == 0 || cols == 0) { return; } NVTE_CHECK(rows % 32 == 0, "Number of tensor rows must be a multiple of 32"); // 16B alignment for TMA From a61dfe2e3eb71ebccc1cede10e5186234cddeab1 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 8 Apr 2026 03:48:22 +0000 Subject: [PATCH 2/4] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- transformer_engine/common/cast/fp8/gated_fp8.cuh | 4 +++- transformer_engine/common/cast/fp8/quantize_fp8.cuh | 4 +++- transformer_engine/common/cast/mxfp8/dequantize_mxfp8.cuh | 4 +++- transformer_engine/common/cast/mxfp8/gated_mxfp8.cuh | 4 +++- transformer_engine/common/cast/nvfp4/dequantize_nvfp4.cuh | 4 +++- .../common/cast/nvfp4/group_quantize_transpose_nvfp4.cuh | 4 +++- transformer_engine/common/cast/nvfp4/quantize_nvfp4.cuh | 4 +++- .../common/cast/nvfp4/quantize_transpose_nvfp4.cuh | 4 +++- .../nvfp4/specialized/quantize_transpose_nvfp4_tuned_1D.cuh | 4 +++- 9 files changed, 27 insertions(+), 9 deletions(-) diff --git a/transformer_engine/common/cast/fp8/gated_fp8.cuh b/transformer_engine/common/cast/fp8/gated_fp8.cuh index 83025a75ba..116c6b33f5 100644 --- a/transformer_engine/common/cast/fp8/gated_fp8.cuh +++ b/transformer_engine/common/cast/fp8/gated_fp8.cuh @@ -287,7 +287,9 @@ void cast_gated_tma(const Tensor &gated_input, const Tensor &grad, Tensor *outpu const size_t rows = gated_input.flat_first_dim(); const size_t cols = gated_input.flat_last_dim() / 2; const size_t output_cols = (IS_BWD ? 2 : 1) * cols; - if (rows == 0 || cols == 0) { return; } + if (rows == 0 || cols == 0) { + return; + } const size_t blocks_Y = DIVUP(rows, CHUNK_DIM_Y); const size_t blocks_X = DIVUP(cols, CHUNK_DIM_X); diff --git a/transformer_engine/common/cast/fp8/quantize_fp8.cuh b/transformer_engine/common/cast/fp8/quantize_fp8.cuh index 57e4a2d874..f7c8b4a874 100644 --- a/transformer_engine/common/cast/fp8/quantize_fp8.cuh +++ b/transformer_engine/common/cast/fp8/quantize_fp8.cuh @@ -358,7 +358,9 @@ void quantize_1D(const Tensor &input, Tensor *output, cudaStream_t stream) { // Tensor size const size_t N = product(input.data.shape); - if (N == 0) { return; } + if (N == 0) { + return; + } const bool isFullTile = (N % ELEMS_PER_BLOCK == 0); NVTE_CHECK(isFullTile, "Only full tiles are supported."); diff --git a/transformer_engine/common/cast/mxfp8/dequantize_mxfp8.cuh b/transformer_engine/common/cast/mxfp8/dequantize_mxfp8.cuh index bacc9de485..3d5a0ad252 100644 --- a/transformer_engine/common/cast/mxfp8/dequantize_mxfp8.cuh +++ b/transformer_engine/common/cast/mxfp8/dequantize_mxfp8.cuh @@ -249,7 +249,9 @@ inline void dequantize(const Tensor &input, Tensor *output, cudaStream_t stream) const size_t rows = input.flat_first_dim(); const size_t cols = input.flat_last_dim(); - if (rows == 0 || cols == 0) { return; } + if (rows == 0 || cols == 0) { + return; + } const size_t chunks_Y = DIVUP(rows, CHUNK_DIM_Y); const size_t chunks_X = DIVUP(cols, CHUNK_DIM_X); diff --git a/transformer_engine/common/cast/mxfp8/gated_mxfp8.cuh b/transformer_engine/common/cast/mxfp8/gated_mxfp8.cuh index 16159262ed..d25cce8318 100644 --- a/transformer_engine/common/cast/mxfp8/gated_mxfp8.cuh +++ b/transformer_engine/common/cast/mxfp8/gated_mxfp8.cuh @@ -710,7 +710,9 @@ void quantize_gated(const Tensor &gated_input, const Tensor &grad, Tensor *outpu const size_t rows = gated_input.flat_first_dim(); const size_t cols = gated_input.flat_last_dim() / 2; const size_t output_cols = (IS_BWD ? 2 : 1) * cols; - if (rows == 0 || cols == 0) { return; } + if (rows == 0 || cols == 0) { + return; + } const bool USE_ROWWISE_SCALING = output->has_data(); const bool USE_COLWISE_SCALING = output->has_columnwise_data(); diff --git a/transformer_engine/common/cast/nvfp4/dequantize_nvfp4.cuh b/transformer_engine/common/cast/nvfp4/dequantize_nvfp4.cuh index a59dd657f4..19b22297c8 100644 --- a/transformer_engine/common/cast/nvfp4/dequantize_nvfp4.cuh +++ b/transformer_engine/common/cast/nvfp4/dequantize_nvfp4.cuh @@ -87,7 +87,9 @@ inline void dequantize(const Tensor &input, Tensor *output, cudaStream_t stream) constexpr int FP4_BLOCK_SIZE = 16; const size_t N = input.flat_first_dim(); const size_t M = input.flat_last_dim(); - if (N == 0 || M == 0) { return; } + if (N == 0 || M == 0) { + return; + } NVTE_CHECK(M % FP4_BLOCK_SIZE == 0, "Last dimension of FP4 tensors needs to be divisible by ", FP4_BLOCK_SIZE, ", but got ", input.data.shape, "."); diff --git a/transformer_engine/common/cast/nvfp4/group_quantize_transpose_nvfp4.cuh b/transformer_engine/common/cast/nvfp4/group_quantize_transpose_nvfp4.cuh index 65c0a35189..e289423913 100644 --- a/transformer_engine/common/cast/nvfp4/group_quantize_transpose_nvfp4.cuh +++ b/transformer_engine/common/cast/nvfp4/group_quantize_transpose_nvfp4.cuh @@ -785,7 +785,9 @@ void group_quantize_transpose(const Tensor &input, const Tensor *noop, const size_t rows = input.flat_first_dim(); const size_t cols = input.flat_last_dim(); - if (rows == 0 || cols == 0) { return; } + if (rows == 0 || cols == 0) { + return; + } NVTE_CHECK(rows % 32 == 0, "Number of tensor rows must be a multiple of 32"); // 16B alignment for TMA diff --git a/transformer_engine/common/cast/nvfp4/quantize_nvfp4.cuh b/transformer_engine/common/cast/nvfp4/quantize_nvfp4.cuh index ffe1f832b3..4d1f74af78 100644 --- a/transformer_engine/common/cast/nvfp4/quantize_nvfp4.cuh +++ b/transformer_engine/common/cast/nvfp4/quantize_nvfp4.cuh @@ -560,7 +560,9 @@ inline void quantize(const Tensor &input, const Tensor *noop, Tensor *output, cu const size_t rows = input.flat_first_dim(); const size_t cols = input.flat_last_dim(); - if (rows == 0 || cols == 0) { return; } + if (rows == 0 || cols == 0) { + return; + } constexpr size_t CHUNK_DIM_Y = 128; constexpr size_t CHUNK_DIM_X = 128; diff --git a/transformer_engine/common/cast/nvfp4/quantize_transpose_nvfp4.cuh b/transformer_engine/common/cast/nvfp4/quantize_transpose_nvfp4.cuh index b1801b0069..c91c995a6b 100644 --- a/transformer_engine/common/cast/nvfp4/quantize_transpose_nvfp4.cuh +++ b/transformer_engine/common/cast/nvfp4/quantize_transpose_nvfp4.cuh @@ -1197,7 +1197,9 @@ void quantize_transpose(const Tensor &input, const Tensor *noop, Tensor *output, const size_t rows = input.flat_first_dim(); const size_t cols = input.flat_last_dim(); - if (rows == 0 || cols == 0) { return; } + if (rows == 0 || cols == 0) { + return; + } NVTE_CHECK(rows % 32 == 0, "Number of tensor rows must be a multiple of 32"); // 16B alignment for TMA diff --git a/transformer_engine/common/cast/nvfp4/specialized/quantize_transpose_nvfp4_tuned_1D.cuh b/transformer_engine/common/cast/nvfp4/specialized/quantize_transpose_nvfp4_tuned_1D.cuh index f6ebf7489c..91fdc12e52 100644 --- a/transformer_engine/common/cast/nvfp4/specialized/quantize_transpose_nvfp4_tuned_1D.cuh +++ b/transformer_engine/common/cast/nvfp4/specialized/quantize_transpose_nvfp4_tuned_1D.cuh @@ -704,7 +704,9 @@ inline void quantize_transpose_tuned_1D(const Tensor &input, const Tensor *noop, const size_t rows = input.flat_first_dim(); const size_t cols = input.flat_last_dim(); - if (rows == 0 || cols == 0) { return; } + if (rows == 0 || cols == 0) { + return; + } NVTE_CHECK(rows % 32 == 0, "Number of tensor rows must be a multiple of 32"); // 16B alignment for TMA From ea6f858c70ff07881b558d9e28398edf692b53b6 Mon Sep 17 00:00:00 2001 From: Tim Moon Date: Thu, 9 Apr 2026 00:28:39 +0000 Subject: [PATCH 3/4] Use consistent early-termination logic in dbias kernels Signed-off-by: Tim Moon --- .../common/cast/fp8/quantize_fp8.cuh | 28 +++++++++++-------- .../cast/mxfp8/group_quantize_mxfp8.cuh | 20 ++++++++----- .../common/cast/mxfp8/quantize_mxfp8.cuh | 24 +++++++++------- 3 files changed, 43 insertions(+), 29 deletions(-) diff --git a/transformer_engine/common/cast/fp8/quantize_fp8.cuh b/transformer_engine/common/cast/fp8/quantize_fp8.cuh index f7c8b4a874..e42cad05bb 100644 --- a/transformer_engine/common/cast/fp8/quantize_fp8.cuh +++ b/transformer_engine/common/cast/fp8/quantize_fp8.cuh @@ -400,29 +400,23 @@ void quantize_2D(const Tensor &input, const Tensor *act_input, Tensor *output, T const size_t rows = input.flat_first_dim(); const size_t cols = input.flat_last_dim(); - // Skip kernel if tensor size is zero - if (rows == 0 || cols == 0) { - if constexpr (IS_DBIAS) { - NVTE_ERROR("Invalid tensor shape for DBias computation (shape=", input.shape(), ")."); - } - return; - } - const size_t chunks_Y = DIVUP(rows, FP8_CHUNK_DIM_Y); const size_t chunks_X = DIVUP(cols, FP8_CHUNK_DIM_X); const size_t blocks_Y = chunks_Y; const size_t blocks_X = chunks_X; - const size_t dbias_rows = blocks_Y; - const size_t dbias_cols = cols; - NVTE_CHECK(is_fp8_dtype(output->dtype()), "Output must have FP8 type."); - NVTE_CHECK(output->scale_inv.dptr != nullptr, "Scaling tensor must be allocated"); + // Workspace for dbias + const size_t dbias_rows = blocks_Y; + const size_t dbias_cols = cols; if constexpr (IS_DBIAS) { NVTE_CHECK(dbias->data.dtype == input.data.dtype, "DBias must have the same type as input."); NVTE_CHECK(dbias->data.shape == std::vector{cols}, "Wrong shape of DBias."); NVTE_CHECK(workspace != nullptr, "Workspace must be a tensor."); + NVTE_CHECK(dbias_rows > 0 && dbias_cols > 0, + "Invalid workspace shape for DBias computation (input shape=", + input.shape(), ", workspace shape=(", dbias_rows, ",", dbias_cols, "))."); if (workspace->data.dptr == nullptr) { workspace->data.shape = {dbias_rows, dbias_cols}; @@ -430,10 +424,20 @@ void quantize_2D(const Tensor &input, const Tensor *act_input, Tensor *output, T return; } } + + // Skip kernel if tensor size is zero + if (rows == 0 || cols == 0) { + if constexpr (IS_DBIAS) { + NVTE_ERROR("Invalid tensor shape for DBias computation (shape=", input.shape(), ")."); + } + return; + } + float *const workspace_ptr = IS_DBIAS ? reinterpret_cast(workspace->data.dptr) : nullptr; float *const amax_ptr = reinterpret_cast(output->amax.dptr); float *const scale_inv_ptr = reinterpret_cast(output->scale_inv.dptr); float *const scale_ptr = reinterpret_cast(output->scale.dptr); + NVTE_CHECK(scale_inv_ptr != nullptr, "Scaling tensor must be allocated"); const dim3 block(FP8_THREADS_PER_CHUNK); const dim3 grid(blocks_X, blocks_Y); diff --git a/transformer_engine/common/cast/mxfp8/group_quantize_mxfp8.cuh b/transformer_engine/common/cast/mxfp8/group_quantize_mxfp8.cuh index 18fbfadd1a..7cfd6da70e 100644 --- a/transformer_engine/common/cast/mxfp8/group_quantize_mxfp8.cuh +++ b/transformer_engine/common/cast/mxfp8/group_quantize_mxfp8.cuh @@ -844,13 +844,7 @@ void group_quantize(const GroupedTensor *input, const GroupedTensor *activations e8m0_t *const scales_rowwise_ptr = reinterpret_cast(output->scale_inv.dptr); e8m0_t *const scales_colwise_ptr = reinterpret_cast(output->columnwise_scale_inv.dptr); - if (use_rowwise_scaling) { - NVTE_CHECK(scales_rowwise_ptr != nullptr, "Scaling tensor must be allocated"); - } - if (use_colwise_scaling) { - NVTE_CHECK(scales_colwise_ptr != nullptr, "Columnwise scaling tensor must be allocated"); - } - + // Workspace for dbias if constexpr (IS_DBIAS) { NVTE_CHECK(is_single_tensor, "DBias is only supported for tensors with the const last dimension."); @@ -863,6 +857,10 @@ void group_quantize(const GroupedTensor *input, const GroupedTensor *activations NVTE_CHECK(workspace != nullptr, "Workspace must be a tensor."); const size_t dbias_workspace_rows = DIVUP(first_logical_dim, static_cast(CHUNK_DIM_Y)); const size_t dbias_workspace_cols = last_logical_dim; + NVTE_CHECK(dbias_workspace_rows > 0 && dbias_workspace_cols > 0, + "Invalid workspace shape for DBias computation (input first_logical_dim=", + first_logical_dim, ", input last_logical_dim=", last_logical_dim, + ", workspace shape=(", dbias_workspace_rows, ",", dbias_workspace_cols, "))."); if (workspace->data.dptr == nullptr) { workspace->data.shape = {dbias_workspace_rows, dbias_workspace_cols}; workspace->data.dtype = DType::kFloat32; @@ -879,6 +877,14 @@ void group_quantize(const GroupedTensor *input, const GroupedTensor *activations return; } + // Check pointers + if (use_rowwise_scaling) { + NVTE_CHECK(scales_rowwise_ptr != nullptr, "Scaling tensor must be allocated"); + } + if (use_colwise_scaling) { + NVTE_CHECK(scales_colwise_ptr != nullptr, "Columnwise scaling tensor must be allocated"); + } + TRANSFORMER_ENGINE_TYPE_SWITCH_NON_FP8ONLY( input->dtype(), IType, TRANSFORMER_ENGINE_TYPE_SWITCH_FP8ONLY( diff --git a/transformer_engine/common/cast/mxfp8/quantize_mxfp8.cuh b/transformer_engine/common/cast/mxfp8/quantize_mxfp8.cuh index f344ea7b79..371299481e 100644 --- a/transformer_engine/common/cast/mxfp8/quantize_mxfp8.cuh +++ b/transformer_engine/common/cast/mxfp8/quantize_mxfp8.cuh @@ -579,14 +579,6 @@ void quantize(const Tensor &input, const Tensor *act_input, const Tensor *noop, const size_t rows = input.flat_first_dim(); const size_t cols = input.flat_last_dim(); - // Skip kernel if tensor size is zero - if (rows == 0 || cols == 0) { - if constexpr (IS_DBIAS) { - NVTE_ERROR("Invalid tensor shape for DBias computation (shape=", input.shape(), ")."); - } - return; - } - // Tensor chunk handled by each CUDA block constexpr size_t CHUNK_DIM_Y = CAST_DBIAS_ONLY ? 128 : 64; constexpr size_t CHUNK_DIM_X = CAST_DBIAS_ONLY ? 128 : 64; @@ -614,8 +606,6 @@ void quantize(const Tensor &input, const Tensor *act_input, const Tensor *noop, use_rowwise_scaling ? reinterpret_cast(output->scale_inv.dptr) : nullptr; e8m0_t *const scales_colwise_ptr = use_colwise_scaling ? reinterpret_cast(output->columnwise_scale_inv.dptr) : nullptr; - const size_t dbias_rows = blocks_Y; - const size_t dbias_cols = cols; ScalingType scaling_type; if (use_rowwise_scaling && (!use_colwise_scaling)) { @@ -626,10 +616,16 @@ void quantize(const Tensor &input, const Tensor *act_input, const Tensor *noop, scaling_type = ScalingType::BIDIMENSIONAL; } + // Workspace for dbias + const size_t dbias_rows = blocks_Y; + const size_t dbias_cols = cols; if constexpr (IS_DBIAS) { NVTE_CHECK(dbias->data.dtype == input.dtype(), "DBias must have the same type as input."); NVTE_CHECK(dbias->data.shape == std::vector{cols}, "Wrong shape of DBias."); NVTE_CHECK(workspace != nullptr, "Workspace must be a tensor."); + NVTE_CHECK(dbias_rows > 0 && dbias_cols > 0, + "Invalid workspace shape for DBias computation (input shape=", + input.shape(), ", workspace shape=(", dbias_rows, ",", dbias_cols, "))."); if (workspace->data.dptr == nullptr) { workspace->data.shape = {dbias_rows, dbias_cols}; @@ -638,6 +634,14 @@ void quantize(const Tensor &input, const Tensor *act_input, const Tensor *noop, } } + // Skip kernel if tensor size is zero + if (rows == 0 || cols == 0) { + if constexpr (IS_DBIAS) { + NVTE_ERROR("Invalid tensor shape for DBias computation (shape=", input.shape(), ")."); + } + return; + } + float *const workspace_ptr = IS_DBIAS ? reinterpret_cast(workspace->data.dptr) : nullptr; float *const amax_ptr = reinterpret_cast(output->amax.dptr); const float *noop_ptr = reinterpret_cast(noop->data.dptr); From e05f3531ddf2b6b2233703df8b8d1b2f99143462 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 9 Apr 2026 00:30:41 +0000 Subject: [PATCH 4/4] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- transformer_engine/common/cast/fp8/quantize_fp8.cuh | 4 ++-- transformer_engine/common/cast/mxfp8/quantize_mxfp8.cuh | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/transformer_engine/common/cast/fp8/quantize_fp8.cuh b/transformer_engine/common/cast/fp8/quantize_fp8.cuh index e42cad05bb..60699199b6 100644 --- a/transformer_engine/common/cast/fp8/quantize_fp8.cuh +++ b/transformer_engine/common/cast/fp8/quantize_fp8.cuh @@ -415,8 +415,8 @@ void quantize_2D(const Tensor &input, const Tensor *act_input, Tensor *output, T NVTE_CHECK(dbias->data.shape == std::vector{cols}, "Wrong shape of DBias."); NVTE_CHECK(workspace != nullptr, "Workspace must be a tensor."); NVTE_CHECK(dbias_rows > 0 && dbias_cols > 0, - "Invalid workspace shape for DBias computation (input shape=", - input.shape(), ", workspace shape=(", dbias_rows, ",", dbias_cols, "))."); + "Invalid workspace shape for DBias computation (input shape=", input.shape(), + ", workspace shape=(", dbias_rows, ",", dbias_cols, "))."); if (workspace->data.dptr == nullptr) { workspace->data.shape = {dbias_rows, dbias_cols}; diff --git a/transformer_engine/common/cast/mxfp8/quantize_mxfp8.cuh b/transformer_engine/common/cast/mxfp8/quantize_mxfp8.cuh index 371299481e..81a7bc9dbb 100644 --- a/transformer_engine/common/cast/mxfp8/quantize_mxfp8.cuh +++ b/transformer_engine/common/cast/mxfp8/quantize_mxfp8.cuh @@ -624,8 +624,8 @@ void quantize(const Tensor &input, const Tensor *act_input, const Tensor *noop, NVTE_CHECK(dbias->data.shape == std::vector{cols}, "Wrong shape of DBias."); NVTE_CHECK(workspace != nullptr, "Workspace must be a tensor."); NVTE_CHECK(dbias_rows > 0 && dbias_cols > 0, - "Invalid workspace shape for DBias computation (input shape=", - input.shape(), ", workspace shape=(", dbias_rows, ",", dbias_cols, "))."); + "Invalid workspace shape for DBias computation (input shape=", input.shape(), + ", workspace shape=(", dbias_rows, ",", dbias_cols, "))."); if (workspace->data.dptr == nullptr) { workspace->data.shape = {dbias_rows, dbias_cols};