Skip to content

[WIP] Refactor pytorch quant modules, Support MoE#2466

Closed
jambayk wants to merge 13 commits into
mainfrom
jambayk/moe-quant
Closed

[WIP] Refactor pytorch quant modules, Support MoE#2466
jambayk wants to merge 13 commits into
mainfrom
jambayk/moe-quant

Conversation

@jambayk
Copy link
Copy Markdown
Contributor

@jambayk jambayk commented May 16, 2026

Describe your changes

Checklist before requesting a review

  • Add unit tests for this change.
  • Make sure all tests can pass.
  • Update documents if necessary.
  • Lint and apply fixes to your code by running lintrunner -a
  • Is this a user-facing change? If yes, give a description of this change to be included in the release notes.

(Optional) Issue link

Copilot AI and others added 5 commits May 15, 2026 22:01
…ent)

This commit checkpoints the in-progress MoE quantization work before a
larger refactor that deletes QuantLinear/QuantEmbedding in favour of
storing every quantized weight (2D linear, 2D embedding, 3D MoE experts)
as a QuantTensor nn.Parameter on the original host module.

Included so far:
- New olive/common/quant/patterns.py for re: prefix matching in
  modules_to_not_convert / overrides.
- New olive/common/quant/tensor.py with QuantTensor wrapper subclass
  (_make_wrapper_subclass + __torch_function__ + __torch_dispatch__),
  supporting 2D and 3D layouts.
- LayerWrapper.get_experts() / get_router() accessors.
- 3D quantize helpers in olive/common/quant/utils.py.
- moe field on OliveHfQuantizationConfig.
- _process_model_before_weight_loading skips ModuleList(Expert) subtrees
  when moe=False, fixing a latent silent-quantization bug for
  Mixtral / PhiMoE / Qwen2/3-MoE.
- Fused-3D MoE support in prepare_model / finalize via QuantTensor
  parameters; current save layout uses _qweight buffer suffixes — to be
  replaced in the upcoming refactor with the canonical
  <param>.qweight/.scales/.qzeros layout.
- ModelBuilder raises NotImplementedError for Olive-quantized MoE
  checkpoints (Mobius is the intended consumer).
- Test additions:
  test/common/quant/test_patterns.py, test/common/quant/test_tensor.py,
  TestOliveHfQuantizerMoE / TestRegexOverrides in test_hf_utils.py,
  test/passes/pytorch/test_quant_utils.py for flatten helper,
  test_olive_quantized_model_raises_for_moe in test_model_builder.py.
- 294 tests pass; lintrunner clean (--skip PYLINT).

Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
Switch Olive's native quantization representation to a single design:
every quantized weight is an nn.Parameter(QuantTensor) on the original
host nn.Linear / nn.Embedding / fused-3D experts module, with sibling
<pname>_qweight / _scales / _qzeros buffers aliasing the QuantTensor's
inner tensors.

Save: a state-dict hook drops the QuantTensor parameter entry; the
buffers already carry the data (plain Tensors, safetensors-friendly).
Load: HF's loader fills the buffers natively via dotted paths; a
post-load helper re-binds the QuantTensor inner refs to the freshly
loaded buffer storage.

QuantLinear / QuantEmbedding (olive/common/quant/nn.py) are kept only
as ONNX-exportable wrappers used by make_export_compatible_quant; they
are no longer the runtime representation.

* New olive/common/quant/state_dict.py with install_quant_tensor_param
  and refresh_quant_tensor_refs helpers.
* OliveHfQuantizer rewritten for the new layout (placeholder install
  before weight load + ref refresh after).
* finalize() in passes/pytorch/quant_utils.py installs QuantTensor
  params via install_quant_tensor_param (replaces the old
  flatten_quant_tensor_params helper).
* prepare_model skips modules whose weight is already a QuantTensor,
  so composing multiple Rtn passes on top of a partially quantized
  model works.
* make_export_compatible_quant detects nn.Linear / nn.Embedding whose
  weight is a QuantTensor and swaps them with QuantLinear /
  QuantEmbedding wrappers before any model dtype casting, preserving
  the existing com.microsoft::MatMulNBits /
  com.microsoft::GatherBlockQuantized symbolic export path.
* OliveQuantizedModel (model_builder.py) normalizes the new
  <dotted>.weight_qweight key layout back to the legacy
  <dotted>.qweight layout for the existing genai loader, and raises
  NotImplementedError for moe=True checkpoints.
* Tests updated to assert against QuantTensor weight instead of
  isinstance(module, QuantLinear); legacy tie_quant_modules tests
  removed; new install_quant_tensor_param test suite added.

Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
…er to N-D

* Remove olive/common/quant/nn.py (QuantModule, QuantLinear,
  QuantEmbedding) entirely. The only purpose of those modules was
  ONNX export, which is now handled by reusing the existing
  QuantLinearNbit from olive/common/hf/quant.py and a new
  parallel QuantEmbeddingNbit (com.microsoft::GatherBlockQuantized
  symbolic) in the same file.
* Add QuantLinearNbit.from_quant_tensor / QuantEmbeddingNbit.from_quant_tensor
  factories so make_export_compatible_quant can swap any nn.Linear /
  nn.Embedding whose weight is a QuantTensor into the export wrappers.
* Generalize WeightQuantizer (get_num_groups, get_qparam_shape,
  find_qparams, quantize, dequantize, _reshape_tensor) and
  pack_to_uint8 / unpack_from_uint8 to operate on any N-D tensor;
  quantization is always along the last dim, leading dims are
  preserved.
* Drop quantize_along_leading_dim / pack_to_uint8_along_last /
  unpack_from_uint8_along_last and the explicit 3D leading-dim loops
  in QuantTensor.from_float and _dequantize.
* Delete test/common/quant/test_nn.py; add N-D tests for the
  generalized quantizer + pack helpers.

Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
Instead of pre-walking the safetensors dict to rewrite
``<dotted>.weight_qweight`` -> ``<dotted>.qweight``, derive the
destination attribute name inside ``set_tensor`` once we already
know ``submodule`` is a ``QuantizedTensorModule``. Strip any of the
known Olive buffer suffixes (``QWEIGHT_SUFFIX``, ``SCALES_SUFFIX``,
``QZEROS_SUFFIX`` from ``olive.common.quant.state_dict``) from the
last path component to produce the bare ``qweight`` / ``scales`` /
``qzeros`` attribute that the genai ``QuantizedTensorModule``
expects.

Also drops internal dev-iteration version labels from comments and
docstrings in olive/common/quant and olive/passes/onnx.

Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
Both QuantTensor's 2D layout and QuantLinearNbit's MatMulNBits
buffer layout pack the quantization axis as uint8 with the same
in-byte order (low nibble = elem[2j], high nibble = elem[2j+1] for
4-bit, etc.). They differ only in qweight rank: QuantTensor uses
(out, in / pack_factor), QuantLinearNbit uses
(out, n_blocks, blob_size) where n_blocks * blob_size ==
in / pack_factor. So the conversion is a pure reshape; the previous
unpack -> .t() -> from_tensors round-trip is unnecessary.

scales and qzeros buffer shapes also match exactly between the two
layouts, so they are copied as-is. For symmetric weights
(QuantTensor.qzeros is None) we fill the QuantLinearNbit.qzeros
buffer with the packed midq pattern that the contrib op expects.

Verified numerically: F.linear via QuantTensor and the
dequantize-from-buffers path through QuantLinearNbit produce
bit-identical outputs across {4,8} bits, {symmetric, asymmetric},
{groupwise, per-channel}.

Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
@jambayk jambayk marked this pull request as ready for review May 16, 2026 00:09
Copilot AI review requested due to automatic review settings May 16, 2026 00:09
Comment thread olive/common/quant/tensor.py Fixed
Comment thread olive/common/quant/tensor.py Fixed
Comment thread olive/common/quant/tensor.py Fixed
Copy link
Copy Markdown
Contributor

@github-advanced-security github-advanced-security AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

lintrunner found more than 20 potential problems in the proposed changes. Check the Files changed tab for more details.

Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This PR refactors Olive’s PyTorch weight quantization implementation from dedicated quantized nn.Module wrappers (QuantLinear/QuantEmbedding) to a QuantTensor parameter representation, and extends the quantization pipeline to better handle MoE models (including fused 3D expert weights and configurable expert skipping).

Changes:

  • Introduces QuantTensor (a torch.Tensor wrapper subclass) + state-dict helpers to store quantized weights as aliased buffers while keeping eager execution working via F.linear/F.embedding dispatch.
  • Updates quantization utilities to (a) support N-D quantization along the last dimension and (b) add MoE-aware selection/skip logic plus fused-3D expert parameter quantization.
  • Adds regex-aware pattern matching helpers for overrides and modules_to_not_convert, updates ONNX export compatibility and ModelBuilder loading behavior, and refreshes/expands unit tests accordingly.

Reviewed changes

Copilot reviewed 20 out of 20 changed files in this pull request and generated 1 comment.

Show a summary per file
File Description
test/passes/pytorch/test_rtn.py Updates RTN tests to validate quantization via QuantTensor weights rather than quantized module classes.
test/passes/pytorch/test_gptq.py Updates GPTQ tests to validate QuantTensor-backed weights.
test/passes/pytorch/test_quant_utils.py Adds offline unit tests for install_quant_tensor_param (incl. fused 3D expert parameters).
test/passes/onnx/test_model_builder.py Adds coverage ensuring Olive-quantized MoE checkpoints error out in ModelBuilder.
test/common/quant/test_utils.py Updates assertions and adds tests for N-D quantization/packing helpers.
test/common/quant/test_tensor.py Adds unit tests for QuantTensor semantics, dispatch behavior, dtype/device propagation, slicing, and ONNX export guards.
test/common/quant/test_patterns.py Adds tests for regex/literal override + skip matching helpers.
test/common/quant/test_hf_utils.py Updates HF quantizer tests for QuantTensor and adds MoE-specific behavior tests.
test/common/quant/test_nn.py Removes tests for deleted QuantLinear/QuantEmbedding module wrappers.
olive/passes/pytorch/rtn.py Enables MoE option exposure in RTN pass default config.
olive/passes/pytorch/quant_utils.py Refactors prepare/finalize to install QuantTensor params, adds MoE expert handling and skip patterns.
olive/passes/onnx/model_builder.py Maps Olive buffer suffixes to expected QuantizedTensorModule attrs and rejects MoE Olive checkpoints explicitly.
olive/common/quant/utils.py Extends WeightQuantizer + pack/unpack to work on N-D tensors along the last dimension.
olive/common/quant/tensor.py Adds QuantTensor implementation with eager dispatch + export guards + tensor transform propagation.
olive/common/quant/state_dict.py Adds buffer naming + state-dict hook + install/refresh helpers for QuantTensor parameters.
olive/common/quant/patterns.py Adds centralized literal/regex matching helpers for overrides and skip patterns.
olive/common/quant/hf_utils.py Refactors HF quantizer to use QuantTensor placeholders and MoE-aware skipping; updates tie logic.
olive/common/hf/wrapper.py Adds LayerWrapper.get_experts / get_router helpers for MoE structure discovery.
olive/common/hf/quant.py Adds conversion from Olive QuantTensor weights to ONNX-exportable QuantLinearNbit/QuantEmbeddingNbit wrappers.
olive/common/quant/nn.py Removes legacy quantized module wrapper implementations.
Comments suppressed due to low confidence (1)

olive/common/quant/hf_utils.py:429

  • tie_quant_word_embeddings will currently tie the output embedding to the input embedding as long as the input has a QuantTensor weight. This can unintentionally quantize the output embedding even when lm_head=False (or leave dst without Olive’s state_dict hook), and it no longer validates shape/dtype compatibility before aliasing buffers. Please guard so tying only happens when both src and dst already have compatible QuantTensor weights/buffers (and keep the previous shape/dtype assertions) to avoid silently changing quantization behavior or breaking save/load.
def tie_quant_word_embeddings(model: PreTrainedModel) -> None:
    """Tie the input and output embeddings when they share a quantized weight.

    Both modules' ``weight`` ``nn.Parameter`` is set to the **same**
    ``nn.Parameter(QuantTensor)`` object, and the underlying
    ``weight_qweight`` / ``weight_scales`` / ``weight_qzeros`` buffers
    are tied (aliased to the input embedding's buffers). This preserves
    the standard HF tied-weights semantics for the quantized layout.
    """
    src = model.get_input_embeddings()
    dst = model.get_output_embeddings()
    if src is None or dst is None:
        return

    # The input embedding owns the canonical QuantTensor + buffers.
    src_param = src._parameters.get("weight")
    if src_param is None or not isinstance(src_param.data, QuantTensor):
        return

    qname, sname, zname = buffer_names("weight")
    # tie buffers
    for n in (qname, sname, zname):
        src_buf = src._buffers.get(n)
        if src_buf is None:
            continue
        dst._buffers[n] = src_buf

    # tie the QuantTensor parameter itself (same Python Parameter object,
    # so both modules see the same .data and the same inner tensors).
    dst._parameters["weight"] = src_param

Comment thread olive/passes/pytorch/quant_utils.py Outdated
The ORT contrib MatMulNBits / GatherBlockQuantized ops treat a missing
zero-points input as midq for unsigned quantization, matching Olive's
symmetric-quantization convention. Drop the synthetic packed-midq
buffer that was previously emitted for symmetric weights and instead
omit the input entirely:

* QuantLinearNbit gains a has_qzeros flag (default True for back-compat);
  pack/from_tensors/from_quant_tensor pass through None as needed.
* QuantLinearTorchFunction (TorchScript + dynamo) skips the qzeros input
  when None, inserting an empty placeholder only when g_idx must be
  positionally aligned.
* QuantEmbeddingTorchFunction.symbolic gains the missing dynamo arg
  exposed by the new symmetric-embedding export path.

Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
g_idx alongside a missing qzeros is not a real combination in Olive
(GPTQ always produces qzeros), so skip the empty-tensor placeholder
and just omit qzeros from the input list entirely.

Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Copilot reviewed 20 out of 20 changed files in this pull request and generated 2 comments.

Comments suppressed due to low confidence (1)

olive/common/hf/quant.py:175

  • In the dynamo ONNX-export path, the placeholder inserted when qzeros is None and g_idx is not None is created as torch.zeros(0, dtype=torch.uint8) with no device specified. If x/qweight/scales are on CUDA, this can trigger device-mismatch errors during export graph construction. Create the placeholder on x.device (and ideally match qweight's device) to keep tensor args consistent.
                    tensor_args.append(g_idx)
                attrs = {
                    "K": in_features,
                    "N": out_features,
                    "bits": bits,
                    "block_size": group_size,
                    "accuracy_level": 4,

Comment thread olive/common/quant/tensor.py Outdated
Comment thread olive/passes/pytorch/quant_utils.py Outdated
- Replace duplicated model walks in hf_utils._process_model_before_weight_loading
  and quant_utils.prepare_model with a shared iter_quant_targets helper that
  returns a list of (module, dotted_name, param_name, shape, dtype, device,
  kind) entries. Selection rules (lm_head/embeds/moe category flags, skip
  patterns, extra_skip_modules, already-quantized) live in one place.
- QuantLinearNbit/QuantEmbeddingNbit: raise instead of synthesising a
  placeholder when g_idx is supplied alongside symmetric quantization.
- tie_quant_word_embeddings: require both input and output embeddings to
  already be QuantTensor-backed with matching shape/dtype before tying.
- Fix CodeQL mismatched-assignment false positives in QuantTensor dispatch
  (index args directly), fix ruff D205/D401/PLW0108/A002 warnings.

Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
Neither attribute is read anywhere — the post-walk loop that produced
the literal skip-name list was a leftover from before the refactor.
The configured patterns already live on quantization_config.

Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Copilot reviewed 22 out of 22 changed files in this pull request and generated 3 comments.

Comment thread olive/common/quant/selection.py Outdated
Comment thread olive/common/quant/hf_utils.py
Comment thread olive/common/quant/utils.py
Copilot AI added 2 commits May 16, 2026 01:10
Stop tagging modules with quant_info / quant_info_3d and stop branching
on 2D vs 3D in the quantization passes. The quantizer already operates
along the last dim regardless of rank, so a single iteration over
parameters that carry a quant_info attribute is enough.

- QuantTarget slims to (module, module_name, pname, full_name) with a
  .param property; the caller reads shape/dtype/device from the
  parameter directly. No more 'kind' field.
- prepare_model writes target.param.quant_info in one pass — both 2D
  linear/embedding weights and fused experts parameters use the same
  code path. The quant_info_3d dict-stash on experts modules is gone.
- finalize iterates every parameter that has quant_info, calls
  QuantTensor.from_float (already rank-generic), and installs in place.
- GPTQ and AutoClip read module.weight.quant_info; module discovery
  uses hasattr(module.weight, 'quant_info') instead of a module-level
  attribute.
- HF placeholder install pulls shape/dtype/device off target.param and
  the placeholder builder is now rank-generic.

Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
The dataclass had four fields and a one-line .param property, used by
two callers. A plain tuple is shorter, matches how the layerwise
quantization loop already iterates over (module, pname, param, info)
tuples, and removes the unused module_name field and dead
for_each_target helper.

QuantTarget remains as a type alias for the public signature.

Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Copilot reviewed 24 out of 24 changed files in this pull request and generated 3 comments.

Comment thread olive/common/quant/hf_utils.py Outdated
Comment thread olive/passes/pytorch/quant_utils.py Outdated
Comment thread olive/common/quant/tensor.py
Copilot AI added 2 commits May 16, 2026 01:34
* Filter fused-MoE params to 2D/3D ranks in iter_quant_targets so a
  1D bias-like parameter fails at selection time instead of much later
  in finalize.
* refresh_quant_tensor_refs: also check isinstance(param.data,
  QuantTensor) for forward-compat with future torch versions that may
  not return the underlying subclass from nn.Parameter().
* OliveHfQuantizationConfig: replace bare '# pylint: disable' with the
  specific super-init-not-called rule; use output.get(k) in to_dict.
* finalize: log a warning when moe=True that the resulting checkpoint
  isn't directly ONNX-exportable via the Olive conversion pass — it
  must be consumed by an MoE-aware model builder.
* Add regression test that _module_weight_has_quant_info ignores
  nn.LayerNorm / nn.Conv2d / unmarked nn.Linear (defends GPTQ/AutoClip
  discovery against future drift).

Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
- __torch_dispatch__ clone/contiguous now forwards extra args/kwargs.
- iter_quant_targets skips ALL nn.Embedding when embeds=False (positional
  / token-type embeddings like GPT-2 wpe are no longer silently quantized).
- WeightQuantizer assertion message: 2/4/8-bit (was 4/8-bit).
- tie_quant_word_embeddings: mark dst aliased buffers non-persistent so
  safetensors save emits one copy of qweight/scales/qzeros.
- finalize: group selected params by host module so each module's
  to(device)/to(cpu) cycle runs once for MoE experts modules carrying
  multiple 3D weight params.
- state_dict: add ensure_state_dict_hooks(model) defensive walk that
  installs the save hook on every host module that owns a QuantTensor
  parameter (idempotent).
- Add test_forward_parity.py: bit-exact eager parity for full models
  (embedding + linears) and fused 3D MoE forwards, plus end-to-end
  ONNX export -> onnxruntime numerical parity for Olive-quantized
  nn.Linear via make_export_compatible_quant.
- Enable pylint by adding file-level protected-access disables on the
  files that intentionally touch nn.Module._parameters.

Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
@jambayk jambayk closed this May 18, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants