From 7a26a01a5ad522bb7123765f955381c7ef36448a Mon Sep 17 00:00:00 2001 From: Hang Gao Date: Thu, 2 Apr 2026 07:56:45 -0700 Subject: [PATCH] Fix int32 overflow in Triton _padded_copy pointer arithmetic MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit The _padded_copy and _binned_copy Triton kernels compute pointer offsets as `offset * NUM_COLUMNS` using int32 arithmetic. In Triton, int32 * int32 stays int32 without promotion to int64. When the product exceeds 2^31, the result wraps negative, creating a backward pointer that accesses memory before the tensor start — triggering "CUDA error: an illegal memory access was encountered". This triggers with expert parallelism at high token counts: the all-to-all dispatch can concentrate tokens on one rank due to routing imbalance. For hidden_size=4096, the overflow threshold is offset >= 524,288 tokens on a single rank. Fix: cast offset and index_b to tl.int64 before the multiplication in all 4 Triton kernels. The .to(tl.int64) adds one instruction per thread block — negligible performance impact. This is the same class of bug as triton-lang/triton#832. --- megablocks/backend/kernels.py | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/megablocks/backend/kernels.py b/megablocks/backend/kernels.py index b584ceed..f553038b 100644 --- a/megablocks/backend/kernels.py +++ b/megablocks/backend/kernels.py @@ -82,8 +82,8 @@ def _padded_copy( # need to reduce the result. Using atomics is slow, so we # do the reduce step in a second kernel. offset = index_a // TOP_K if A_TO_B else index_a - a += tl.multiple_of(offset * NUM_COLUMNS, NUM_COLUMNS) - b += tl.multiple_of(index_b * NUM_COLUMNS, NUM_COLUMNS) + a += tl.multiple_of(offset.to(tl.int64) * NUM_COLUMNS, NUM_COLUMNS) + b += tl.multiple_of(index_b.to(tl.int64) * NUM_COLUMNS, NUM_COLUMNS) offsets = tl.max_contiguous(tl.arange(0, BLOCK_X), BLOCK_X) # Load the scale, if requested. @@ -258,8 +258,8 @@ def _padded_copy_wgrad( # Offset the input and output pointers. wgrad += index_out - grad += tl.multiple_of((index_out // TOP_K) * NUM_COLUMNS, NUM_COLUMNS) - x += tl.multiple_of(index_x * NUM_COLUMNS, NUM_COLUMNS) + grad += tl.multiple_of((index_out // TOP_K).to(tl.int64) * NUM_COLUMNS, NUM_COLUMNS) + x += tl.multiple_of(index_x.to(tl.int64) * NUM_COLUMNS, NUM_COLUMNS) offsets = tl.max_contiguous(tl.arange(0, BLOCK_X), BLOCK_X) acc = tl.zeros((BLOCK_X,), dtype=tl.float32) @@ -365,8 +365,8 @@ def _binned_copy( # need to reduce the result. Using atomics is slow, so we # do the reduce step in a second kernel. offset = index_a // TOP_K if A_TO_B else index_a - a += tl.multiple_of(offset * NUM_COLUMNS, NUM_COLUMNS) - b += tl.multiple_of(index_b * NUM_COLUMNS, NUM_COLUMNS) + a += tl.multiple_of(offset.to(tl.int64) * NUM_COLUMNS, NUM_COLUMNS) + b += tl.multiple_of(index_b.to(tl.int64) * NUM_COLUMNS, NUM_COLUMNS) offsets = tl.max_contiguous(tl.arange(0, BLOCK_X), BLOCK_X) # Load the scale, if requested. @@ -500,8 +500,8 @@ def _binned_copy_wgrad( # Offset the input and output pointers. wgrad += index_out - grad += tl.multiple_of((index_out // TOP_K) * NUM_COLUMNS, NUM_COLUMNS) - x += tl.multiple_of(index_x * NUM_COLUMNS, NUM_COLUMNS) + grad += tl.multiple_of((index_out // TOP_K).to(tl.int64) * NUM_COLUMNS, NUM_COLUMNS) + x += tl.multiple_of(index_x.to(tl.int64) * NUM_COLUMNS, NUM_COLUMNS) offsets = tl.max_contiguous(tl.arange(0, BLOCK_X), BLOCK_X) acc = tl.zeros((BLOCK_X,), dtype=tl.float32)