Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 1 addition & 2 deletions tests/cpp/operator/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ add_executable(test_operator
test_multi_padding.cu
test_multi_unpadding.cu
test_causal_softmax.cu
test_swizzle.cu #CUDA-only test
test_swizzle.cu
test_swap_first_dims.cu
test_grouped_gemm.cu #CUDA-only test
../test_common.cu)
Expand All @@ -42,7 +42,6 @@ if(USE_ROCM)
# Remove CUDA-only tests and add ROCm specific ones
list(REMOVE_ITEM test_cuda_sources
test_cast_float8blockwise.cu
test_swizzle.cu
test_grouped_gemm.cu)
list(APPEND test_cuda_sources
test_dequantize_nvfp4.cu
Expand Down
112 changes: 95 additions & 17 deletions tests/cpp/operator/test_cublaslt_gemm.cu
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
#include <gtest/gtest.h>
#include <transformer_engine/cast.h>
#include <transformer_engine/gemm.h>
#include <transformer_engine/swizzle.h>
#include <transformer_engine/transformer_engine.h>
#include "../test_common.h"

Expand All @@ -30,7 +31,15 @@ std::vector<std::tuple<size_t, size_t, size_t>> test_case_sizes = {

std::vector<std::tuple<size_t, size_t, size_t>> test_case_sizes_mxfp8 = {
{32, 128, 16},
{64, 128, 32},
{128, 128, 64},
{64, 256, 32},
{128, 384, 64},
{256, 512, 128},
{512, 1024, 256},
{768, 3072, 4096},
{1024, 2048, 128},
{4096, 8192, 64},
};

// A, B, Bias, Gelu, D
Expand Down Expand Up @@ -303,6 +312,40 @@ void cpu_rowwise_to_columnwise(
}
}

// Swizzle MXFP8 scale_inv of a test::Tensor in-place for gfx1250.
static void swizzle_mxfp8_scales(test::Tensor &t, bool rowwise) {
using namespace transformer_engine;
void *scale_ptr = rowwise ? t.rowwise_scale_inv_dptr()
: t.columnwise_scale_inv_dptr();
if (!scale_ptr) return;
const NVTEShape scale_shape = rowwise ? t.rowwise_scale_inv_shape()
: t.columnwise_scale_inv_shape();
const NVTEShape data_shape = rowwise ? t.rowwise_shape()
: t.columnwise_shape();
size_t num_scales = 1;
for (size_t d = 0; d < scale_shape.ndim; d++) num_scales *= scale_shape.data[d];
uint8_t *d_tmp = nullptr;
NVTE_CHECK_CUDA(cudaMalloc(&d_tmp, num_scales));
TensorWrapper input_tw(NVTE_MXFP8_1D_SCALING);
TensorWrapper output_tw(NVTE_MXFP8_1D_SCALING);
output_tw.set_with_gemm_swizzled_scales(true);
if (rowwise) {
input_tw.set_rowwise_data(nullptr, t.dtype(), data_shape);
input_tw.set_rowwise_scale_inv(scale_ptr, DType::kFloat8E8M0, scale_shape);
output_tw.set_rowwise_data(nullptr, t.dtype(), data_shape);
output_tw.set_rowwise_scale_inv(d_tmp, DType::kFloat8E8M0, scale_shape);
} else {
input_tw.set_columnwise_data(nullptr, t.dtype(), data_shape);
input_tw.set_columnwise_scale_inv(scale_ptr, DType::kFloat8E8M0, scale_shape);
output_tw.set_columnwise_data(nullptr, t.dtype(), data_shape);
output_tw.set_columnwise_scale_inv(d_tmp, DType::kFloat8E8M0, scale_shape);
}
nvte_swizzle_scaling_factors(input_tw.data(), output_tw.data(), 0);
NVTE_CHECK_CUDA(cudaDeviceSynchronize());
NVTE_CHECK_CUDA(cudaMemcpy(scale_ptr, d_tmp, num_scales, cudaMemcpyDeviceToDevice));
NVTE_CHECK_CUDA(cudaFree(d_tmp));
}

std::pair<double, double> getTestTolerances(const DType type, bool use_fp8, bool use_mxfp8) {
auto [atol, rtol] = getTolerances(type);

Expand All @@ -318,6 +361,12 @@ std::pair<double, double> getTestTolerances(const DType type, bool use_fp8, bool
else if (use_fp8) {
atol = 1e-3;
rtol = std::max(rtol, 1e-2);
// Relax for gfx1250
cudaDeviceProp prop;
(void)cudaGetDeviceProperties(&prop, 0);
if (prop.major == 12 && type == DType::kBFloat16) {
rtol = std::max(rtol, 5e-2);
}
}
else if (type == DType::kBFloat16) {
//relax for certain prime number TN gemm
Expand Down Expand Up @@ -496,6 +545,31 @@ void performTest(const TestParams& params) {
#endif
Tensor Workspace("Workspace", TShape{ workspace_size }, DType::kByte);

//perform the reference gemm on GPU (before swizzle, which modifies scales in-place)
Tensor RefD("RefD", TShape{ params.n, params.m }, dtype);
Tensor RefPreGeluOut;

if (params.use_gelu) {
RefPreGeluOut = Tensor("RefPreGeluOut", TShape{ params.n, params.m }, gelu_type);
}

run_reference<A_Type, B_Type, Bias_Type, Gelu_Type, D_Type>(
params,
A,
B,
params.use_bias ? &bias : nullptr,
D,
RefD,
params.use_gelu ? &RefPreGeluOut : nullptr);

// On gfx1250, hipBLASLt MXFP8 kernels expect pre-swizzled scales.
if (use_mxfp8 && prop.major == 12) {
if (!a_colwise) swizzle_mxfp8_scales(A, true);
if (a_colwise) swizzle_mxfp8_scales(A, false);
if (!b_colwise) swizzle_mxfp8_scales(B, true);
if (b_colwise) swizzle_mxfp8_scales(B, false);
}

//perform the gemm in GPU
nvte_cublas_gemm(A.data(),
B.data(),
Expand All @@ -517,23 +591,6 @@ void performTest(const TestParams& params) {
pre_gelu_out.to_cpu();
}

//perform the reference gemm on GPU
Tensor RefD("RefD", TShape{ params.n, params.m }, dtype);
Tensor RefPreGeluOut;

if (params.use_gelu) {
RefPreGeluOut = Tensor("RefPreGeluOut", TShape{ params.n, params.m }, gelu_type);
}

run_reference<A_Type, B_Type, Bias_Type, Gelu_Type, D_Type>(
params,
A,
B,
params.use_bias ? &bias : nullptr,
D,
RefD,
params.use_gelu ? &RefPreGeluOut : nullptr);

// check if error message happens in running
(void)cudaDeviceSynchronize();
auto err = cudaGetLastError();
Expand Down Expand Up @@ -582,6 +639,17 @@ void performDqTest(const TestParams &params) {
GTEST_SKIP() << "MXFP8 is not supported in current config";
}

// hipBLASLt on gfx950 produces incorrect results for certain small MXFP8
// GEMMs with non-TN layouts.
if (prop.major == 9 && prop.minor == 5) {
const bool is_NN = !params.transa && !params.transb;
const bool is_NT = !params.transa && params.transb;
if ((is_NN && params.m == 64) ||
(is_NT && params.m > 32 && params.m <= 128 && params.n <= 64)) {
GTEST_SKIP() << "hipBLASLt MXFP8 non-TN GEMM with small M/N is not supported on gfx950";
}
}

DType ref_type = dtype;
TShape a_shape = params.transa ? TShape{params.m, params.k} : TShape{params.k, params.m};
TShape b_shape = params.transb ? TShape{params.k, params.n} : TShape{params.n, params.k};
Expand All @@ -605,6 +673,16 @@ void performDqTest(const TestParams &params) {
nvte_dequantize(A_fp8.data(), A_ref.data(), 0);
nvte_dequantize(B_fp8.data(), B_ref.data(), 0);

// On gfx1250, hipBLASLt MXFP8 kernels expect pre-swizzled scales.
if (prop.major == 12) {
const bool a_colwise = !params.transa;
const bool b_colwise = params.transb;
if (!a_colwise) swizzle_mxfp8_scales(A_fp8, true);
if (a_colwise) swizzle_mxfp8_scales(A_fp8, false);
if (!b_colwise) swizzle_mxfp8_scales(B_fp8, true);
if (b_colwise) swizzle_mxfp8_scales(B_fp8, false);
}

Tensor bias;
Tensor pre_gelu_out;

Expand Down
180 changes: 180 additions & 0 deletions tests/cpp/operator/test_swizzle.cu
Original file line number Diff line number Diff line change
Expand Up @@ -166,3 +166,183 @@ INSTANTIATE_TEST_SUITE_P(
std::to_string(std::get<2>(info.param));
return name;
});

#ifdef __HIP_PLATFORM_AMD__

// MX pre-swizzle test (gfx1250 Tensile 3D layout)
//
// Tensile 3D: {K_scale, M}.reshape({K_scale, padM/4, 4}).permute({1, 0, 2})
// For source (m, k): dst = (m/4) * (K*4) + k*4 + (m%4)

// CPU reference for Tensile 3D MX scale pre-swizzle.
// Row-major input [M, K], output is a flat permuted array.
void compute_ref_mx_swizzle_row(const uint8_t *h_input, uint8_t *h_output,
const int M, const int K,
const int orig_M, const int orig_K) {
constexpr int GROUP = 4;
for (int m = 0; m < M; m++) {
for (int k = 0; k < K; k++) {
uint8_t val = 127; // E8M0 identity: 2^0 = 1.0
if (m < orig_M && k < orig_K) {
val = h_input[m * orig_K + k];
}
int group = k / GROUP;
int within = k % GROUP;
int dst = group * (M * GROUP) + m * GROUP + within;
h_output[dst] = val;
}
}
}

void compute_ref_mx_swizzle_col(const uint8_t *h_input, uint8_t *h_output,
const int M, const int K,
const int orig_M, const int orig_K) {
constexpr int GROUP = 4;
for (int m = 0; m < M; m++) {
for (int k = 0; k < K; k++) {
uint8_t val = 127;
if (m < orig_M && k < orig_K) {
val = h_input[k * orig_M + m];
}
int group = k / GROUP;
int within = k % GROUP;
int dst = group * (M * GROUP) + m * GROUP + within;
h_output[dst] = val;
}
}
}

static size_t roundup_sz(size_t val, size_t mult) {
return ((val + mult - 1) / mult) * mult;
}

class MxSwizzleTestSuite
: public ::testing::TestWithParam<
std::tuple<std::pair<int, int>, bool>> {};

TEST_P(MxSwizzleTestSuite, TestMxSwizzle) {
using namespace transformer_engine;
using namespace test;

cudaDeviceProp prop;
cudaGetDeviceProperties(&prop, 0);
if (prop.major < 12) {
GTEST_SKIP() << "MXFP8 pre-swizzle is only supported on gfx1250";
}

const auto dims = std::get<0>(GetParam());
const bool rowwise = std::get<1>(GetParam());

// Original (unpadded) scale dimensions
const size_t orig_M = dims.first;
const size_t orig_K = dims.second;

// Padded dimensions: K-tiled layout requires K_scale padded to multiple of 4
const size_t M = orig_M;
const size_t K = roundup_sz(orig_K, 4);

// Allocate host input (unpadded) and fill with random data
const size_t input_size = orig_M * orig_K;
std::unique_ptr<uint8_t[]> h_input(new uint8_t[input_size]);
std::mt19937 rng(42);
for (size_t i = 0; i < input_size; i++) {
h_input[i] = static_cast<uint8_t>(rng() % 256);
}

// Allocate device input
uint8_t *d_input = nullptr;
NVTE_CHECK_CUDA(cudaMalloc(&d_input, input_size));
NVTE_CHECK_CUDA(cudaMemcpy(d_input, h_input.get(), input_size, cudaMemcpyHostToDevice));

// Allocate device output (padded size)
const size_t output_size = M * K;
uint8_t *d_output = nullptr;
NVTE_CHECK_CUDA(cudaMalloc(&d_output, output_size));
NVTE_CHECK_CUDA(cudaMemset(d_output, 0, output_size));

// Build TensorWrapper for input and output
TensorWrapper input_tw(NVTE_MXFP8_1D_SCALING);
TensorWrapper output_tw(NVTE_MXFP8_1D_SCALING);
output_tw.set_with_gemm_swizzled_scales(true);

// Data shape must be consistent with scale shape for validation.
// Scale shapes use padded K; data shapes use unpadded dims
// (kernel derives original_M/K from them).
if (rowwise) {
std::vector<size_t> data_shape_in = {orig_M, orig_K * 32};
std::vector<size_t> data_shape_out = {M, K * 32};
std::vector<size_t> scale_shape_in = {M, K};
std::vector<size_t> scale_shape_out = {M, K};
input_tw.set_rowwise_data(nullptr, DType::kFloat8E4M3, data_shape_in);
input_tw.set_rowwise_scale_inv(d_input, DType::kFloat8E8M0, scale_shape_in);
output_tw.set_rowwise_data(nullptr, DType::kFloat8E4M3, data_shape_out);
output_tw.set_rowwise_scale_inv(d_output, DType::kFloat8E8M0, scale_shape_out);
} else {
std::vector<size_t> data_shape_in = {orig_K * 32, orig_M};
std::vector<size_t> data_shape_out = {K * 32, M};
std::vector<size_t> scale_shape_in = {K, M};
std::vector<size_t> scale_shape_out = {K, M};
input_tw.set_columnwise_data(nullptr, DType::kFloat8E4M3, data_shape_in);
input_tw.set_columnwise_scale_inv(d_input, DType::kFloat8E8M0, scale_shape_in);
output_tw.set_columnwise_data(nullptr, DType::kFloat8E4M3, data_shape_out);
output_tw.set_columnwise_scale_inv(d_output, DType::kFloat8E8M0, scale_shape_out);
}

nvte_swizzle_scaling_factors(input_tw.data(), output_tw.data(), 0);

NVTE_CHECK_CUDA(cudaDeviceSynchronize());

// Copy output back to host
std::unique_ptr<uint8_t[]> h_output(new uint8_t[output_size]);
NVTE_CHECK_CUDA(cudaMemcpy(h_output.get(), d_output, output_size, cudaMemcpyDeviceToHost));

// Compute reference
std::unique_ptr<uint8_t[]> h_ref(new uint8_t[output_size]);
memset(h_ref.get(), 0, output_size);
if (rowwise) {
compute_ref_mx_swizzle_row(h_input.get(), h_ref.get(), M, K, orig_M, orig_K);
} else {
compute_ref_mx_swizzle_col(h_input.get(), h_ref.get(), M, K, orig_M, orig_K);
}

// Compare
compareResults("mx_swizzle", h_output.get(), h_ref.get(), output_size);

cudaFree(d_input);
cudaFree(d_output);
}

namespace {

// Scale dimensions (M_scale, K_scale).
// K_scale will be padded to multiple of 4 by the test.
std::vector<std::pair<int, int>> mx_scale_dims = {
{4, 4}, // minimal
{8, 4}, // small
{32, 8}, // medium
{64, 16}, // larger
{96, 8}, // non-power-of-2 M
{128, 32}, // big
{256, 64}, // bigger
{512, 128}, // stress inter-tile
{1024, 256}, // large
{4096, 256}, // max stress
};

} // namespace

INSTANTIATE_TEST_SUITE_P(
OperatorTest,
MxSwizzleTestSuite,
::testing::Combine(
::testing::ValuesIn(mx_scale_dims),
::testing::Values(true, false)
),
[](const testing::TestParamInfo<MxSwizzleTestSuite::ParamType>& info) {
std::string name = "M" + std::to_string(std::get<0>(info.param).first) +
"_K" + std::to_string(std::get<0>(info.param).second) +
(std::get<1>(info.param) ? "_row" : "_col");
return name;
});

#endif // __HIP_PLATFORM_AMD__
Loading
Loading