Skip to content

add block-sparse fp8 prefill attention (dim=128/192, two quant schemes)#45

Open
bcacdwk wants to merge 1 commit into
Tencent:mainfrom
bcacdwk:block-sparse-attention
Open

add block-sparse fp8 prefill attention (dim=128/192, two quant schemes)#45
bcacdwk wants to merge 1 commit into
Tencent:mainfrom
bcacdwk:block-sparse-attention

Conversation

@bcacdwk
Copy link
Copy Markdown

@bcacdwk bcacdwk commented May 25, 2026

Summary

  • Block-sparse paged-KV prefill attention for FP8, dim=128 (with kvcache)
  • Block-sparse varlen prefill attention for FP8, dim=192 (e.g. MLA)
  • Two quantization schemes via quant_type parameter:
    • QPERTOKEN_PERHEAD_KPERTENSOR_VPERTENSOR (default)
    • QPERTOKEN_PERHEAD_KPERTOKEN_PERHEAD_VPERHEAD
  • Updated existing dense FP8 prefill kernels to support both quant schemes.

Test plan

  • make wheel builds clean on H20 / CUDA 13.0 / torch 2.10
  • All 9 test_attention_*.py pass (404 cases)
  • Full pytest tests/ regression: 623 passed

Copilot AI review requested due to automatic review settings May 25, 2026 13:08
Copy link
Copy Markdown

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Adds FP8 prefill attention support for additional sparsity/layout/quantization combinations, including (1) paged KV-cache block-sparse prefill for dim=128 and (2) varlen block-sparse prefill for dim=192, and updates the Python/C++ entrypoints + tests to exercise the new quant_type dispatch.

Changes:

  • Split FP8 paged-KV prefill into two quantization schemes selected by quant_type, and updated the C++/Python bindings accordingly.
  • Added new CUDA warp-specialization kernels for block-sparse paged-KV FP8 (dim=128) and block-sparse varlen FP8 (dim=192).
  • Reworked/expanded test coverage for the new kernels and quantization modes; removed the old single FP8 KV-cache prefill test.

Reviewed changes

Copilot reviewed 22 out of 23 changed files in this pull request and generated 7 comments.

Show a summary per file
File Description
tests/test_fuse_moe_pertensor.py Updates the naive reference GEMM path used by MoE tests.
tests/test_attention_with_kvcache_qpertoken_perhead_kvpertensor_prefill_fp8.py New tests for paged KV-cache FP8 prefill (dim=128) using per-tensor K/V scale scheme.
tests/test_attention_with_kvcache_qkpertoken_perhead_vperhead_prefill_fp8.py New tests for paged KV-cache FP8 prefill (dim=128) using per-token/per-head K scale + per-head V scale scheme.
tests/test_attention_with_kvcache_prefill_fp8.py Removes the previous unified FP8 KV-cache prefill test (replaced by scheme-specific tests).
tests/test_attention_blocksparse_qpertoken_perhead_kvpertensor_fp8.py New block-sparse paged-KV FP8 prefill tests for the per-tensor K/V scheme.
tests/test_attention_blocksparse_qkpertoken_perhead_vperhead_fp8.py New block-sparse paged-KV FP8 prefill tests for the per-token/per-head K + per-head V scheme.
tests/test_attention_blocksparse_prefill_fp8_dim192.py New varlen (dim=192/128) FP8 block-sparse prefill tests.
src/attention/prefill/warp_spec_with_kvcache_fp8_dim128.h Splits/renames FP8 KV-cache dim128 entrypoints by quantization scheme and updates stride parameters.
src/attention/prefill/warp_spec_with_kvcache_fp8_dim128.cu Implements scheme-specific FP8 KV-cache dim128 launch paths and updated tensor strides/TMA wiring.
src/attention/prefill/warp_spec_with_kvcache_dim128.h Extends BF16 KV-cache dim128 interface to accept explicit K/V strides.
src/attention/prefill/warp_spec_with_kvcache_dim128.cu Uses explicit K/V strides when constructing Cute tensors for BF16 KV-cache prefill.
src/attention/prefill/warp_spec_with_kvcache_blocksparse_fp8_dim128.h New block-sparse paged-KV FP8 dim128 kernel entrypoints.
src/attention/prefill/warp_spec_with_kvcache_blocksparse_fp8_dim128.cu New block-sparse paged-KV FP8 dim128 warp-spec kernel launchers (masked/unmasked).
src/attention/prefill/warp_spec_blocksparse_fp8_dim192.h New block-sparse FP8 dim192 warp-spec entrypoint for varlen attention.
src/attention/prefill/warp_spec_blocksparse_fp8_dim192.cu New block-sparse FP8 dim192 kernel launcher with optional block mask support.
src/attention/prefill/prefill.h Adds new async APIs for scheme-specific FP8 KV-cache prefill and new blocksparse prefill variants.
src/attention/prefill/prefill.cc Wires new scheme-specific FP8 + block-sparse prefill dispatch to warp-spec implementations.
src/attention/prefill/multi_stage_with_kvcache_dim128.h Extends multi-stage BF16 KV-cache dim128 interface to accept explicit strides.
src/attention/prefill/multi_stage_with_kvcache_dim128.cu Uses explicit K/V strides in multi-stage BF16 KV-cache dim128 kernel launch.
src/attention/prefill/config.h Adds/renames config structs to support new FP8 KV-cache quantization modes and dim192 prefill.
src/attention/entry.cc Updates Torch entrypoints: output validation, quant_type plumbing, and new blocksparse APIs.
hpc/attention.py Adds QuantType, updates wrappers for new quant_type signature, and adds new blocksparse APIs + fake impls.
Comments suppressed due to low confidence (1)

src/attention/prefill/config.h:467

  • Type name appears to have a typo: AttentionKVCachePrefillQKTokenVTenorFp8Config ("Tenor"). Renaming to "Tensor" would improve clarity and avoid propagating the typo into more call sites.
template <typename Tin_, typename Tout_, typename TiledMmaQKAtom, typename TiledMmaPVAtom,
          int kTileM_, int kTileN_, int kTileK_, int kTileV_, int kTileS_, int kBlockSize_,
          int kScaleBlockSize_, int kStage_, int kWarpgroupM_ = 2, int kWarpgroupN_ = 1,
          int kSwizzleQ = 128, int kSwizzleK = 128, int kSwizzleV = 128, int kSwizzleY = 128>
struct AttentionKVCachePrefillQKTokenVTenorFp8Config {
  using Tin = Tin_;

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

.unsqueeze(0)
) # (1, 1, num_seq_q, num_seq_kv)
else:
causal_mask = causal_mask.view(1, 1, num_seq_q, num_seq_kv)
.unsqueeze(0)
) # (1, 1, num_seq_q, num_seq_kv)
else:
causal_mask = causal_mask.view(1, 1, num_seq_q, num_seq_kv)
Comment thread hpc/attention.py
Comment on lines +1 to +8
from enum import Enum
from typing import Optional

import torch
from torch import Tensor


class QuantType(Enum):
Comment thread hpc/attention.py
Comment on lines +11 to +12
QPERTENSOR_KPERTENSOR_VPERTENSOR = 2
QPERTOKEN_PERHEAD_KPERTOKEN_PERHEAD_VPERHEAD_QKHADAMARD = 3
Comment thread src/attention/entry.cc
expected_max_num_tile_m, ", Kb] where Kb = ceil(max_kv_len / kTileN=", kTileN, ")");
num_tile_kv_in_mask = block_mask_tensor.size(3);

TORCH_CHECK(num_tile_kv_in_mask > 0, "block_mask Kb dim must be > 0");
int kWarpgroupM_ = 2, int kWarpgroupN_ = 1, int kSwizzleQ = 128, int kSwizzleK = 128,
int kSwizzleV = 128, int kSwizzleY = 128>
struct AttentionKVCachePrefillFp8Config {
struct AttentionKVCachePrefillQTokenKVTenorFp8Config {
Comment thread src/attention/entry.cc
Comment on lines +167 to +169
TORCH_CHECK((quant_type == 0 || quant_type == 1), "quant_type only support 0/1");
TORCH_CHECK((kscale.dtype().itemsize() == 4 || kscale.dtype().itemsize() == 1),
"kscale dtype must be float or fp8");
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants