[WIP] Generic quantization support for PEFT methods#3117
[WIP] Generic quantization support for PEFT methods#3117BenjaminBossan wants to merge 16 commits intohuggingface:mainfrom
Conversation
Right now, if a new PEFT method wants to add support for quantized layers, it requires a significant amount of worker. Notably, the method needs to implement dedicated layer classes for each quantization method (e.g. one class for bnb 4bit, one for bnb 8bit, one for AWQ, ...). The result of that is that, at the moment, most PEFT methods don't support any, or only very few, quantization methods, even though the amount of actual logic required to support these methods is quite contained. This PR is a suggestion of how to solve the issue. If this approach is accepted, with a few extra lines, we should be able to support all quantization methods in all PEFT methods. The PR is not in a finished state, more to follow. Right now, only VeRA and MiSS have been updated as a POC.
|
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
no reason why it would be nn.Linear instead of nn.Module like the other PEFT methods
+ some docstring cleanups
There was a problem hiding this comment.
Pull request overview
This PR introduces a generic “quantization backend” abstraction so PEFT tuner layers can support multiple quantization frameworks without needing per-backend layer subclasses, and wires it into VeRA and MiSS as an initial proof-of-concept.
Changes:
- Add
Quantizationbackendimplementations + backend resolution (resolve_quantization_backend) and a helper to surface backend info in module repr. - Extend
BaseTunerLayerwithget_base_weight/set_base_weightto centralize dequantize/requantize handling for merge/unmerge. - Surface quantization backend info in
get_layer_status()/get_model_status()and add tests (including a new quantization matrix test file).
Reviewed changes
Copilot reviewed 11 out of 11 changed files in this pull request and generated 5 comments.
Show a summary per file
| File | Description |
|---|---|
tests/test_tuners_utils.py |
Adds coverage for layer/model status reporting of quantization_backend. |
tests/test_quantization.py |
Adds a PEFT-method × quant-backend matrix test suite (bnb/torchao loaders and core behavioral checks). |
src/peft/utils/quantization_utils.py |
Introduces quantization backend classes, backend resolution logic, and repr helper. |
src/peft/utils/__init__.py |
Exports the new quantization helpers from peft.utils. |
src/peft/tuners/vera/model.py |
Updates VeRA injection to remove bnb-specific module creation and to forward torchao merge metadata. |
src/peft/tuners/vera/layer.py |
Hooks VeRA layers into the new backend mechanism for merge/unmerge and forward safety cloning. |
src/peft/tuners/vera/bnb.py |
Removes VeRA’s dedicated bitsandbytes layer implementations (intended to be superseded by generic backend support). |
src/peft/tuners/tuners_utils.py |
Adds quantization_backend attribute and centralized base-weight getters/setters to BaseTunerLayer. |
src/peft/tuners/miss/model.py |
Forwards torchao merge metadata during MiSS injection. |
src/peft/tuners/miss/layer.py |
Hooks MiSS layers into the new backend mechanism for merge/unmerge and forward safety cloning. |
src/peft/peft_model.py |
Extends tuner status dataclasses + status functions to report quantization backend consistency. |
Comments suppressed due to low confidence (1)
src/peft/tuners/vera/model.py:259
- This PR removes the dedicated VeRA bitsandbytes layer implementations, but
peft.tuners.verastill has lazy attribute resolution forLinear8bitLt/Linear4bitviafrom .bnb import ...(seesrc/peft/tuners/vera/__init__.py). Withvera/bnb.pydeleted, those imports will raise at runtime and existing tests/imports that referencepeft.tuners.vera.Linear8bitLtwill break. Please update the VeRA package exports to match the new generic quantization approach (either provide compatible aliases or remove the lazy attributes).
@staticmethod
def _create_new_module(vera_config, vera_A, vera_B, adapter_name, target, **kwargs):
bias = kwargs.pop("bias", False)
if isinstance(target, BaseTunerLayer):
target_base_layer = target.get_base_layer()
else:
target_base_layer = target
if isinstance(target_base_layer, torch.nn.Linear):
if kwargs["fan_in_fan_out"]:
warnings.warn(
"fan_in_fan_out is set to True but the target module is `torch.nn.Linear`. "
"Setting fan_in_fan_out to False."
)
kwargs["fan_in_fan_out"] = vera_config.fan_in_fan_out = False
elif isinstance(target_base_layer, Conv1D):
kwargs["is_target_conv_1d_layer"] = True
if not kwargs["fan_in_fan_out"]:
warnings.warn(
"fan_in_fan_out is set to False but the target module is `Conv1D`. Setting fan_in_fan_out to True."
)
kwargs["fan_in_fan_out"] = vera_config.fan_in_fan_out = True
else:
raise ValueError(
f"Target module {target} is not supported. Currently, only the following modules are supported: "
"`torch.nn.Linear`, `transformers.pytorch_utils.Conv1D`."
)
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
Also: skip bnb4bit + CPU
Conv not fully fleshed out
Forgot to update some tests
| rotated_weight = torch.transpose(rotated_weight, 0, 1) | ||
|
|
||
| scaled_rotated_weight = rotated_weight * boft_scale | ||
| x_rotated = x @ boft_rotation |
There was a problem hiding this comment.
This is a reformulation of the forward path of BOFT that avoids using the base layer weight directly. This was necessary because calling torch.mm(boft_rotation, orig_weight) can fail with quantized weights. Instead, we should make a forward pass and let the quantized layer handle the details. I ran the BOFT tests with the old and the new implementation and added an assert that they are identical (up to precision).
Regarding runtime, I checked the MetaMath benchmark and got 147 sec for 250 steps (116 sec for 1 eval run) using main branch, and 138 sec (108 sec) using the code from this branch. So the new code seems to be on par or possibly slightly faster than the old one.
|
This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread. |
|
not stale |
Problem
Right now, if a new PEFT method wants to add support for quantized layers, it requires a significant amount of work. Notably, the method needs to implement dedicated layer classes for each quantization method (e.g. one class for bnb 4bit, one for bnb 8bit, one for AWQ, ...). These classes typically are >90% boilerplate and the actual difference between implementations of these classes is minimal.
The result of that is that, at the moment, most PEFT methods don't support any, or only very few, quantization methods, even though the amount of actual logic required to support these methods is relatively small.
Suggested solution
This PR is a suggestion of how to solve the issue. If this approach is accepted, with a few extra lines, we should be able to support all quantization methods in all PEFT methods. The general approach is to add an attribute to each PEFT layer,
self.quantization_backend, which supports these methods:get_base_weightset_base_weightWhen the PEFT layers use these methods to access and write to the base layer weight, and if the weight is quantized, the new classes will deal with that correctly. This means that we no longer need a dedicated layer class to deal with quantized layers, the normal layer class will do. E.g. for MiSS, the normal
miss.Linearclass can deal with bnb layers, there is no need to add amiss/bnb.pymodule with dedicated layers.A few rewrites in the existing PEFT methods are required to support this new quantization backend class, but the amount of total code needed for that is considerably smaller than adding new classes for each quantization method.
Furthermore, these quantization backend classes are agnostic with regard to the PEFT method. Therefore, with M PEFT methods and N quantization methods, we no longer need MxN implementations to support quantization but only M+N.
Migration
For LoRA, we have already implemented the layer classes for each supported quantization method. For the sake of consistency, it could still make sense to migrate LoRA to the new approach if it's accepted. This needs to be accompanied by detailed regression testing to ensure that everything keeps working. I would only suggest to deprecate and remove abandoned quantization methods (perhaps for a v1.0 release).
The bigger issue, however, is that packages that depend on PEFT may break with this change. As an example, if they detect quantized layers via
isinstancechecks, those would break as all layers would just be normallora.Linear,lora.Conv2detc. The approach here would most likely involve deprecating the import of these classes. I think it's also possible to "cheat"isinstanceand pretend like there is inheritance when there isn't but I'd like to avoid that.Anyway, this is out of scope of this PR and will be addressed in the future.
Scope
Updating all PEFT methods is too much for a single PR. This PR focuses on only three PEFT methods for now:
forwardstep. Similar changes may be required for other methods too.