-
Notifications
You must be signed in to change notification settings - Fork 29
speed up nvte_multi_padding / nvte_multi_unpadding #592
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
matthiasdiener
wants to merge
9
commits into
dev
Choose a base branch
from
mdiener/speedup-pad-unpad
base: dev
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
+115
−0
Open
Changes from all commits
Commits
Show all changes
9 commits
Select commit
Hold shift + click to select a range
ce6e865
speed up nvte_multi_padding / nvte_multi_unpadding
matthiasdiener a470ecb
factor out binary search
matthiasdiener 45b996a
Merge branch 'dev' into mdiener/speedup-pad-unpad
matthiasdiener 5f011ae
guard
matthiasdiener a35459c
Merge remote-tracking branch 'origin/dev' into mdiener/speedup-pad-unpad
matthiasdiener cb9221d
factor out cols
matthiasdiener dc708c6
bump n_warps_per_tile
matthiasdiener 84b7d09
use NT stores
matthiasdiener 710d3c0
Merge branch 'dev' into mdiener/speedup-pad-unpad
matthiasdiener File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Some comments aren't visible on the classic Files Changed page.
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,4 +1,6 @@ | ||
| /************************************************************************* | ||
| * This file was modified for portability to AMDGPU | ||
| * Copyright (c) 2026, Advanced Micro Devices, Inc. All rights reserved. | ||
| * Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. | ||
| * | ||
| * See LICENSE for license information. | ||
|
|
@@ -13,13 +15,42 @@ | |
|
|
||
| #include "../common.h" | ||
| #include "../utils.cuh" | ||
| #ifdef __HIP_PLATFORM_AMD__ | ||
| #include "rocm_device_utils.cuh" // for rocm_upper_bound(), NTVec | ||
| #endif | ||
|
|
||
| namespace transformer_engine { | ||
|
|
||
| namespace { | ||
|
|
||
| #ifdef __HIP_PLATFORM_AMD__ | ||
| // Non-temporal store helper: uses NT store for full aligned vectors, | ||
| // falls back to element-wise for partial/unaligned cases. | ||
| // Note: NT loads were also benchmarked but hurt performance. | ||
| template <uint32_t nvec, typename Type> | ||
| __device__ __forceinline__ void nt_store_to_elts(const Vec<Type, nvec>& v, | ||
| Type* ptr, int count) { | ||
| constexpr size_t BYTES = nvec * sizeof(Type); | ||
| if (count == nvec && reinterpret_cast<uint64_t>(ptr) % BYTES == 0) { | ||
| NTVec<Type, nvec> nt; | ||
| #pragma unroll | ||
| for (int i = 0; i < nvec; i++) nt.val[i] = v.data.elt[i]; | ||
| nt.nt_store(ptr); | ||
| } else { | ||
| #pragma unroll | ||
| for (int i = 0; i < nvec; i++) { | ||
| if (i < count) ptr[i] = v.data.elt[i]; | ||
| } | ||
| } | ||
| } | ||
| #endif | ||
|
|
||
| // Parameters to tune | ||
| #ifdef __HIP_PLATFORM_AMD__ | ||
| constexpr int n_warps_per_tile = 16; | ||
| #else | ||
| constexpr int n_warps_per_tile = 4; | ||
| #endif | ||
| constexpr int threads_per_block = THREADS_PER_WARP * n_warps_per_tile; | ||
| constexpr int desired_load_store_size = 8; | ||
| constexpr int kMaxTensorsPerKernel = 64; // Args must be <4 KB | ||
|
|
@@ -65,15 +96,22 @@ __global__ void __launch_bounds__(threads_per_block) multi_padding_kernel(MultiP | |
| constexpr int n_iterations = THREADS_PER_WARP / n_warps_per_tile; | ||
|
|
||
| // Find tensor corresponding to block | ||
| #ifdef __HIP_PLATFORM_AMD__ | ||
| const int tensor_id = rocm_upper_bound(args.block_range, args.num_tensors, bid); | ||
| #else | ||
| int tensor_id = 0; | ||
| while (args.block_range[tensor_id + 1] <= bid) { | ||
| ++tensor_id; | ||
| } | ||
| #endif | ||
| const Type* input = reinterpret_cast<const Type*>(args.input_list[tensor_id]); | ||
| Type* output = reinterpret_cast<Type*>(args.output_list[tensor_id]); | ||
| const int num_rows = args.num_rows_list[tensor_id]; | ||
| const int padded_num_rows = args.padded_num_rows_list[tensor_id]; | ||
| const int row_length = args.row_length_list[tensor_id]; | ||
| #ifdef __HIP_PLATFORM_AMD__ | ||
| const bool inplace = (input == output); | ||
| #endif | ||
|
|
||
| // Find position of tile within tensor | ||
| const int num_tiles_n = (row_length + tile_dim_n - 1) / tile_dim_n; | ||
|
|
@@ -83,6 +121,36 @@ __global__ void __launch_bounds__(threads_per_block) multi_padding_kernel(MultiP | |
| const int tile_row = tile_id_m * tile_dim_m; | ||
| const int tile_col = tile_id_n * tile_dim_n; | ||
|
|
||
| #ifdef __HIP_PLATFORM_AMD__ | ||
| // Process subtiles with vectorized loads/stores | ||
| #pragma unroll | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Did you try #pragma unroll 2 here? If we have the registers available that might help performance. |
||
| for (int iter = 0; iter < n_iterations; ++iter) { | ||
| const int i1 = tidy + iter * bdimy; | ||
| const int j1 = tidx; | ||
| const int col = tile_col + j1 * nvec; | ||
| const int remaining = row_length - col; | ||
| const int valid_cols = remaining > 0 ? min(remaining, nvec) : 0; | ||
| #pragma unroll | ||
| for (int i2 = 0; i2 < nvec; ++i2) { | ||
| const int row = tile_row + i1 * nvec + i2; | ||
| if (row < num_rows) { | ||
| // Valid data row: skip copy when in-place | ||
| if (!inplace) { | ||
| const size_t offset = static_cast<size_t>(row) * row_length + col; | ||
| Vec v; | ||
| v.load_from_elts(input, offset, valid_cols); | ||
| nt_store_to_elts(v, output + offset, valid_cols); | ||
| } | ||
| } else if (row < padded_num_rows) { | ||
| // Padding row: fill with zeros | ||
| const size_t offset = static_cast<size_t>(row) * row_length + col; | ||
| Vec v; | ||
| v.clear(); | ||
| nt_store_to_elts(v, output + offset, valid_cols); | ||
| } | ||
| } | ||
| } | ||
| #else // !__HIP_PLATFORM_AMD__ | ||
| // Load input and store to registers | ||
| // Note: Each thread loads n_iterations subtiles, casts to output | ||
| // type, and transposes in registers. | ||
|
|
@@ -125,6 +193,7 @@ __global__ void __launch_bounds__(threads_per_block) multi_padding_kernel(MultiP | |
| } | ||
| } | ||
| } | ||
| #endif // __HIP_PLATFORM_AMD__ | ||
| } | ||
|
|
||
| template <int nvec, typename Type> | ||
|
|
@@ -150,14 +219,21 @@ __global__ void __launch_bounds__(threads_per_block) multi_unpadding_kernel(Mult | |
| constexpr int n_iterations = THREADS_PER_WARP / n_warps_per_tile; | ||
|
|
||
| // Find tensor corresponding to block | ||
| #ifdef __HIP_PLATFORM_AMD__ | ||
| const int tensor_id = rocm_upper_bound(args.block_range, args.num_tensors, bid); | ||
| #else | ||
| int tensor_id = 0; | ||
| while (args.block_range[tensor_id + 1] <= bid) { | ||
| ++tensor_id; | ||
| } | ||
| #endif | ||
| const Type* input = reinterpret_cast<const Type*>(args.input_list[tensor_id]); | ||
| Type* output = reinterpret_cast<Type*>(args.output_list[tensor_id]); | ||
| const int num_rows = args.num_rows_list[tensor_id]; | ||
| const int row_length = args.row_length_list[tensor_id]; | ||
| #ifdef __HIP_PLATFORM_AMD__ | ||
| const bool inplace = (input == output); | ||
| #endif | ||
|
|
||
| // Find position of tile within tensor | ||
| const int num_tiles_n = (row_length + tile_dim_n - 1) / tile_dim_n; | ||
|
|
@@ -167,6 +243,27 @@ __global__ void __launch_bounds__(threads_per_block) multi_unpadding_kernel(Mult | |
| const int tile_row = tile_id_m * tile_dim_m; | ||
| const int tile_col = tile_id_n * tile_dim_n; | ||
|
|
||
| #ifdef __HIP_PLATFORM_AMD__ | ||
| // Process subtiles with vectorized loads/stores | ||
| #pragma unroll | ||
| for (int iter = 0; iter < n_iterations; ++iter) { | ||
| const int i1 = tidy + iter * bdimy; | ||
| const int j1 = tidx; | ||
| const int col = tile_col + j1 * nvec; | ||
| const int remaining = row_length - col; | ||
| const int valid_cols = remaining > 0 ? min(remaining, nvec) : 0; | ||
| #pragma unroll | ||
| for (int i2 = 0; i2 < nvec; ++i2) { | ||
| const int row = tile_row + i1 * nvec + i2; | ||
| if (row < num_rows && !inplace) { | ||
| const size_t offset = static_cast<size_t>(row) * row_length + col; | ||
| Vec v; | ||
| v.load_from_elts(input, offset, valid_cols); | ||
| nt_store_to_elts(v, output + offset, valid_cols); | ||
| } | ||
| } | ||
| } | ||
| #else // !__HIP_PLATFORM_AMD__ | ||
| // Load input and store to registers | ||
| // Note: Each thread loads n_iterations subtiles, casts to output | ||
| // type, and transposes in registers. | ||
|
|
@@ -202,6 +299,7 @@ __global__ void __launch_bounds__(threads_per_block) multi_unpadding_kernel(Mult | |
| } | ||
| } | ||
| } | ||
| #endif // __HIP_PLATFORM_AMD__ | ||
| } | ||
|
|
||
| } // namespace | ||
|
|
||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What is the case where we hit non-aligned vectors? Isn't FP8/MXFP8 always padded to a multiple of 16 by default? Ideally we would template out the NT vs elementwise stores