You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
On master it is not possible to configure the CUDA MMQ kernel as a function of batch size and data type. This PR fixes that with a general refactor of the MMQ kernel that resembles more the mma FA kernel with a table of parameters rather than a bunch of macros and functions that return hard-coded values per architecture. Also I moved a lot of the code out of the main mmq.cuh file since it is pretty overloaded on master with 4k LoC. I removed the variable names mmq_x and mmq_y and replaced them with J and I (same as the FA kernels) to avoid confusion with the x and y data pointers. There are no (intentional) functional changes from this PR other than:
On master all template specializations in terms of tile sizes are being compiled for both the high-performance version without out-of-bounds checks in src0->ne[1] direction as well as the fallback version with those checks (something like ~5% end-to-end performance difference). However, for the fallback case it should be fine to compile fewer template specializations; a lot of them are just there to make pp snappier for short prompts where the number of tokens is not necessarily cleanly divided by e.g. 64 or 128. So for the fallback case I reduced the template specializations to only powers of 2. Longer-term we can consider adding a compilation option like GGML_CUDA_FULL as an opt-in for template specializations that are rarely useful but blow up the compilation time.
On master __launch_bounds__ is optional, with this PR it becomes mandatory in the configuration. This should only affect RDNA1 where a targeted occupancy of 2 is now given.
Going forward this PR will enable:
Better kernel tuning, particularly without side effects where the performance for some data types would increase at the cost of a performance regression for other types. On master the kernels are tuned exclusively for large batch sizes, better tuning for small batch sizes should help with speculative decoding performance.
Removal of the legacy SRAM data layout for __dp4a with 4 byte loads, to be replaced with 16 byte loads (~10% end-to-end speedup for e.g. P40) that can to a large degree re-use the SRAM layout for tensor cores. Originally I was going to do this transition first so code is being removed before the refactor but this triggered performance regressions for some combinations of GPUs and data types. So I'm taking a more granular approach where I will do the transition piece-by-piece; the refactor in this PR still has some WIPs and inconsistencies that I will gradually phase out.
Support for converting quantized data and activations to FP16 in SRAM. This will be useful for general Volta performance as well as FP16/BF16/FP32 MoE performance on all GPUs.
On my NVIDIA hardware I am seeing no changes to performance beyond statistical fluctuations. On my AMD hardware however the performance is changing and quite frankly I don't understand why, the only thing that should have really changed is that there are now some if constexpr rather than macros. On average the impact is at least slightly positive; truth be told these unforeseen and hard to explain changes in AMD performance from seemingly innocuous code changes are a major reason why I want a more granular way to configure the kernel in the first place.
Should be fixed now, the Blackwell config was wrong but in such a way that did not consistently result in incorrect outputs.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
ggmlchanges relating to the ggml tensor library for machine learningNvidia GPUIssues specific to Nvidia GPUs
2 participants
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
On master it is not possible to configure the CUDA MMQ kernel as a function of batch size and data type. This PR fixes that with a general refactor of the MMQ kernel that resembles more the mma FA kernel with a table of parameters rather than a bunch of macros and functions that return hard-coded values per architecture. Also I moved a lot of the code out of the main
mmq.cuhfile since it is pretty overloaded on master with 4k LoC. I removed the variable namesmmq_xandmmq_yand replaced them withJandI(same as the FA kernels) to avoid confusion with thexandydata pointers. There are no (intentional) functional changes from this PR other than:src0->ne[1]direction as well as the fallback version with those checks (something like ~5% end-to-end performance difference). However, for the fallback case it should be fine to compile fewer template specializations; a lot of them are just there to make pp snappier for short prompts where the number of tokens is not necessarily cleanly divided by e.g. 64 or 128. So for the fallback case I reduced the template specializations to only powers of 2. Longer-term we can consider adding a compilation option likeGGML_CUDA_FULLas an opt-in for template specializations that are rarely useful but blow up the compilation time.__launch_bounds__is optional, with this PR it becomes mandatory in the configuration. This should only affect RDNA1 where a targeted occupancy of 2 is now given.Going forward this PR will enable:
__dp4awith 4 byte loads, to be replaced with 16 byte loads (~10% end-to-end speedup for e.g. P40) that can to a large degree re-use the SRAM layout for tensor cores. Originally I was going to do this transition first so code is being removed before the refactor but this triggered performance regressions for some combinations of GPUs and data types. So I'm taking a more granular approach where I will do the transition piece-by-piece; the refactor in this PR still has some WIPs and inconsistencies that I will gradually phase out.Requirements