Skip to content

Update flashMLA implementation on MACA#27

Open
zhan3916 wants to merge 2 commits into
mainfrom
mcFlashMLA
Open

Update flashMLA implementation on MACA#27
zhan3916 wants to merge 2 commits into
mainfrom
mcFlashMLA

Conversation

@zhan3916

@zhan3916 zhan3916 commented Jul 1, 2026

Copy link
Copy Markdown
Collaborator

1,Support Deepseek V3 and Deepseek V3.2.
2,Improve kernel performance.

@zhan3916 zhan3916 requested a review from chinadzh July 1, 2026 09:50

@gemini-code-assist gemini-code-assist Bot left a comment

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request refactors the FlashMLA implementation, introducing support for sparse decoding, sparse prefill attention, and context parallelism, while removing obsolete benchmark scripts and instantiations. The code review highlights several critical issues, including compilation errors in flash_fwd_dispatch_template.h due to a missing is_bf16 member in SparsePrefillParams, potential division-by-zero bugs and a backwards error message in get_attn_impl_meta, and a potential null pointer dereference in fwd_kvcache_mla when indices_all_valid_per_q is not provided. Additionally, the reviewer noted redundant duplicate checks on seqlens_k and a minor typo in flash_api.cpp.

Important

The consumer version of Gemini Code Assist on GitHub is being sunset. Starting June 18, 2026, new organization installations will be blocked, and all code review activity will officially cease on July 17, 2026.
For more details on the timeline and next steps, please review the Help Documentation.

Comment on lines +74 to +75
assert(params.is_bf16 && "sparse prefill only support bf16");
constexpr static int kBlockM = 64;

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

params.is_bf16 is accessed here, but the SparsePrefillParams struct does not have a member named is_bf16. This will cause a compilation failure. Since sparse_prefill_fwd already enforces kBFloat16 via TORCH_CHECK at the API level, this assertion is redundant and should be removed.

Suggested change
assert(params.is_bf16 && "sparse prefill only support bf16");
constexpr static int kBlockM = 64;
constexpr static int kBlockM = 64;

Comment on lines +88 to +89
assert(params.is_bf16 && "sparse prefill only support bf16");
constexpr static int kBlockM = 64;

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

params.is_bf16 is accessed here, but the SparsePrefillParams struct does not have a member named is_bf16. This will cause a compilation failure. Since sparse_prefill_fwd already enforces kBFloat16 via TORCH_CHECK at the API level, this assertion is redundant and should be removed.

Suggested change
assert(params.is_bf16 && "sparse prefill only support bf16");
constexpr static int kBlockM = 64;
constexpr static int kBlockM = 64;

Comment on lines +73 to +92
TORCH_CHECK(h_q_.has_value());
int h_q = h_q_.value();
TORCH_CHECK(h_q % h_k == 0, "h_k must be divisible by h_q.");
int s_q = num_q_tokens_per_head_k * h_k / h_q;
// BF16/FP16 + Sparse MLA
return {
std::max((sm_count/2) / h_k / (mctlass::ceil_div(h_q/h_k, 2*64) * s_q), 1),
5,
block_n // block_n
};
}
} else {
TORCH_CHECK(!is_fp8_kvcache, "FP8 KV Cache is not supported.");
// Dense BF16/FP8 MLA
return {
std::max(sm_count / h_k / mctlass::ceil_div(num_q_tokens_per_head_k, block_m), 1),
5,
block_n,
};
}

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

There are multiple potential division-by-zero bugs and a backwards error message in this block:

  1. The error message "h_k must be divisible by h_q." is backwards; it should be "h_q must be divisible by h_k.".
  2. If h_k is 0, h_q % h_k will cause a division-by-zero crash.
  3. If s_q is 0 (which happens if num_q_tokens_per_head_k * h_k < h_q), then mctlass::ceil_div(h_q/h_k, 2*64) * s_q is 0, leading to a division-by-zero crash on line 79.
  4. If h_k is 0, std::max(sm_count / h_k / ...) on line 88 will cause a division-by-zero crash.

We should add defensive checks to ensure h_k > 0, h_q > 0, and s_q > 0 before performing these divisions.

            TORCH_CHECK(h_q_.has_value());
            int h_q = h_q_.value();
            TORCH_CHECK(h_k > 0, "h_k must be greater than 0.");
            TORCH_CHECK(h_q > 0, "h_q must be greater than 0.");
            TORCH_CHECK(h_q % h_k == 0, "h_q must be divisible by h_k.");
            int s_q = num_q_tokens_per_head_k * h_k / h_q;
            TORCH_CHECK(s_q > 0, "s_q must be greater than 0 to avoid division by zero.");
            // BF16/FP16 + Sparse MLA
            return {
                std::max((sm_count/2) / h_k / (mctlass::ceil_div(h_q/h_k, 2*64) * s_q), 1),
                5,
                block_n // block_n
            };
        }
    } else {
        TORCH_CHECK(!is_fp8_kvcache, "FP8 KV Cache is not supported.");
        TORCH_CHECK(h_k > 0, "h_k must be greater than 0.");
        // Dense BF16/FP8 MLA
        return {
            std::max(sm_count / h_k / mctlass::ceil_div(num_q_tokens_per_head_k, block_m), 1),
            5,
            block_n,
        };

Comment on lines +134 to +139
bool is_sparse_attn = indices.has_value();
int topk = is_sparse_attn ? indices->size(-1) : -1;
TORCH_CHECK(!is_sparse_attn || indices->dtype() == torch::kInt32, "indices must have dtype int32");
TORCH_CHECK(!is_sparse_attn || indices->stride(-1) == 1, "indices must have contiguous last dimension");
TORCH_CHECK(!is_sparse_attn || indices_all_valid_per_q->dtype() == torch::kBool, "indices_all_valid_per_q must have dtype bool");
TORCH_CHECK(!is_sparse_attn || indices_all_valid_per_q->stride(-1) == 1, "indices_all_valid_per_q must have contiguous last dimension");

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

If is_sparse_attn is true, indices_all_valid_per_q is dereferenced via -> on lines 138 and 139. However, there is no check to ensure that indices_all_valid_per_q actually has a value (i.e., is not std::nullopt). If it is std::nullopt, this will cause a null pointer dereference crash. We should add a defensive check to ensure indices_all_valid_per_q.has_value() when is_sparse_attn is true.

    bool is_sparse_attn = indices.has_value();
    int topk = is_sparse_attn ? indices->size(-1) : -1;
    TORCH_CHECK(!is_sparse_attn || indices->dtype() == torch::kInt32, "indices must have dtype int32");
    TORCH_CHECK(!is_sparse_attn || indices->stride(-1) == 1, "indices must have contiguous last dimension");
    TORCH_CHECK(!is_sparse_attn || indices_all_valid_per_q.has_value(), "indices_all_valid_per_q must be provided when indices is provided");
    TORCH_CHECK(!is_sparse_attn || indices_all_valid_per_q->dtype() == torch::kBool, "indices_all_valid_per_q must have dtype bool");
    TORCH_CHECK(!is_sparse_attn || indices_all_valid_per_q->stride(-1) == 1, "indices_all_valid_per_q must have contiguous last dimension");

Comment on lines +161 to +163
if (is_sparse_attn){
TORCH_CHECK(num_heads_ori >= 64 || seqlen_q_ori == 1, "sparse decoding head q must greter than 64 when seqlen q > 1");
}

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Fix the typo "greter" to "greater" and make the error message grammatically correct.

    if (is_sparse_attn){
        TORCH_CHECK(num_heads_ori >= 64 || seqlen_q_ori == 1, "sparse decoding head q must be greater than 64 when seqlen q > 1");
    }

Comment on lines +374 to +383
CHECK_DEVICE(seqlens_k);
TORCH_CHECK(seqlens_k.is_contiguous());
TORCH_CHECK(seqlens_k.dtype() == torch::kInt32);
if (is_sparse_attn)
TORCH_CHECK(h_q.has_value(), "num_heads_q must be provided when topk is provided");

CHECK_DEVICE(seqlens_k);
TORCH_CHECK(seqlens_k.is_contiguous());
TORCH_CHECK(seqlens_k.dtype() == torch::kInt32);

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The checks on seqlens_k (lines 374-376 and lines 380-382) are completely identical and redundant. We can safely remove the duplicate block to clean up the code.

    CHECK_DEVICE(seqlens_k);
    TORCH_CHECK(seqlens_k.is_contiguous());
    TORCH_CHECK(seqlens_k.dtype() == torch::kInt32);
    if (is_sparse_attn)
        TORCH_CHECK(h_q.has_value(), "num_heads_q must be provided when topk is provided");

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant