add block-sparse fp8 prefill attention (dim=128/192, two quant schemes)#45
Open
bcacdwk wants to merge 1 commit into
Open
add block-sparse fp8 prefill attention (dim=128/192, two quant schemes)#45bcacdwk wants to merge 1 commit into
bcacdwk wants to merge 1 commit into
Conversation
There was a problem hiding this comment.
Pull request overview
Adds FP8 prefill attention support for additional sparsity/layout/quantization combinations, including (1) paged KV-cache block-sparse prefill for dim=128 and (2) varlen block-sparse prefill for dim=192, and updates the Python/C++ entrypoints + tests to exercise the new quant_type dispatch.
Changes:
- Split FP8 paged-KV prefill into two quantization schemes selected by
quant_type, and updated the C++/Python bindings accordingly. - Added new CUDA warp-specialization kernels for block-sparse paged-KV FP8 (dim=128) and block-sparse varlen FP8 (dim=192).
- Reworked/expanded test coverage for the new kernels and quantization modes; removed the old single FP8 KV-cache prefill test.
Reviewed changes
Copilot reviewed 22 out of 23 changed files in this pull request and generated 7 comments.
Show a summary per file
| File | Description |
|---|---|
| tests/test_fuse_moe_pertensor.py | Updates the naive reference GEMM path used by MoE tests. |
| tests/test_attention_with_kvcache_qpertoken_perhead_kvpertensor_prefill_fp8.py | New tests for paged KV-cache FP8 prefill (dim=128) using per-tensor K/V scale scheme. |
| tests/test_attention_with_kvcache_qkpertoken_perhead_vperhead_prefill_fp8.py | New tests for paged KV-cache FP8 prefill (dim=128) using per-token/per-head K scale + per-head V scale scheme. |
| tests/test_attention_with_kvcache_prefill_fp8.py | Removes the previous unified FP8 KV-cache prefill test (replaced by scheme-specific tests). |
| tests/test_attention_blocksparse_qpertoken_perhead_kvpertensor_fp8.py | New block-sparse paged-KV FP8 prefill tests for the per-tensor K/V scheme. |
| tests/test_attention_blocksparse_qkpertoken_perhead_vperhead_fp8.py | New block-sparse paged-KV FP8 prefill tests for the per-token/per-head K + per-head V scheme. |
| tests/test_attention_blocksparse_prefill_fp8_dim192.py | New varlen (dim=192/128) FP8 block-sparse prefill tests. |
| src/attention/prefill/warp_spec_with_kvcache_fp8_dim128.h | Splits/renames FP8 KV-cache dim128 entrypoints by quantization scheme and updates stride parameters. |
| src/attention/prefill/warp_spec_with_kvcache_fp8_dim128.cu | Implements scheme-specific FP8 KV-cache dim128 launch paths and updated tensor strides/TMA wiring. |
| src/attention/prefill/warp_spec_with_kvcache_dim128.h | Extends BF16 KV-cache dim128 interface to accept explicit K/V strides. |
| src/attention/prefill/warp_spec_with_kvcache_dim128.cu | Uses explicit K/V strides when constructing Cute tensors for BF16 KV-cache prefill. |
| src/attention/prefill/warp_spec_with_kvcache_blocksparse_fp8_dim128.h | New block-sparse paged-KV FP8 dim128 kernel entrypoints. |
| src/attention/prefill/warp_spec_with_kvcache_blocksparse_fp8_dim128.cu | New block-sparse paged-KV FP8 dim128 warp-spec kernel launchers (masked/unmasked). |
| src/attention/prefill/warp_spec_blocksparse_fp8_dim192.h | New block-sparse FP8 dim192 warp-spec entrypoint for varlen attention. |
| src/attention/prefill/warp_spec_blocksparse_fp8_dim192.cu | New block-sparse FP8 dim192 kernel launcher with optional block mask support. |
| src/attention/prefill/prefill.h | Adds new async APIs for scheme-specific FP8 KV-cache prefill and new blocksparse prefill variants. |
| src/attention/prefill/prefill.cc | Wires new scheme-specific FP8 + block-sparse prefill dispatch to warp-spec implementations. |
| src/attention/prefill/multi_stage_with_kvcache_dim128.h | Extends multi-stage BF16 KV-cache dim128 interface to accept explicit strides. |
| src/attention/prefill/multi_stage_with_kvcache_dim128.cu | Uses explicit K/V strides in multi-stage BF16 KV-cache dim128 kernel launch. |
| src/attention/prefill/config.h | Adds/renames config structs to support new FP8 KV-cache quantization modes and dim192 prefill. |
| src/attention/entry.cc | Updates Torch entrypoints: output validation, quant_type plumbing, and new blocksparse APIs. |
| hpc/attention.py | Adds QuantType, updates wrappers for new quant_type signature, and adds new blocksparse APIs + fake impls. |
Comments suppressed due to low confidence (1)
src/attention/prefill/config.h:467
- Type name appears to have a typo:
AttentionKVCachePrefillQKTokenVTenorFp8Config("Tenor"). Renaming to "Tensor" would improve clarity and avoid propagating the typo into more call sites.
template <typename Tin_, typename Tout_, typename TiledMmaQKAtom, typename TiledMmaPVAtom,
int kTileM_, int kTileN_, int kTileK_, int kTileV_, int kTileS_, int kBlockSize_,
int kScaleBlockSize_, int kStage_, int kWarpgroupM_ = 2, int kWarpgroupN_ = 1,
int kSwizzleQ = 128, int kSwizzleK = 128, int kSwizzleV = 128, int kSwizzleY = 128>
struct AttentionKVCachePrefillQKTokenVTenorFp8Config {
using Tin = Tin_;
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| .unsqueeze(0) | ||
| ) # (1, 1, num_seq_q, num_seq_kv) | ||
| else: | ||
| causal_mask = causal_mask.view(1, 1, num_seq_q, num_seq_kv) |
| .unsqueeze(0) | ||
| ) # (1, 1, num_seq_q, num_seq_kv) | ||
| else: | ||
| causal_mask = causal_mask.view(1, 1, num_seq_q, num_seq_kv) |
Comment on lines
+1
to
+8
| from enum import Enum | ||
| from typing import Optional | ||
|
|
||
| import torch | ||
| from torch import Tensor | ||
|
|
||
|
|
||
| class QuantType(Enum): |
Comment on lines
+11
to
+12
| QPERTENSOR_KPERTENSOR_VPERTENSOR = 2 | ||
| QPERTOKEN_PERHEAD_KPERTOKEN_PERHEAD_VPERHEAD_QKHADAMARD = 3 |
| expected_max_num_tile_m, ", Kb] where Kb = ceil(max_kv_len / kTileN=", kTileN, ")"); | ||
| num_tile_kv_in_mask = block_mask_tensor.size(3); | ||
|
|
||
| TORCH_CHECK(num_tile_kv_in_mask > 0, "block_mask Kb dim must be > 0"); |
| int kWarpgroupM_ = 2, int kWarpgroupN_ = 1, int kSwizzleQ = 128, int kSwizzleK = 128, | ||
| int kSwizzleV = 128, int kSwizzleY = 128> | ||
| struct AttentionKVCachePrefillFp8Config { | ||
| struct AttentionKVCachePrefillQTokenKVTenorFp8Config { |
Comment on lines
+167
to
+169
| TORCH_CHECK((quant_type == 0 || quant_type == 1), "quant_type only support 0/1"); | ||
| TORCH_CHECK((kscale.dtype().itemsize() == 4 || kscale.dtype().itemsize() == 1), | ||
| "kscale dtype must be float or fp8"); |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Summary
quant_typeparameter:QPERTOKEN_PERHEAD_KPERTENSOR_VPERTENSOR(default)QPERTOKEN_PERHEAD_KPERTOKEN_PERHEAD_VPERHEADTest plan
make wheelbuilds clean on H20 / CUDA 13.0 / torch 2.10test_attention_*.pypass (404 cases)pytest tests/regression: 623 passed