Update flashMLA implementation on MACA#27
Conversation
There was a problem hiding this comment.
Code Review
This pull request refactors the FlashMLA implementation, introducing support for sparse decoding, sparse prefill attention, and context parallelism, while removing obsolete benchmark scripts and instantiations. The code review highlights several critical issues, including compilation errors in flash_fwd_dispatch_template.h due to a missing is_bf16 member in SparsePrefillParams, potential division-by-zero bugs and a backwards error message in get_attn_impl_meta, and a potential null pointer dereference in fwd_kvcache_mla when indices_all_valid_per_q is not provided. Additionally, the reviewer noted redundant duplicate checks on seqlens_k and a minor typo in flash_api.cpp.
Important
The consumer version of Gemini Code Assist on GitHub is being sunset. Starting June 18, 2026, new organization installations will be blocked, and all code review activity will officially cease on July 17, 2026.
For more details on the timeline and next steps, please review the Help Documentation.
| assert(params.is_bf16 && "sparse prefill only support bf16"); | ||
| constexpr static int kBlockM = 64; |
There was a problem hiding this comment.
params.is_bf16 is accessed here, but the SparsePrefillParams struct does not have a member named is_bf16. This will cause a compilation failure. Since sparse_prefill_fwd already enforces kBFloat16 via TORCH_CHECK at the API level, this assertion is redundant and should be removed.
| assert(params.is_bf16 && "sparse prefill only support bf16"); | |
| constexpr static int kBlockM = 64; | |
| constexpr static int kBlockM = 64; |
| assert(params.is_bf16 && "sparse prefill only support bf16"); | ||
| constexpr static int kBlockM = 64; |
There was a problem hiding this comment.
params.is_bf16 is accessed here, but the SparsePrefillParams struct does not have a member named is_bf16. This will cause a compilation failure. Since sparse_prefill_fwd already enforces kBFloat16 via TORCH_CHECK at the API level, this assertion is redundant and should be removed.
| assert(params.is_bf16 && "sparse prefill only support bf16"); | |
| constexpr static int kBlockM = 64; | |
| constexpr static int kBlockM = 64; |
| TORCH_CHECK(h_q_.has_value()); | ||
| int h_q = h_q_.value(); | ||
| TORCH_CHECK(h_q % h_k == 0, "h_k must be divisible by h_q."); | ||
| int s_q = num_q_tokens_per_head_k * h_k / h_q; | ||
| // BF16/FP16 + Sparse MLA | ||
| return { | ||
| std::max((sm_count/2) / h_k / (mctlass::ceil_div(h_q/h_k, 2*64) * s_q), 1), | ||
| 5, | ||
| block_n // block_n | ||
| }; | ||
| } | ||
| } else { | ||
| TORCH_CHECK(!is_fp8_kvcache, "FP8 KV Cache is not supported."); | ||
| // Dense BF16/FP8 MLA | ||
| return { | ||
| std::max(sm_count / h_k / mctlass::ceil_div(num_q_tokens_per_head_k, block_m), 1), | ||
| 5, | ||
| block_n, | ||
| }; | ||
| } |
There was a problem hiding this comment.
There are multiple potential division-by-zero bugs and a backwards error message in this block:
- The error message
"h_k must be divisible by h_q."is backwards; it should be"h_q must be divisible by h_k.". - If
h_kis0,h_q % h_kwill cause a division-by-zero crash. - If
s_qis0(which happens ifnum_q_tokens_per_head_k * h_k < h_q), thenmctlass::ceil_div(h_q/h_k, 2*64) * s_qis0, leading to a division-by-zero crash on line 79. - If
h_kis0,std::max(sm_count / h_k / ...)on line 88 will cause a division-by-zero crash.
We should add defensive checks to ensure h_k > 0, h_q > 0, and s_q > 0 before performing these divisions.
TORCH_CHECK(h_q_.has_value());
int h_q = h_q_.value();
TORCH_CHECK(h_k > 0, "h_k must be greater than 0.");
TORCH_CHECK(h_q > 0, "h_q must be greater than 0.");
TORCH_CHECK(h_q % h_k == 0, "h_q must be divisible by h_k.");
int s_q = num_q_tokens_per_head_k * h_k / h_q;
TORCH_CHECK(s_q > 0, "s_q must be greater than 0 to avoid division by zero.");
// BF16/FP16 + Sparse MLA
return {
std::max((sm_count/2) / h_k / (mctlass::ceil_div(h_q/h_k, 2*64) * s_q), 1),
5,
block_n // block_n
};
}
} else {
TORCH_CHECK(!is_fp8_kvcache, "FP8 KV Cache is not supported.");
TORCH_CHECK(h_k > 0, "h_k must be greater than 0.");
// Dense BF16/FP8 MLA
return {
std::max(sm_count / h_k / mctlass::ceil_div(num_q_tokens_per_head_k, block_m), 1),
5,
block_n,
};| bool is_sparse_attn = indices.has_value(); | ||
| int topk = is_sparse_attn ? indices->size(-1) : -1; | ||
| TORCH_CHECK(!is_sparse_attn || indices->dtype() == torch::kInt32, "indices must have dtype int32"); | ||
| TORCH_CHECK(!is_sparse_attn || indices->stride(-1) == 1, "indices must have contiguous last dimension"); | ||
| TORCH_CHECK(!is_sparse_attn || indices_all_valid_per_q->dtype() == torch::kBool, "indices_all_valid_per_q must have dtype bool"); | ||
| TORCH_CHECK(!is_sparse_attn || indices_all_valid_per_q->stride(-1) == 1, "indices_all_valid_per_q must have contiguous last dimension"); |
There was a problem hiding this comment.
If is_sparse_attn is true, indices_all_valid_per_q is dereferenced via -> on lines 138 and 139. However, there is no check to ensure that indices_all_valid_per_q actually has a value (i.e., is not std::nullopt). If it is std::nullopt, this will cause a null pointer dereference crash. We should add a defensive check to ensure indices_all_valid_per_q.has_value() when is_sparse_attn is true.
bool is_sparse_attn = indices.has_value();
int topk = is_sparse_attn ? indices->size(-1) : -1;
TORCH_CHECK(!is_sparse_attn || indices->dtype() == torch::kInt32, "indices must have dtype int32");
TORCH_CHECK(!is_sparse_attn || indices->stride(-1) == 1, "indices must have contiguous last dimension");
TORCH_CHECK(!is_sparse_attn || indices_all_valid_per_q.has_value(), "indices_all_valid_per_q must be provided when indices is provided");
TORCH_CHECK(!is_sparse_attn || indices_all_valid_per_q->dtype() == torch::kBool, "indices_all_valid_per_q must have dtype bool");
TORCH_CHECK(!is_sparse_attn || indices_all_valid_per_q->stride(-1) == 1, "indices_all_valid_per_q must have contiguous last dimension");| if (is_sparse_attn){ | ||
| TORCH_CHECK(num_heads_ori >= 64 || seqlen_q_ori == 1, "sparse decoding head q must greter than 64 when seqlen q > 1"); | ||
| } |
| CHECK_DEVICE(seqlens_k); | ||
| TORCH_CHECK(seqlens_k.is_contiguous()); | ||
| TORCH_CHECK(seqlens_k.dtype() == torch::kInt32); | ||
| if (is_sparse_attn) | ||
| TORCH_CHECK(h_q.has_value(), "num_heads_q must be provided when topk is provided"); | ||
|
|
||
| CHECK_DEVICE(seqlens_k); | ||
| TORCH_CHECK(seqlens_k.is_contiguous()); | ||
| TORCH_CHECK(seqlens_k.dtype() == torch::kInt32); | ||
|
|
There was a problem hiding this comment.
The checks on seqlens_k (lines 374-376 and lines 380-382) are completely identical and redundant. We can safely remove the duplicate block to clean up the code.
CHECK_DEVICE(seqlens_k);
TORCH_CHECK(seqlens_k.is_contiguous());
TORCH_CHECK(seqlens_k.dtype() == torch::kInt32);
if (is_sparse_attn)
TORCH_CHECK(h_q.has_value(), "num_heads_q must be provided when topk is provided");
1,Support Deepseek V3 and Deepseek V3.2.
2,Improve kernel performance.