Skip to content

feat(e8p): compressed batched tensor-core GEMM for B>1 decode#15

Merged
cnygaard merged 3 commits into
mainfrom
feat/e8p-batched-gemm
Jun 22, 2026
Merged

feat(e8p): compressed batched tensor-core GEMM for B>1 decode#15
cnygaard merged 3 commits into
mainfrom
feat/e8p-batched-gemm

Conversation

@cnygaard

Copy link
Copy Markdown
Owner

Summary

Fixes the B>1 (batched) E8P decode cliff. The fused E8P linear's B>1 branch fully decompressed every weight to dense fp16 on each step, then ran a dense matmul against only B token-vectors — the decompress never amortized, leaving E8P ~17× behind bf16 at B=32.

Adds glq_decode_matmul_e8p_kernel: it packs up to 8 tokens into the decode GEMV's otherwise-idle mma.sync.m16n8k16 n-columns and bit-expands each weight tile once, reusing it across the whole batch. At B=1 only the n=0 column carries output; n=1..7 are free for the batch.

The fragment layout was derived and cross-checked against glq_decompress_packed_e8p_kernel (the ground-truth weight→memory layout): mma-row g (w0x) → output row 2g, row g+8 (w1x) → 2g+1; the per-token B-fragment is 8 consecutive uint32 at tok[32*iik + 8t + r] (replacing the GEMV's shuffle-broadcast).

Changes

  • glq_e8p.cu: new glq_decode_matmul_e8p_kernel + glq_matmul_e8p host launcher (n-tiled by ceil(B/8), zero-pad + OOB guard on the trailing partial tile).
  • glq_cuda.cu: wire into glq_fused_linear_e8p_cuda's B>1 branch. The dense decompress path is kept behind env GLQ_E8P_DENSE_B1 as a guarded fallback / A-B reference. Covers 2bpw (stage-1), 4bpw (kernel run twice), 3bpw (batch tiled through the existing lookupmatmul_e81b_k8 instead of dense decompress).
  • glq_bindings.cpp: expose glq_matmul_e8p for parity testing.

The op signature, custom-op registration, and Python wiring are unchanged — only the C++ internals of the B>1 path change, so the fused-cudagraph integration is untouched.

Validation (gemma-4-12B-it e8p-3bpw, vLLM 0.23 FULL cudagraph, RTX PRO 6000 Blackwell)

gate result
synthetic parity, B=1..32 (incl. 3/7/9) rel 1e-7 vs dense-decompress and vs per-token GEMV
B=32 decode (512 tok) 2.5 → 8.0 per-seq tok/s (3.2×; 79.8 → 254.9 aggregate)
B=1 decode 24.3 tok/s — no regression (still the untouched GEMV)
5-prompt coherence smoke passes (stage-1 batched + e81b batched stage-2)

Doesn't beat bf16 on a 96 GB GPU (GLQ is compute/L2-bound there); the win is removing the 17× pathology so batched E8P serving is viable, plus the footprint advantage (6.25 vs 24 GB).

Notes / follow-ups

  • The kernel uses an atomicAdd cross-warp reduction, matching the existing E8P GEMV (run-to-run det_max ~5e-5). A scratch+reduce pass would make it bit-exact if E8P determinism is wanted.
  • 4bpw E8P stage-2 runs the same validated kernel twice; wired but not yet e2e-tested on a real 4bpw checkpoint.

cnygaard added 3 commits June 22, 2026 21:23
The fused E8P linear's B>1 branch fully decompressed every weight to dense
fp16 on each step, then ran a dense matmul against only B token-vectors —
17x behind bf16 at B=32 because the decompress never amortized. Add
glq_decode_matmul_e8p_kernel, which packs up to 8 tokens into the decode
GEMV's otherwise-idle mma.sync n-columns and bit-expands each weight tile
once, reusing it across the whole batch. The fragment layout is cross-checked
against glq_decompress_packed_e8p_kernel (the ground-truth weight->memory
layout): mma-row g (w0x) -> output row 2g, row g+8 (w1x) -> 2g+1; the
per-token B-fragment is 8 consecutive uint32 at tok[32*iik + 8t + r].

- glq_matmul_e8p host launcher: n-tiled by ceil(B/8), zero-pad + OOB guard
  on the trailing partial tile
- wire into glq_fused_linear_e8p_cuda's B>1 branch; the dense decompress
  path is kept behind env GLQ_E8P_DENSE_B1 as a guarded fallback / A-B ref
- covers 2bpw (stage-1), 4bpw (kernel run twice), and 3bpw (batch tiled
  through the existing lookupmatmul_e81b_k8 instead of dense decompress)
- expose glq_matmul_e8p pybind for parity testing

Validated on gemma-4-12B-it e8p-3bpw, vLLM 0.23 FULL cudagraph, RTX PRO 6000:
B=32 decode 512tok 2.5 -> 8.0 per-seq tok/s (3.2x, 79.8 -> 254.9 aggregate);
B=1 24.3 tok/s (no regression, still the untouched GEMV); synthetic parity
rel 1e-7 vs both dense-decompress and per-token GEMV across B=1..32; 5-prompt
coherence smoke passes. Kernel uses atomicAdd cross-warp reduction, matching
the existing E8P GEMV (run-to-run det_max ~5e-5).
The batched kernel reduced across its warps with atomicAdd, so batched decode
was not bit-exact run-to-run (det_max ~5e-5) — unlike GLQ's shell kernels.
Replace it with the same scratch+reduce pattern: E8P_MM_WARPS (8) warps split K
and each writes its own scratch plane (scratch[warpId*B*N + token*N + row], no
atomics — planes disjoint by warpId, rows disjoint by block, every slot written
once), then a local glq_reduce_splits_2d_e8p_kernel sums the planes in fixed
order. The reduce is kept in this TU since the build has no -rdc=true.

Dropping the K-split warp count from 32 to 8 bounds the scratch to NW*B*N (so it
stays below the packed-weight traffic at B=32) while keeping enough K parallelism.

gemma-4-12B-it e8p-3bpw, vLLM 0.23 FULL cudagraph, RTX PRO 6000:
parity det_max=0 (bit-exact) across B=1..32 with rel still 1e-7 vs dense and the
GEMV; B=32 decode 8.0 -> 7.8 per-seq (-2.5% for full determinism); B=1 24.0
(unchanged, still the GEMV). Also validated the 4bpw two-E8P-stage batched path
end-to-end on a SmolLM2-360M e8p-4bpw checkpoint (batched vs dense rel 3.2e-3,
identical greedy, coherent).
The B=1 GEMV reduced across its 32 warps with atomicAdd, the last E8P path that
was not deterministic run-to-run. Apply the same scratch+reduce treatment as the
batched kernel: keep the 32 warps (scratch is only 32*N floats, tiny for B=1, and
32 warps preserve the K-parallelism the latency-bound matvec needs), each writes
its own plane scratch[warpId*N + row] (no atomics), then the shared local
glq_reduce_splits_2d_e8p_kernel (moved above the GEMV) sums the planes in fixed
order with BN=N. The reduce is one extra kernel inside glq_decode_matvec_e8p (not
a new torch.op), so it's captured in the fused-op FULL cudagraph.

gemma-4-12B-it e8p-3bpw, vLLM 0.23 FULL cudagraph, RTX PRO 6000:
GEMV run-to-run det_max=0 (bit-exact) across B=1..32 with rel still 1e-7 vs dense;
B=1 decode 24.0 -> 23.9 tok/s (no regression — reduce captured in the graph);
5-prompt coherence smoke passes. The whole GLQ E8P + shell kernel surface is now
deterministic.
@cnygaard cnygaard merged commit e2d8157 into main Jun 22, 2026
3 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant