From f7917d80a3cf768da309357ece977b97b111f7f7 Mon Sep 17 00:00:00 2001 From: "Tessa (livepeer-tessa)" Date: Sun, 12 Apr 2026 18:25:10 +0000 Subject: [PATCH] fix(lora): validate LoRA dimensions against model at load time (#922) Add dimension validation in parse_lora_weights() so a LoRA trained for a different model size (e.g. Wan2.1-5B, in_features=5120) is rejected with a user-friendly ValueError when loaded into the 1.3B model (in_features=1536), rather than loading silently and crashing 150+ times at inference. Before: mat1/mat2 shape mismatch RuntimeError deep in peft/torch at inference After: ValueError at load time naming the layer, expected vs actual dims, and a plain-language hint about model architecture mismatch Also adds test_lora_dimension_validation.py covering: - compatible LoRA loads without error - 5B LoRA on 1.3B model raises ValueError - error message is user-friendly (names layer + dimensions) - out_features mismatch is also caught - 5B LoRA on 5B model is fine Signed-off-by: Tessa (livepeer-tessa) --- src/scope/core/pipelines/wan2_1/lora/utils.py | 17 ++++ tests/test_lora_dimension_validation.py | 87 +++++++++++++++++++ 2 files changed, 104 insertions(+) create mode 100644 tests/test_lora_dimension_validation.py diff --git a/src/scope/core/pipelines/wan2_1/lora/utils.py b/src/scope/core/pipelines/wan2_1/lora/utils.py index e34f4b63d..45ff50c1e 100644 --- a/src/scope/core/pipelines/wan2_1/lora/utils.py +++ b/src/scope/core/pipelines/wan2_1/lora/utils.py @@ -410,6 +410,23 @@ def parse_lora_weights( f"parse_lora_weights: Matched base_key='{base_key}' -> model_key='{model_key}'" ) + # Validate LoRA dimensions against the model weight before injecting. + # lora_A shape: [rank, in_features] — in_features must match model weight dim 1 + # lora_B shape: [out_features, rank] — out_features must match model weight dim 0 + # (model weight shape is [out_features, in_features] for nn.Linear) + model_weight = model_state.get(model_key) + if model_weight is not None and lora_A.ndim == 2 and lora_B.ndim == 2: + lora_in = lora_A.shape[1] # LoRA expects this input dimension + lora_out = lora_B.shape[0] # LoRA expects this output dimension + model_out, model_in = model_weight.shape[0], model_weight.shape[1] + if lora_in != model_in or lora_out != model_out: + raise ValueError( + f"LoRA dimension mismatch at layer '{base_key}': " + f"LoRA expects ({lora_out}×{lora_in}) but model layer is ({model_out}×{model_in}). " + f"This LoRA was likely trained for a different model size (e.g. Wan2.1-5B vs 1.3B). " + f"Please use a LoRA that matches the loaded model architecture." + ) + # Extract alpha and rank alpha = None if alpha_key and alpha_key in lora_state: diff --git a/tests/test_lora_dimension_validation.py b/tests/test_lora_dimension_validation.py new file mode 100644 index 000000000..6f4c0e12c --- /dev/null +++ b/tests/test_lora_dimension_validation.py @@ -0,0 +1,87 @@ +"""Tests for LoRA dimension validation in parse_lora_weights. + +Regression test for issue #922: a LoRA trained for Wan2.1-5B (in_features=5120) +was silently loaded into the Wan2.1-1.3B model (in_features=1536) and only +failed 156 times at inference time with an inscrutable RuntimeError. +""" + +import pytest +import torch + +from scope.core.pipelines.wan2_1.lora.utils import parse_lora_weights + + +def _make_model_state(in_features: int, out_features: int = 256) -> dict: + """Minimal model state dict with one linear layer.""" + return { + "blocks.0.self_attn.q.weight": torch.zeros(out_features, in_features), + } + + +def _make_lora_state(rank: int, in_features: int, out_features: int = 256) -> dict: + """Minimal PEFT-format LoRA state targeting the same layer.""" + return { + "diffusion_model.blocks.0.self_attn.q.lora_A.weight": torch.zeros(rank, in_features), + "diffusion_model.blocks.0.self_attn.q.lora_B.weight": torch.zeros(out_features, rank), + } + + +class TestLoRADimensionValidation: + """Verify parse_lora_weights raises a clear error on dimension mismatch.""" + + def test_compatible_lora_loads_successfully(self): + """LoRA matching the model's dimensions should parse without error.""" + model_state = _make_model_state(in_features=1536) + lora_state = _make_lora_state(rank=32, in_features=1536) + + mapping = parse_lora_weights(lora_state, model_state) + + assert len(mapping) == 1 + key = "blocks.0.self_attn.q.weight" + assert key in mapping + assert mapping[key]["rank"] == 32 + + def test_incompatible_lora_raises_value_error(self): + """LoRA trained for 5B (in_features=5120) must not silently load into 1.3B (in_features=1536).""" + model_state = _make_model_state(in_features=1536) # 1.3B model + lora_state = _make_lora_state(rank=32, in_features=5120) # 5B LoRA + + with pytest.raises(ValueError, match="LoRA dimension mismatch"): + parse_lora_weights(lora_state, model_state) + + def test_error_message_is_user_friendly(self): + """The error message should name the layer and the dimension sizes.""" + model_state = _make_model_state(in_features=1536) + lora_state = _make_lora_state(rank=32, in_features=5120) + + with pytest.raises(ValueError) as exc_info: + parse_lora_weights(lora_state, model_state) + + msg = str(exc_info.value) + assert "blocks.0.self_attn.q" in msg, "Layer name should appear in error" + assert "5120" in msg, "LoRA in_features should appear in error" + assert "1536" in msg, "Model in_features should appear in error" + assert "model size" in msg.lower() or "architecture" in msg.lower(), ( + "Error should hint at model size mismatch" + ) + + def test_out_features_mismatch_also_caught(self): + """LoRA with wrong output dimension should also be rejected.""" + model_state = _make_model_state(in_features=1536, out_features=256) + # LoRA with matching in_features but wrong out_features + lora_state = { + "diffusion_model.blocks.0.self_attn.q.lora_A.weight": torch.zeros(32, 1536), + "diffusion_model.blocks.0.self_attn.q.lora_B.weight": torch.zeros(512, 32), # wrong + } + + with pytest.raises(ValueError, match="LoRA dimension mismatch"): + parse_lora_weights(lora_state, model_state) + + def test_compatible_5b_lora_on_5b_model(self): + """LoRA trained for 5B on a 5B model should load fine.""" + model_state = _make_model_state(in_features=5120, out_features=5120) + lora_state = _make_lora_state(rank=32, in_features=5120, out_features=5120) + + mapping = parse_lora_weights(lora_state, model_state) + + assert len(mapping) == 1