enable blockwise FP8 quantization on rocm#609
Conversation
| # TODO replace with call to fp8.py when recipe added. | ||
| recipe_available = not IS_HIP_EXTENSION and (get_device_compute_capability() >= (9, 0) and float(torch.version.cuda) >= 12.8) | ||
| if IS_HIP_EXTENSION: | ||
| recipe_available = get_device_compute_capability() >= (9, 0) |
There was a problem hiding this comment.
Wouldn't this be always True on ROCm TE?
| @@ -1 +1 @@ | |||
| /************************************************************************* | |||
| #ifndef __HIP_PLATFORM_AMD__ | ||
| #include <cudaTypedefs.h> | ||
| #endif | ||
| #include <cuda_bf16.h> | ||
| #include <cuda_runtime.h> | ||
|
|
||
| #include <cfloat> | ||
| #ifndef __HIP_PLATFORM_AMD__ | ||
| #include <cuda/barrier> | ||
| #endif | ||
|
|
||
| #include "common/common.h" | ||
| #include "common/recipe/recipe_common.cuh" | ||
| #include "common/util/cuda_runtime.h" | ||
| #ifndef __HIP_PLATFORM_AMD__ | ||
| #include "common/util/ptx.cuh" | ||
| #endif |
There was a problem hiding this comment.
These #includes should be already disabled via hipify, so probably no need for the #ifndefs here.
| static constexpr float max = 448.0f; | ||
| static constexpr float max_inverse = 1.0 / max; |
There was a problem hiding this comment.
Is this change necessary? fp8e4m3 max depends on the device type on AMD.
|
Could you give a description of what you want to achieve with this PR? My understanding is that block fp8 quantization relies on some upstream kernels that will need to be adapted for AMD. If you're just trying to enable the interface, I would argue that we should do this last, after we have a working quantization and GEMM path (and enabled and passing C++/Python tests). |
|
@alextmagro I tested with |
OK, in that case we need to add the cpp blockwise tests to the CMake file, and the pytorch test file to ci/pytorch.sh. |
Description
Please include a brief summary of the changes, relevant motivation and context.
Enable blockwise FP8 quantization on rocm
Fixes # (issue)
Type of change
Changes
Please list the changes introduced in this PR:
remove HIP guard in quantization.py
guard kernels using TMA in quantization.
add branch to handle rocm for different threads per wave
Checklist: