Skip to content

feat(e8p): block-diagonal RHT — remove pow2-padding bloat, keep TC decode#18

Open
cnygaard wants to merge 6 commits into
mainfrom
feat/e8p-block-diag
Open

feat(e8p): block-diagonal RHT — remove pow2-padding bloat, keep TC decode#18
cnygaard wants to merge 6 commits into
mainfrom
feat/e8p-block-diag

Conversation

@cnygaard

Copy link
Copy Markdown
Owner

Summary

Makes the E8P codebook path block-diagonal instead of forcing a full power-of-2 Hadamard. Previously e8p padded every linear dim to the next power of two, which inflated non-pow2 models badly (gemma-4-31B "3bpw" → 27 GB, larger than shell-5bpw). The pow2 requirement was a myth: the E8P tensor-core kernels only need K % 64 == 0 && N % 16 == 0, not pow2 — the block structure lives only in the RHT. Block-diag pads each dim to a sum of pow2 blocks (each ≥ 64 cols / 16 rows) instead, with near-zero padding, while the TC decode + RVQ are untouched.

gemma-4-31B e8p 3bpw: ~27 → ~11 GB (qidxs 22.5 → 10.7 GiB, 2.1×). Footprint verified exact: predicted == measured on a real SmolLM2-360M checkpoint.

What's in here

  • Quantize (hadamard.py, rht.py, quantize_model.py): _block_decompose_min(d, mb) floors blocks to the qidxs-tile multiples; RHT(e8p=True) uses it. Drops the old is_e8p → pow2 force.
  • C++ (glq_cuda.cu, glq_bindings.cpp): thin glq_input/output_rht_blockdiag_cuda host helpers (single block → full RHT, else multiblock); glq_fused_linear_e8p_cuda gains 4 block params and just swaps its two RHT calls. Decode kernels unchanged, still one fused op so FULL cudagraph still pays off.
  • Op surface + serving (custom_ops.py, linear_method.py, quantized_linear.py): block tensors threaded through the fused_linear_e8p op/fake; vLLM precomputes them as frozen constants.

Fixes folded in

  • vLLM FULL-cudagraph serving: buffer sizing at the block-diag (checkpoint) shape so the loader copies in place; block RHT tensors precomputed in process_weights_after_loading so inductor lifts them as constants (no CPU-tensor-at-capture abort).
  • HF-eager / large-batch OOB: HF-eager fell to a multi-op path that ran the single-block RHT on a non-pow2 m_pad → shared-mem OOB at batch ≥ ~640 (found with compute-sanitizer). Now routes through the raw pybind fused op when torch.ops.glq isn't registered.
  • Back-compat: _forward_e8p re-derives n_pad/m_pad/blocks from the loaded Qidxs_e8p shape, so the branch loads both legacy pow2 e8p checkpoints (what stock pip glq produces) and new block-diag ones, on every load path incl. accelerate dispatch.

Validation (SmolLM2-360M, SmolLM3-3B; RTX PRO 6000 Blackwell, vLLM 0.23)

  • Serving: SmolLM2-360M e8p 4bpw block-diag serves on vLLM FULL cudagraph (51/51 decode graphs), coherent.

  • Throughput (SmolLM3-3B 2bpw, _tps_vllm.py, decode 256): e8p block-diag 138 tok/s B=1, 3394 agg / 106 per-seq B=32 — and 1.5–1.9× faster than the shell codebook at the same quality.

  • Quality (WikiText-2, full test, non-overlap 2048-chunk, fp16 — same harness for all):

    bpw rotation PPL vs bf16
    bf16 9.11
    3 pow2 (full) 9.70 +6.5%
    3 block-diag 9.95 +9.2%
    2 pow2 (full) 12.13 +33%
    2 block-diag 13.19 +45%

    Block-diag's weaker per-block rotation costs ~1 PPL at 2bpw but only ~0.26 PPL at 3bpw (the RVQ residual masks it) — so block-diag is near-free at ≥3bpw, and the pow2-vs-block-diag choice is really only a 2bpw concern. e8p ≈ shell quality at matched bpw. A GLQ_E8P_POW2=1 quantize toggle keeps the legacy full-Hadamard path available.

Notes

  • e8p is not published yet, so block-diag becomes the e8p default; _e8p_pad == _glq_pad for already-pow2 dims, so legacy pow2 checkpoints take the unchanged single-block path.
  • Follow-up (not in this PR): DRY-refactor the shell glq_fused_linear_block_diag_cuda to call the new shared *_blockdiag_cuda host helpers; re-quantize the existing private 12B/31B e8p checkpoints to block-diag to reap the footprint win.

cnygaard added 5 commits June 23, 2026 19:40
e8p forced a full pow2 Hadamard, padding every linear dim to the next power of
two. On dims that sit just above a pow2 (gemma-4-31B H=5376, I=21504) this
inflates the qidxs ~2.1x — "3bpw" gemma-4-31B is 27 GB (bigger than shell-5bpw).
The pow2 requirement is a myth here: the e8p TC kernels only need K%64==0 / N%16==0,
and the block structure lives only in the RHT. A sum of pow2 FHT blocks each >=64
(cols) / >=16 (rows) is automatically a multiple of 64/16.

This commit lands the quantize-side foundation:
- hadamard.py: `_block_decompose_min(n, min_block)` — block-diag decomposition with a
  pow2 floor so sum stays a multiple of min_block (validated: 31B qidxs 22.6->10.8 GB).
- rht.py: `RHT(e8p=True)` uses the floored decomposition (n>=64, m>=16). Default off.
- quantize_model.py: e8p now block-diags instead of forcing pow2.
- quantized_linear.py: E8RHTLinear sizes e8p buffers at the block-diag dims.

WIP — the inference/serving wiring is NOT landed yet (load-time block re-derivation,
multiblock-RHT dispatch in `_e8p_linear_apply` and `glq_fused_linear_e8p_cuda`, the
op-schema + vLLM plumbing). e8p checkpoints quantized on this branch will not load/serve
until that lands. Work branch; e8p is not public yet.
Lands the inference side so block-diag e8p checkpoints load and serve. The e8p
fused op stays a thin orchestrator: the multiblock-vs-full RHT dispatch is lifted
out of glq_fused_linear_block_diag_cuda into two shared host helpers
(glq_input_rht_blockdiag_cuda / glq_output_rht_blockdiag_cuda; single block = the
existing full pow2 RHT), and glq_fused_linear_e8p_cuda just swaps its two RHT calls
to them and gains blocks_n/blocks_m/*_meta params. The TC decode kernels are unchanged.

- glq_cuda.cu: extract the two block-diag RHT host helpers; thread blocks through
  the thin e8p fused op.
- glq_bindings.cpp + glq_vllm/custom_ops.py: extend the fused_linear_e8p op
  schema/fake with the 4 block tensors.
- quantized_linear.py: register e8p block-meta; pass blocks from _forward_e8p +
  _e8p_linear_apply (multi-op fallback rejects block-diag — fused op required).
- glq_vllm/linear_method.py: derive + pass per-shard block tensors in
  _glq_apply_e8p (n_pad/m_pad already come from the loaded Qidxs_e8p shape).

Validated: SmolLM2-360M e8p 4bpw re-quantized block-diag (non-pow2 n_pad 960/2560,
blocks [512,256,128,64]) loads in HF, fused block-diag vs dense rel 3.9e-3, identical
greedy, coherent ("Jupiter"). RTX PRO 6000, glq ext rebuilds + fused_linear_e8p
registers with the new signature. (Shell block-diag op left inlined; DRY refactor to
the shared helpers is a follow-up. 31B footprint + vLLM serving validation next.)
Complete the block-diag e8p inference path for vLLM. Two fixes were needed
to serve a block-diag (non-pow2-padded) e8p checkpoint under FULL cudagraph:

1. create_weights buffer sizing. The e8p buffers (Qidxs_e8p/SU/SV, the fused
   n_pad) were sized at the next power of two (_glq_pad). A block-diag
   checkpoint stores them at the block-diag padded dim (sum of pow2 blocks each
   >= 64/16), so the shapes mismatched and the loader could not copy in place.
   For a managed nn.Parameter that forces param.data = empty_like(loaded) onto
   CPU, and vLLM reverts the post-load .data->GPU move, leaving a CPU tensor at
   cudagraph capture ("qidxs_e8p must be a CUDA tensor"). New _e8p_pad(d,
   min_block) sizes them at the checkpoint shape so the in-place copy_ branch is
   taken (equals _glq_pad for already-pow2 dims, so legacy checkpoints are
   unchanged).

2. Block-tensor construction site. The block-diag RHT tensors (blocks_n/_m +
   packed meta) were built lazily inside apply() via torch.tensor(...). dynamo
   traced that CPU allocation into the captured graph, and the resulting CPU
   tensor tripped "Cannot copy between CPU and CUDA tensors during CUDA graph
   capture". Build them once in process_weights_after_loading (before compile),
   held on each meta dict, mirroring the shell block-diag path (_glq_bd_meta) so
   inductor lifts them as frozen constants. _bn/_bm stay on CPU (the C++ op
   reads block sizes host-side); the packed _bnm/_bmm are GPU int32.

Validated: SmolLM2-360M e8p 4bpw block-diag serves on vLLM 0.23 with FULL
cudagraph (51/51 decode graphs captured, no CPU-tensor abort), coherent output
(2+2=4, primes 2/3/5/7/11). Legacy pow2 e8p checkpoints take the single-block
== full-RHT path unchanged.
The block-diagonal e8p HF-eager path (and any non-vLLM context) crashed with a
CUDA illegal memory access on prefills/PPL at batch ≥ ~640.

Root cause: _e8p_linear_apply only took the fused-op fast path when the
torch.ops.glq custom ops were registered — which happens via
glq_vllm/custom_ops.py, imported only on the vLLM path. A plain
`import glq.hf_integration` run left _ops = None, so the code fell to the
multi-op fallback, which calls the single-block glq_output_rht_cuda with the
full (non-power-of-2) m_pad. That kernel uses __builtin_ctz(m_pad) as log_m and
runs a power-of-2 butterfly, so for a block-diag dim (e.g. SmolLM3 gate/up
m_pad=11008) it reads past the shared-memory buffer. It only *faulted* at high
occupancy (B ≥ ~640 packs enough CTAs that the OOB smem read crosses an unmapped
partition), so small-B decode/serving on vLLM never tripped it. The fallback
guard also only checked block-diagonal blocks_n, missing block-diagonal blocks_m.

Fix: resolve the fused op from the raw pybind binding
(cuda.glq_fused_linear_e8p_cuda) when the torch.op isn't registered. Both call
the same block-diagonal-correct kernel (glq_output_rht_blockdiag_cuda), so
block-diag e8p now works in HF eager and at large batch. Harden the fallback
guard to also reject multi-block blocks_m.

Verified: SmolLM2-360M/SmolLM3-3B e8p block-diag forward (1, 2048) returns clean
logits (was illegal-memory-access); compute-sanitizer-clean.
…-diag

E8RHTLinear.__init__ sized n_pad/m_pad/blocks from in/out_features assuming the
new block-diagonal layout, and nothing re-derived them from the actual checkpoint
on load. A block-diag checkpoint happened to match, but a legacy *pow2* e8p
checkpoint (n_pad/m_pad padded to the next power of two — the format stock
pip glq 0.6.5 produces, and what the already-uploaded 12B/31B private e8p
checkpoints use) mismatched: e.g. SmolLM3 down_proj loaded with n_pad=16384 while
__init__ computed 11008, so the TC kernel's K==Qidxs.cols*64 check failed.

Fix: in _forward_e8p, re-derive n_pad/m_pad + block decomposition from the loaded
Qidxs_e8p shape once (actual_n=cols*64, actual_m=rows*16). _block_decompose(pow2)
== [pow2] → single full-Hadamard block, so both legacy pow2 and new block-diag
checkpoints resolve correctly, on every load path incl. accelerate device_map
dispatch (which bypasses _load_from_state_dict). Restores back-compat for pow2
e8p checkpoints on the block-diag branch.

Also adds a GLQ_E8P_POW2 quantize toggle (glq/rht.py) to produce the legacy full
pow2 Hadamard e8p for A/B. Measured on SmolLM3-3B 2bpw (wikitext2, same harness):
pow2 e8p 12.13 PPL / 9.95 dB SQNR vs block-diag e8p 13.19 / 9.11 dB — the full
Hadamard is a stronger rotation and quantizes ~1 PPL better, so block-diag trades
quality for footprint (use pow2 for near-pow2 dims, block-diag for far-from-pow2).

Verified: pow2 SmolLM3-3B e8p forward (1, 2048) clean; PPL runs.
@cnygaard cnygaard changed the title feat(e8p): block-diagonal RHT — kill pow2-padding bloat, keep TC decode feat(e8p): block-diagonal RHT — remove pow2-padding bloat, keep TC decode Jun 23, 2026
Block-diagonal E8P: the E8P codebook path no longer forces a full power-of-2
Hadamard. It pads each linear dim to a sum of pow2 blocks (each >= 64 cols /
16 rows) instead of the next power of two, so non-pow2 models stop bloating —
gemma-4-31B E8P 3bpw drops from ~27 GB to ~11 GB (qidxs 22.5 -> 10.7 GiB) while
the tensor-core decode + RVQ are untouched, still one fused op so FULL cudagraph
still pays off. Serves on vLLM 0.23 FULL cudagraph; 1.5-1.9x faster decode than
the shell codebook at matched quality. Folds in three fixes surfaced by
validation: vLLM cudagraph buffer sizing, an HF-eager / large-batch shared-mem
OOB (single-block RHT on a non-pow2 m_pad), and pow2<->block-diag back-compat
loading (re-derives dims from the loaded Qidxs_e8p shape, so both legacy pow2
and new block-diag checkpoints load on every path). A GLQ_E8P_POW2 quantize
toggle keeps the legacy full-Hadamard E8P available — measured ~1 PPL better at
2bpw, ~0.26 at 3bpw, so block-diag is near-free at >=3bpw.

Also includes first-class E8P support in the batch_quantization tool (#17).
@cnygaard

Copy link
Copy Markdown
Owner Author

Folded the 0.6.6 version bump into this PR (commit b0e6e36): glq/__init__.py + pyproject.toml → 0.6.6. Merging ships 0.6.6 — block-diagonal E8P plus the first-class E8P batch_quantization from #17. After merge: tag v0.6.6 on main and build/publish.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant