diff --git a/src/art/megatron/model_support/handlers/qwen3_moe.py b/src/art/megatron/model_support/handlers/qwen3_moe.py index 45656f774..61aa8fcac 100644 --- a/src/art/megatron/model_support/handlers/qwen3_moe.py +++ b/src/art/megatron/model_support/handlers/qwen3_moe.py @@ -1,5 +1,8 @@ +import re from typing import Any, Sequence +import torch + from art.megatron.model_support.handlers.default_dense import DefaultMoeHandler from art.megatron.model_support.handlers.qwen3_common import ( install_qwen3_text_preprocess_patch, @@ -11,12 +14,28 @@ "alltoall_dispatch_preprocess", "deepep_permute_restore", ) +_QWEN3_FUSED_MOE_KEY_RE = re.compile( + r"^(?P.*\.mlp\.experts)\." + r"(?:(?Pbase_layer)\.)?(?Plora_[AB])\.weight$" +) +_QWEN3_EXPERT_MOE_KEY_RE = re.compile( + r"^.*\.mlp\.experts\.\d+\." + r"(?:gate_proj|up_proj|down_proj)\.lora_[AB]\.weight$" +) class Qwen3MoeHandler(DefaultMoeHandler): key = "qwen3_moe" native_vllm_lora_status = "validated" + def to_vllm_lora_tensors( + self, + tensors: dict[str, torch.Tensor], + *, + adapter_config: dict[str, Any], + ) -> tuple[dict[str, torch.Tensor], dict[str, Any]]: + return _to_vllm_lora_tensors(tensors, adapter_config=adapter_config) + def install_preprocess_patch(self, model_chunks: Sequence[Any]) -> None: install_qwen3_text_preprocess_patch(model_chunks) @@ -29,3 +48,142 @@ def compile_workaround_config( QWEN3_MOE_HANDLER = Qwen3MoeHandler() + + +def _qwen3_moe_config(adapter_config: dict[str, Any]) -> dict[str, Any]: + config = dict(adapter_config) + target_modules = list(config.get("target_modules") or []) + if "experts" not in target_modules: + target_modules.append("experts") + config["target_modules"] = target_modules + return config + + +def _packed_lora_b_by_expert( + tensor: torch.Tensor, + *, + num_experts: int, + rank: int, +) -> torch.Tensor: + return tensor.reshape(tensor.shape[0], rank, num_experts).permute(2, 0, 1) + + +def _clone(tensor: torch.Tensor) -> torch.Tensor: + return tensor.clone().contiguous() + + +def _expand_fused_moe_lora( + prefix: str, + slots: dict[str, torch.Tensor], + *, + rank: int, +) -> dict[str, torch.Tensor]: + try: + gate_up_a = slots["base_layer.lora_A"] + gate_up_b = slots["base_layer.lora_B"] + down_a = slots["lora_A"] + down_b = slots["lora_B"] + except KeyError as exc: + raise RuntimeError(f"Incomplete Qwen3 MoE LoRA block for {prefix}") from exc + + if ( + gate_up_a.ndim != 2 + or gate_up_b.ndim != 2 + or down_a.ndim != 2 + or down_b.ndim != 2 + ): + raise RuntimeError(f"Qwen3 MoE LoRA tensors for {prefix} must be 2D") + if gate_up_a.shape[0] % rank != 0: + raise RuntimeError( + f"{prefix}: gate/up lora_A shape {tuple(gate_up_a.shape)} " + f"is not divisible by rank {rank}" + ) + if gate_up_b.shape[0] % 2 != 0: + raise RuntimeError( + f"{prefix}: gate/up lora_B rows {gate_up_b.shape[0]} are not even" + ) + num_experts = gate_up_a.shape[0] // rank + expected_rank_cols = num_experts * rank + intermediate = gate_up_b.shape[0] // 2 + if gate_up_b.shape[1] != expected_rank_cols: + raise RuntimeError( + f"{prefix}: gate/up lora_B shape {tuple(gate_up_b.shape)} does not " + f"match {num_experts} experts at rank {rank}" + ) + if down_a.shape != (expected_rank_cols, intermediate): + raise RuntimeError( + f"{prefix}: down lora_A shape {tuple(down_a.shape)} does not match " + f"expected {(expected_rank_cols, intermediate)}" + ) + if down_b.shape[1] != expected_rank_cols: + raise RuntimeError( + f"{prefix}: down lora_B shape {tuple(down_b.shape)} does not match " + f"{num_experts} experts at rank {rank}" + ) + + gate_up_b_by_expert = _packed_lora_b_by_expert( + gate_up_b, + num_experts=num_experts, + rank=rank, + ) + down_b_by_expert = _packed_lora_b_by_expert( + down_b, + num_experts=num_experts, + rank=rank, + ) + expanded: dict[str, torch.Tensor] = {} + for expert in range(num_experts): + rows = slice(expert * rank, (expert + 1) * rank) + gate_b, up_b = gate_up_b_by_expert[expert].split(intermediate, dim=0) + expert_prefix = f"{prefix}.{expert}" + expanded[f"{expert_prefix}.gate_proj.lora_A.weight"] = _clone(gate_up_a[rows]) + expanded[f"{expert_prefix}.gate_proj.lora_B.weight"] = _clone(gate_b) + expanded[f"{expert_prefix}.up_proj.lora_A.weight"] = _clone(gate_up_a[rows]) + expanded[f"{expert_prefix}.up_proj.lora_B.weight"] = _clone(up_b) + expanded[f"{expert_prefix}.down_proj.lora_A.weight"] = _clone(down_a[rows]) + expanded[f"{expert_prefix}.down_proj.lora_B.weight"] = _clone( + down_b_by_expert[expert] + ) + return expanded + + +def _to_vllm_lora_tensors( + tensors: dict[str, torch.Tensor], + *, + adapter_config: dict[str, Any], +) -> tuple[dict[str, torch.Tensor], dict[str, Any]]: + grouped: dict[str, dict[str, torch.Tensor]] = {} + for key, tensor in tensors.items(): + match = _QWEN3_FUSED_MOE_KEY_RE.match(key) + if match is None: + continue + slot = ( + f"{'base_layer.' if match.group('base_layer') else ''}{match.group('lora')}" + ) + grouped.setdefault(match.group("prefix"), {})[slot] = tensor + + if not grouped: + if any(_QWEN3_EXPERT_MOE_KEY_RE.match(key) for key in tensors): + return tensors, _qwen3_moe_config(adapter_config) + return tensors, adapter_config + + rank = int(adapter_config["r"]) + transformed: dict[str, torch.Tensor] = {} + used_keys: set[str] = set() + for prefix, slots in grouped.items(): + transformed.update(_expand_fused_moe_lora(prefix, slots, rank=rank)) + used_keys.update( + { + f"{prefix}.base_layer.lora_A.weight", + f"{prefix}.base_layer.lora_B.weight", + f"{prefix}.lora_A.weight", + f"{prefix}.lora_B.weight", + } + ) + for key, tensor in tensors.items(): + if key in used_keys: + continue + if key in transformed: + raise RuntimeError(f"Duplicate Qwen3 LoRA tensor after conversion: {key}") + transformed[key] = tensor + return transformed, _qwen3_moe_config(adapter_config) diff --git a/src/art/megatron/model_support/registry.py b/src/art/megatron/model_support/registry.py index 8ae232ac7..910718ce0 100644 --- a/src/art/megatron/model_support/registry.py +++ b/src/art/megatron/model_support/registry.py @@ -21,6 +21,8 @@ "down_proj", ) +_QWEN3_MOE_TARGET_MODULES = (*_DENSE_TARGET_MODULES, "experts") + _QWEN3_5_DENSE_TARGET_MODULES = ( "q_proj", "k_proj", @@ -61,7 +63,7 @@ "Qwen/Qwen3-30B-A3B-Instruct-2507", "Qwen/Qwen3-235B-A22B-Instruct-2507", ), - default_target_modules=_DENSE_TARGET_MODULES, + default_target_modules=_QWEN3_MOE_TARGET_MODULES, native_vllm_lora_status=QWEN3_MOE_HANDLER.native_vllm_lora_status, ) diff --git a/tests/integration/megatron/lora/test_lora_disk_codecs.py b/tests/integration/megatron/lora/test_lora_disk_codecs.py index aea01d7dc..05fe4457e 100644 --- a/tests/integration/megatron/lora/test_lora_disk_codecs.py +++ b/tests/integration/megatron/lora/test_lora_disk_codecs.py @@ -11,6 +11,7 @@ QWEN3_5_MOE_HANDLER, QWEN3_MOE_HANDLER, ) +from art.megatron.model_support.lora_disk import normalize_lora_checkpoint_to_vllm from art.megatron.weights.merge import load_lora_adapter_state_dict, merge_lora_adapter from art.utils.convert_moe_lora import convert_checkpoint_if_needed @@ -204,6 +205,65 @@ def _qwen3_moe_lora_tensors(prefix: str, *, rank: int = 2) -> dict[str, torch.Te return tensors +def _pack_lora_b_by_expert(blocks: list[torch.Tensor]) -> torch.Tensor: + stacked = torch.stack(blocks, dim=0) + return stacked.permute(1, 2, 0).reshape(stacked.shape[1], -1).contiguous() + + +def _qwen3_fused_moe_fixture( + prefix: str, + *, + rank: int = 2, +) -> tuple[dict[str, torch.Tensor], dict[str, torch.Tensor]]: + hidden = 3 + intermediate = 4 + num_experts = 2 + gate_up_a = torch.arange( + num_experts * rank * hidden, + dtype=torch.float32, + ).reshape(num_experts * rank, hidden) + down_a = ( + torch.arange( + num_experts * rank * intermediate, + dtype=torch.float32, + ).reshape(num_experts * rank, intermediate) + + 100 + ) + gate_up_b_blocks = [ + torch.arange( + 2 * intermediate * rank, + dtype=torch.float32, + ).reshape(2 * intermediate, rank) + + 200 + + expert * 100 + for expert in range(num_experts) + ] + down_b_blocks = [ + torch.arange(hidden * rank, dtype=torch.float32).reshape(hidden, rank) + + 500 + + expert * 100 + for expert in range(num_experts) + ] + fused = { + f"{prefix}.base_layer.lora_A.weight": gate_up_a, + f"{prefix}.base_layer.lora_B.weight": _pack_lora_b_by_expert(gate_up_b_blocks), + f"{prefix}.lora_A.weight": down_a, + f"{prefix}.lora_B.weight": _pack_lora_b_by_expert(down_b_blocks), + } + expected: dict[str, torch.Tensor] = {} + for expert in range(num_experts): + rows = slice(expert * rank, (expert + 1) * rank) + gate_b, up_b = gate_up_b_blocks[expert].split(intermediate, dim=0) + expert_prefix = f"{prefix}.{expert}" + expected[f"{expert_prefix}.gate_proj.lora_A.weight"] = gate_up_a[rows].clone() + expected[f"{expert_prefix}.gate_proj.lora_B.weight"] = gate_b + expected[f"{expert_prefix}.up_proj.lora_A.weight"] = gate_up_a[rows].clone() + expected[f"{expert_prefix}.up_proj.lora_B.weight"] = up_b + expected[f"{expert_prefix}.down_proj.lora_A.weight"] = down_a[rows].clone() + expected[f"{expert_prefix}.down_proj.lora_B.weight"] = down_b_blocks[expert] + return fused, expected + + def test_peft_fused_moe_checkpoint_converts_to_vllm_3d_layout(tmp_path: Path) -> None: prefix = "base_model.model.model.layers.0.mlp.experts" peft_tensors = { @@ -266,6 +326,67 @@ def test_peft_fused_moe_checkpoint_converts_to_vllm_3d_layout(tmp_path: Path) -> assert "target_parameters" not in adapter_config +def test_qwen3_fused_identity_normalizes_to_per_expert_vllm_layout( + tmp_path: Path, +) -> None: + prefix = "base_model.model.model.layers.0.mlp.experts" + rank = 2 + fused, expected = _qwen3_fused_moe_fixture(prefix, rank=rank) + _save_adapter( + tmp_path, + { + f"{prefix}.base_layer.lora_A.weight": fused[ + f"{prefix}.base_layer.lora_B.weight" + ].T.contiguous(), + f"{prefix}.base_layer.lora_B.weight": fused[ + f"{prefix}.base_layer.lora_A.weight" + ].T.contiguous(), + f"{prefix}.lora_A.weight": fused[f"{prefix}.lora_B.weight"].T.contiguous(), + f"{prefix}.lora_B.weight": fused[f"{prefix}.lora_A.weight"].T.contiguous(), + }, + { + "r": rank, + "lora_alpha": 4, + "target_modules": ["q_proj"], + "target_parameters": [ + "model.layers.0.mlp.experts.gate_up_proj", + "model.layers.0.mlp.experts.down_proj", + ], + }, + ) + + convert_checkpoint_if_needed(str(tmp_path)) + normalize_lora_checkpoint_to_vllm( + tmp_path, + handler=QWEN3_MOE_HANDLER, + adapter_config=_config("Qwen/Qwen3-30B-A3B", rank=rank), + ) + + converted = load_file(tmp_path / "adapter_model.safetensors") + _assert_tensors_equal(converted, expected) + adapter_config = json.loads((tmp_path / "adapter_config.json").read_text()) + assert "experts" in adapter_config["target_modules"] + loaded_modules = _assert_stock_vllm_loads( + tmp_path, + expected_modules={ + "experts.0.gate_proj", + "experts.0.up_proj", + "experts.0.down_proj", + "experts.1.gate_proj", + "experts.1.up_proj", + "experts.1.down_proj", + }, + ) + assert loaded_modules == [ + "model.layers.0.mlp.experts.0.down_proj", + "model.layers.0.mlp.experts.0.gate_proj", + "model.layers.0.mlp.experts.0.up_proj", + "model.layers.0.mlp.experts.1.down_proj", + "model.layers.0.mlp.experts.1.gate_proj", + "model.layers.0.mlp.experts.1.up_proj", + ] + + def test_qwen35_and_qwen36_vllm_canonical_roundtrip_and_stock_loader(tmp_path: Path): art_prefix = "base_model.model.model.layers.0" original = _qwen35_moe_art_tensors(art_prefix)