[WIP] Refactor pytorch quant modules, Support MoE#2466
Conversation
…ent) This commit checkpoints the in-progress MoE quantization work before a larger refactor that deletes QuantLinear/QuantEmbedding in favour of storing every quantized weight (2D linear, 2D embedding, 3D MoE experts) as a QuantTensor nn.Parameter on the original host module. Included so far: - New olive/common/quant/patterns.py for re: prefix matching in modules_to_not_convert / overrides. - New olive/common/quant/tensor.py with QuantTensor wrapper subclass (_make_wrapper_subclass + __torch_function__ + __torch_dispatch__), supporting 2D and 3D layouts. - LayerWrapper.get_experts() / get_router() accessors. - 3D quantize helpers in olive/common/quant/utils.py. - moe field on OliveHfQuantizationConfig. - _process_model_before_weight_loading skips ModuleList(Expert) subtrees when moe=False, fixing a latent silent-quantization bug for Mixtral / PhiMoE / Qwen2/3-MoE. - Fused-3D MoE support in prepare_model / finalize via QuantTensor parameters; current save layout uses _qweight buffer suffixes — to be replaced in the upcoming refactor with the canonical <param>.qweight/.scales/.qzeros layout. - ModelBuilder raises NotImplementedError for Olive-quantized MoE checkpoints (Mobius is the intended consumer). - Test additions: test/common/quant/test_patterns.py, test/common/quant/test_tensor.py, TestOliveHfQuantizerMoE / TestRegexOverrides in test_hf_utils.py, test/passes/pytorch/test_quant_utils.py for flatten helper, test_olive_quantized_model_raises_for_moe in test_model_builder.py. - 294 tests pass; lintrunner clean (--skip PYLINT). Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
Switch Olive's native quantization representation to a single design: every quantized weight is an nn.Parameter(QuantTensor) on the original host nn.Linear / nn.Embedding / fused-3D experts module, with sibling <pname>_qweight / _scales / _qzeros buffers aliasing the QuantTensor's inner tensors. Save: a state-dict hook drops the QuantTensor parameter entry; the buffers already carry the data (plain Tensors, safetensors-friendly). Load: HF's loader fills the buffers natively via dotted paths; a post-load helper re-binds the QuantTensor inner refs to the freshly loaded buffer storage. QuantLinear / QuantEmbedding (olive/common/quant/nn.py) are kept only as ONNX-exportable wrappers used by make_export_compatible_quant; they are no longer the runtime representation. * New olive/common/quant/state_dict.py with install_quant_tensor_param and refresh_quant_tensor_refs helpers. * OliveHfQuantizer rewritten for the new layout (placeholder install before weight load + ref refresh after). * finalize() in passes/pytorch/quant_utils.py installs QuantTensor params via install_quant_tensor_param (replaces the old flatten_quant_tensor_params helper). * prepare_model skips modules whose weight is already a QuantTensor, so composing multiple Rtn passes on top of a partially quantized model works. * make_export_compatible_quant detects nn.Linear / nn.Embedding whose weight is a QuantTensor and swaps them with QuantLinear / QuantEmbedding wrappers before any model dtype casting, preserving the existing com.microsoft::MatMulNBits / com.microsoft::GatherBlockQuantized symbolic export path. * OliveQuantizedModel (model_builder.py) normalizes the new <dotted>.weight_qweight key layout back to the legacy <dotted>.qweight layout for the existing genai loader, and raises NotImplementedError for moe=True checkpoints. * Tests updated to assert against QuantTensor weight instead of isinstance(module, QuantLinear); legacy tie_quant_modules tests removed; new install_quant_tensor_param test suite added. Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
…er to N-D * Remove olive/common/quant/nn.py (QuantModule, QuantLinear, QuantEmbedding) entirely. The only purpose of those modules was ONNX export, which is now handled by reusing the existing QuantLinearNbit from olive/common/hf/quant.py and a new parallel QuantEmbeddingNbit (com.microsoft::GatherBlockQuantized symbolic) in the same file. * Add QuantLinearNbit.from_quant_tensor / QuantEmbeddingNbit.from_quant_tensor factories so make_export_compatible_quant can swap any nn.Linear / nn.Embedding whose weight is a QuantTensor into the export wrappers. * Generalize WeightQuantizer (get_num_groups, get_qparam_shape, find_qparams, quantize, dequantize, _reshape_tensor) and pack_to_uint8 / unpack_from_uint8 to operate on any N-D tensor; quantization is always along the last dim, leading dims are preserved. * Drop quantize_along_leading_dim / pack_to_uint8_along_last / unpack_from_uint8_along_last and the explicit 3D leading-dim loops in QuantTensor.from_float and _dequantize. * Delete test/common/quant/test_nn.py; add N-D tests for the generalized quantizer + pack helpers. Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
Instead of pre-walking the safetensors dict to rewrite ``<dotted>.weight_qweight`` -> ``<dotted>.qweight``, derive the destination attribute name inside ``set_tensor`` once we already know ``submodule`` is a ``QuantizedTensorModule``. Strip any of the known Olive buffer suffixes (``QWEIGHT_SUFFIX``, ``SCALES_SUFFIX``, ``QZEROS_SUFFIX`` from ``olive.common.quant.state_dict``) from the last path component to produce the bare ``qweight`` / ``scales`` / ``qzeros`` attribute that the genai ``QuantizedTensorModule`` expects. Also drops internal dev-iteration version labels from comments and docstrings in olive/common/quant and olive/passes/onnx. Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
Both QuantTensor's 2D layout and QuantLinearNbit's MatMulNBits
buffer layout pack the quantization axis as uint8 with the same
in-byte order (low nibble = elem[2j], high nibble = elem[2j+1] for
4-bit, etc.). They differ only in qweight rank: QuantTensor uses
(out, in / pack_factor), QuantLinearNbit uses
(out, n_blocks, blob_size) where n_blocks * blob_size ==
in / pack_factor. So the conversion is a pure reshape; the previous
unpack -> .t() -> from_tensors round-trip is unnecessary.
scales and qzeros buffer shapes also match exactly between the two
layouts, so they are copied as-is. For symmetric weights
(QuantTensor.qzeros is None) we fill the QuantLinearNbit.qzeros
buffer with the packed midq pattern that the contrib op expects.
Verified numerically: F.linear via QuantTensor and the
dequantize-from-buffers path through QuantLinearNbit produce
bit-identical outputs across {4,8} bits, {symmetric, asymmetric},
{groupwise, per-channel}.
Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
There was a problem hiding this comment.
lintrunner found more than 20 potential problems in the proposed changes. Check the Files changed tab for more details.
There was a problem hiding this comment.
Pull request overview
This PR refactors Olive’s PyTorch weight quantization implementation from dedicated quantized nn.Module wrappers (QuantLinear/QuantEmbedding) to a QuantTensor parameter representation, and extends the quantization pipeline to better handle MoE models (including fused 3D expert weights and configurable expert skipping).
Changes:
- Introduces
QuantTensor(atorch.Tensorwrapper subclass) + state-dict helpers to store quantized weights as aliased buffers while keeping eager execution working viaF.linear/F.embeddingdispatch. - Updates quantization utilities to (a) support N-D quantization along the last dimension and (b) add MoE-aware selection/skip logic plus fused-3D expert parameter quantization.
- Adds regex-aware pattern matching helpers for
overridesandmodules_to_not_convert, updates ONNX export compatibility and ModelBuilder loading behavior, and refreshes/expands unit tests accordingly.
Reviewed changes
Copilot reviewed 20 out of 20 changed files in this pull request and generated 1 comment.
Show a summary per file
| File | Description |
|---|---|
| test/passes/pytorch/test_rtn.py | Updates RTN tests to validate quantization via QuantTensor weights rather than quantized module classes. |
| test/passes/pytorch/test_gptq.py | Updates GPTQ tests to validate QuantTensor-backed weights. |
| test/passes/pytorch/test_quant_utils.py | Adds offline unit tests for install_quant_tensor_param (incl. fused 3D expert parameters). |
| test/passes/onnx/test_model_builder.py | Adds coverage ensuring Olive-quantized MoE checkpoints error out in ModelBuilder. |
| test/common/quant/test_utils.py | Updates assertions and adds tests for N-D quantization/packing helpers. |
| test/common/quant/test_tensor.py | Adds unit tests for QuantTensor semantics, dispatch behavior, dtype/device propagation, slicing, and ONNX export guards. |
| test/common/quant/test_patterns.py | Adds tests for regex/literal override + skip matching helpers. |
| test/common/quant/test_hf_utils.py | Updates HF quantizer tests for QuantTensor and adds MoE-specific behavior tests. |
| test/common/quant/test_nn.py | Removes tests for deleted QuantLinear/QuantEmbedding module wrappers. |
| olive/passes/pytorch/rtn.py | Enables MoE option exposure in RTN pass default config. |
| olive/passes/pytorch/quant_utils.py | Refactors prepare/finalize to install QuantTensor params, adds MoE expert handling and skip patterns. |
| olive/passes/onnx/model_builder.py | Maps Olive buffer suffixes to expected QuantizedTensorModule attrs and rejects MoE Olive checkpoints explicitly. |
| olive/common/quant/utils.py | Extends WeightQuantizer + pack/unpack to work on N-D tensors along the last dimension. |
| olive/common/quant/tensor.py | Adds QuantTensor implementation with eager dispatch + export guards + tensor transform propagation. |
| olive/common/quant/state_dict.py | Adds buffer naming + state-dict hook + install/refresh helpers for QuantTensor parameters. |
| olive/common/quant/patterns.py | Adds centralized literal/regex matching helpers for overrides and skip patterns. |
| olive/common/quant/hf_utils.py | Refactors HF quantizer to use QuantTensor placeholders and MoE-aware skipping; updates tie logic. |
| olive/common/hf/wrapper.py | Adds LayerWrapper.get_experts / get_router helpers for MoE structure discovery. |
| olive/common/hf/quant.py | Adds conversion from Olive QuantTensor weights to ONNX-exportable QuantLinearNbit/QuantEmbeddingNbit wrappers. |
| olive/common/quant/nn.py | Removes legacy quantized module wrapper implementations. |
Comments suppressed due to low confidence (1)
olive/common/quant/hf_utils.py:429
tie_quant_word_embeddingswill currently tie the output embedding to the input embedding as long as the input has aQuantTensorweight. This can unintentionally quantize the output embedding even whenlm_head=False(or leavedstwithout Olive’s state_dict hook), and it no longer validates shape/dtype compatibility before aliasing buffers. Please guard so tying only happens when both src and dst already have compatibleQuantTensorweights/buffers (and keep the previous shape/dtype assertions) to avoid silently changing quantization behavior or breaking save/load.
def tie_quant_word_embeddings(model: PreTrainedModel) -> None:
"""Tie the input and output embeddings when they share a quantized weight.
Both modules' ``weight`` ``nn.Parameter`` is set to the **same**
``nn.Parameter(QuantTensor)`` object, and the underlying
``weight_qweight`` / ``weight_scales`` / ``weight_qzeros`` buffers
are tied (aliased to the input embedding's buffers). This preserves
the standard HF tied-weights semantics for the quantized layout.
"""
src = model.get_input_embeddings()
dst = model.get_output_embeddings()
if src is None or dst is None:
return
# The input embedding owns the canonical QuantTensor + buffers.
src_param = src._parameters.get("weight")
if src_param is None or not isinstance(src_param.data, QuantTensor):
return
qname, sname, zname = buffer_names("weight")
# tie buffers
for n in (qname, sname, zname):
src_buf = src._buffers.get(n)
if src_buf is None:
continue
dst._buffers[n] = src_buf
# tie the QuantTensor parameter itself (same Python Parameter object,
# so both modules see the same .data and the same inner tensors).
dst._parameters["weight"] = src_param
The ORT contrib MatMulNBits / GatherBlockQuantized ops treat a missing zero-points input as midq for unsigned quantization, matching Olive's symmetric-quantization convention. Drop the synthetic packed-midq buffer that was previously emitted for symmetric weights and instead omit the input entirely: * QuantLinearNbit gains a has_qzeros flag (default True for back-compat); pack/from_tensors/from_quant_tensor pass through None as needed. * QuantLinearTorchFunction (TorchScript + dynamo) skips the qzeros input when None, inserting an empty placeholder only when g_idx must be positionally aligned. * QuantEmbeddingTorchFunction.symbolic gains the missing dynamo arg exposed by the new symmetric-embedding export path. Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
g_idx alongside a missing qzeros is not a real combination in Olive (GPTQ always produces qzeros), so skip the empty-tensor placeholder and just omit qzeros from the input list entirely. Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
There was a problem hiding this comment.
Pull request overview
Copilot reviewed 20 out of 20 changed files in this pull request and generated 2 comments.
Comments suppressed due to low confidence (1)
olive/common/hf/quant.py:175
- In the dynamo ONNX-export path, the placeholder inserted when
qzeros is None and g_idx is not Noneis created astorch.zeros(0, dtype=torch.uint8)with no device specified. Ifx/qweight/scalesare on CUDA, this can trigger device-mismatch errors during export graph construction. Create the placeholder onx.device(and ideally matchqweight's device) to keep tensor args consistent.
tensor_args.append(g_idx)
attrs = {
"K": in_features,
"N": out_features,
"bits": bits,
"block_size": group_size,
"accuracy_level": 4,
- Replace duplicated model walks in hf_utils._process_model_before_weight_loading and quant_utils.prepare_model with a shared iter_quant_targets helper that returns a list of (module, dotted_name, param_name, shape, dtype, device, kind) entries. Selection rules (lm_head/embeds/moe category flags, skip patterns, extra_skip_modules, already-quantized) live in one place. - QuantLinearNbit/QuantEmbeddingNbit: raise instead of synthesising a placeholder when g_idx is supplied alongside symmetric quantization. - tie_quant_word_embeddings: require both input and output embeddings to already be QuantTensor-backed with matching shape/dtype before tying. - Fix CodeQL mismatched-assignment false positives in QuantTensor dispatch (index args directly), fix ruff D205/D401/PLW0108/A002 warnings. Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
Neither attribute is read anywhere — the post-walk loop that produced the literal skip-name list was a leftover from before the refactor. The configured patterns already live on quantization_config. Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
Stop tagging modules with quant_info / quant_info_3d and stop branching on 2D vs 3D in the quantization passes. The quantizer already operates along the last dim regardless of rank, so a single iteration over parameters that carry a quant_info attribute is enough. - QuantTarget slims to (module, module_name, pname, full_name) with a .param property; the caller reads shape/dtype/device from the parameter directly. No more 'kind' field. - prepare_model writes target.param.quant_info in one pass — both 2D linear/embedding weights and fused experts parameters use the same code path. The quant_info_3d dict-stash on experts modules is gone. - finalize iterates every parameter that has quant_info, calls QuantTensor.from_float (already rank-generic), and installs in place. - GPTQ and AutoClip read module.weight.quant_info; module discovery uses hasattr(module.weight, 'quant_info') instead of a module-level attribute. - HF placeholder install pulls shape/dtype/device off target.param and the placeholder builder is now rank-generic. Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
The dataclass had four fields and a one-line .param property, used by two callers. A plain tuple is shorter, matches how the layerwise quantization loop already iterates over (module, pname, param, info) tuples, and removes the unused module_name field and dead for_each_target helper. QuantTarget remains as a type alias for the public signature. Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
* Filter fused-MoE params to 2D/3D ranks in iter_quant_targets so a 1D bias-like parameter fails at selection time instead of much later in finalize. * refresh_quant_tensor_refs: also check isinstance(param.data, QuantTensor) for forward-compat with future torch versions that may not return the underlying subclass from nn.Parameter(). * OliveHfQuantizationConfig: replace bare '# pylint: disable' with the specific super-init-not-called rule; use output.get(k) in to_dict. * finalize: log a warning when moe=True that the resulting checkpoint isn't directly ONNX-exportable via the Olive conversion pass — it must be consumed by an MoE-aware model builder. * Add regression test that _module_weight_has_quant_info ignores nn.LayerNorm / nn.Conv2d / unmarked nn.Linear (defends GPTQ/AutoClip discovery against future drift). Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
- __torch_dispatch__ clone/contiguous now forwards extra args/kwargs. - iter_quant_targets skips ALL nn.Embedding when embeds=False (positional / token-type embeddings like GPT-2 wpe are no longer silently quantized). - WeightQuantizer assertion message: 2/4/8-bit (was 4/8-bit). - tie_quant_word_embeddings: mark dst aliased buffers non-persistent so safetensors save emits one copy of qweight/scales/qzeros. - finalize: group selected params by host module so each module's to(device)/to(cpu) cycle runs once for MoE experts modules carrying multiple 3D weight params. - state_dict: add ensure_state_dict_hooks(model) defensive walk that installs the save hook on every host module that owns a QuantTensor parameter (idempotent). - Add test_forward_parity.py: bit-exact eager parity for full models (embedding + linears) and fused 3D MoE forwards, plus end-to-end ONNX export -> onnxruntime numerical parity for Olive-quantized nn.Linear via make_export_compatible_quant. - Enable pylint by adding file-level protected-access disables on the files that intentionally touch nn.Module._parameters. Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
Describe your changes
Checklist before requesting a review
lintrunner -a(Optional) Issue link