feat(e8p): compressed batched tensor-core GEMM for B>1 decode#15
Merged
Conversation
The fused E8P linear's B>1 branch fully decompressed every weight to dense fp16 on each step, then ran a dense matmul against only B token-vectors — 17x behind bf16 at B=32 because the decompress never amortized. Add glq_decode_matmul_e8p_kernel, which packs up to 8 tokens into the decode GEMV's otherwise-idle mma.sync n-columns and bit-expands each weight tile once, reusing it across the whole batch. The fragment layout is cross-checked against glq_decompress_packed_e8p_kernel (the ground-truth weight->memory layout): mma-row g (w0x) -> output row 2g, row g+8 (w1x) -> 2g+1; the per-token B-fragment is 8 consecutive uint32 at tok[32*iik + 8t + r]. - glq_matmul_e8p host launcher: n-tiled by ceil(B/8), zero-pad + OOB guard on the trailing partial tile - wire into glq_fused_linear_e8p_cuda's B>1 branch; the dense decompress path is kept behind env GLQ_E8P_DENSE_B1 as a guarded fallback / A-B ref - covers 2bpw (stage-1), 4bpw (kernel run twice), and 3bpw (batch tiled through the existing lookupmatmul_e81b_k8 instead of dense decompress) - expose glq_matmul_e8p pybind for parity testing Validated on gemma-4-12B-it e8p-3bpw, vLLM 0.23 FULL cudagraph, RTX PRO 6000: B=32 decode 512tok 2.5 -> 8.0 per-seq tok/s (3.2x, 79.8 -> 254.9 aggregate); B=1 24.3 tok/s (no regression, still the untouched GEMV); synthetic parity rel 1e-7 vs both dense-decompress and per-token GEMV across B=1..32; 5-prompt coherence smoke passes. Kernel uses atomicAdd cross-warp reduction, matching the existing E8P GEMV (run-to-run det_max ~5e-5).
The batched kernel reduced across its warps with atomicAdd, so batched decode was not bit-exact run-to-run (det_max ~5e-5) — unlike GLQ's shell kernels. Replace it with the same scratch+reduce pattern: E8P_MM_WARPS (8) warps split K and each writes its own scratch plane (scratch[warpId*B*N + token*N + row], no atomics — planes disjoint by warpId, rows disjoint by block, every slot written once), then a local glq_reduce_splits_2d_e8p_kernel sums the planes in fixed order. The reduce is kept in this TU since the build has no -rdc=true. Dropping the K-split warp count from 32 to 8 bounds the scratch to NW*B*N (so it stays below the packed-weight traffic at B=32) while keeping enough K parallelism. gemma-4-12B-it e8p-3bpw, vLLM 0.23 FULL cudagraph, RTX PRO 6000: parity det_max=0 (bit-exact) across B=1..32 with rel still 1e-7 vs dense and the GEMV; B=32 decode 8.0 -> 7.8 per-seq (-2.5% for full determinism); B=1 24.0 (unchanged, still the GEMV). Also validated the 4bpw two-E8P-stage batched path end-to-end on a SmolLM2-360M e8p-4bpw checkpoint (batched vs dense rel 3.2e-3, identical greedy, coherent).
The B=1 GEMV reduced across its 32 warps with atomicAdd, the last E8P path that was not deterministic run-to-run. Apply the same scratch+reduce treatment as the batched kernel: keep the 32 warps (scratch is only 32*N floats, tiny for B=1, and 32 warps preserve the K-parallelism the latency-bound matvec needs), each writes its own plane scratch[warpId*N + row] (no atomics), then the shared local glq_reduce_splits_2d_e8p_kernel (moved above the GEMV) sums the planes in fixed order with BN=N. The reduce is one extra kernel inside glq_decode_matvec_e8p (not a new torch.op), so it's captured in the fused-op FULL cudagraph. gemma-4-12B-it e8p-3bpw, vLLM 0.23 FULL cudagraph, RTX PRO 6000: GEMV run-to-run det_max=0 (bit-exact) across B=1..32 with rel still 1e-7 vs dense; B=1 decode 24.0 -> 23.9 tok/s (no regression — reduce captured in the graph); 5-prompt coherence smoke passes. The whole GLQ E8P + shell kernel surface is now deterministic.
This was referenced Jun 22, 2026
Merged
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Summary
Fixes the B>1 (batched) E8P decode cliff. The fused E8P linear's B>1 branch fully decompressed every weight to dense fp16 on each step, then ran a dense matmul against only B token-vectors — the decompress never amortized, leaving E8P ~17× behind bf16 at B=32.
Adds
glq_decode_matmul_e8p_kernel: it packs up to 8 tokens into the decode GEMV's otherwise-idlemma.sync.m16n8k16n-columns and bit-expands each weight tile once, reusing it across the whole batch. At B=1 only the n=0 column carries output; n=1..7 are free for the batch.The fragment layout was derived and cross-checked against
glq_decompress_packed_e8p_kernel(the ground-truth weight→memory layout): mma-rowg(w0x) → output row2g, rowg+8(w1x) →2g+1; the per-token B-fragment is 8 consecutiveuint32attok[32*iik + 8t + r](replacing the GEMV's shuffle-broadcast).Changes
glq_e8p.cu: newglq_decode_matmul_e8p_kernel+glq_matmul_e8phost launcher (n-tiled byceil(B/8), zero-pad + OOB guard on the trailing partial tile).glq_cuda.cu: wire intoglq_fused_linear_e8p_cuda's B>1 branch. The dense decompress path is kept behind envGLQ_E8P_DENSE_B1as a guarded fallback / A-B reference. Covers 2bpw (stage-1), 4bpw (kernel run twice), 3bpw (batch tiled through the existinglookupmatmul_e81b_k8instead of dense decompress).glq_bindings.cpp: exposeglq_matmul_e8pfor parity testing.The op signature, custom-op registration, and Python wiring are unchanged — only the C++ internals of the B>1 path change, so the fused-cudagraph integration is untouched.
Validation (gemma-4-12B-it e8p-3bpw, vLLM 0.23 FULL cudagraph, RTX PRO 6000 Blackwell)
Doesn't beat bf16 on a 96 GB GPU (GLQ is compute/L2-bound there); the win is removing the 17× pathology so batched E8P serving is viable, plus the footprint advantage (6.25 vs 24 GB).
Notes / follow-ups
atomicAddcross-warp reduction, matching the existing E8P GEMV (run-to-rundet_max ~5e-5). A scratch+reduce pass would make it bit-exact if E8P determinism is wanted.