Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
158 changes: 158 additions & 0 deletions src/art/megatron/model_support/handlers/qwen3_moe.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
import re
from typing import Any, Sequence

import torch

from art.megatron.model_support.handlers.default_dense import DefaultMoeHandler
from art.megatron.model_support.handlers.qwen3_common import (
install_qwen3_text_preprocess_patch,
Expand All @@ -11,12 +14,28 @@
"alltoall_dispatch_preprocess",
"deepep_permute_restore",
)
_QWEN3_FUSED_MOE_KEY_RE = re.compile(
r"^(?P<prefix>.*\.mlp\.experts)\."
r"(?:(?P<base_layer>base_layer)\.)?(?P<lora>lora_[AB])\.weight$"
)
_QWEN3_EXPERT_MOE_KEY_RE = re.compile(
r"^.*\.mlp\.experts\.\d+\."
r"(?:gate_proj|up_proj|down_proj)\.lora_[AB]\.weight$"
)


class Qwen3MoeHandler(DefaultMoeHandler):
key = "qwen3_moe"
native_vllm_lora_status = "validated"

def to_vllm_lora_tensors(
self,
tensors: dict[str, torch.Tensor],
*,
adapter_config: dict[str, Any],
) -> tuple[dict[str, torch.Tensor], dict[str, Any]]:
return _to_vllm_lora_tensors(tensors, adapter_config=adapter_config)

def install_preprocess_patch(self, model_chunks: Sequence[Any]) -> None:
install_qwen3_text_preprocess_patch(model_chunks)

Expand All @@ -29,3 +48,142 @@ def compile_workaround_config(


QWEN3_MOE_HANDLER = Qwen3MoeHandler()


def _qwen3_moe_config(adapter_config: dict[str, Any]) -> dict[str, Any]:
config = dict(adapter_config)
target_modules = list(config.get("target_modules") or [])
if "experts" not in target_modules:
target_modules.append("experts")
config["target_modules"] = target_modules
return config


def _packed_lora_b_by_expert(
tensor: torch.Tensor,
*,
num_experts: int,
rank: int,
) -> torch.Tensor:
return tensor.reshape(tensor.shape[0], rank, num_experts).permute(2, 0, 1)


def _clone(tensor: torch.Tensor) -> torch.Tensor:
return tensor.clone().contiguous()


def _expand_fused_moe_lora(
prefix: str,
slots: dict[str, torch.Tensor],
*,
rank: int,
) -> dict[str, torch.Tensor]:
try:
gate_up_a = slots["base_layer.lora_A"]
gate_up_b = slots["base_layer.lora_B"]
down_a = slots["lora_A"]
down_b = slots["lora_B"]
except KeyError as exc:
raise RuntimeError(f"Incomplete Qwen3 MoE LoRA block for {prefix}") from exc

if (
gate_up_a.ndim != 2
or gate_up_b.ndim != 2
or down_a.ndim != 2
or down_b.ndim != 2
):
raise RuntimeError(f"Qwen3 MoE LoRA tensors for {prefix} must be 2D")
if gate_up_a.shape[0] % rank != 0:
raise RuntimeError(
f"{prefix}: gate/up lora_A shape {tuple(gate_up_a.shape)} "
f"is not divisible by rank {rank}"
)
if gate_up_b.shape[0] % 2 != 0:
raise RuntimeError(
f"{prefix}: gate/up lora_B rows {gate_up_b.shape[0]} are not even"
)
num_experts = gate_up_a.shape[0] // rank
expected_rank_cols = num_experts * rank
intermediate = gate_up_b.shape[0] // 2
if gate_up_b.shape[1] != expected_rank_cols:
raise RuntimeError(
f"{prefix}: gate/up lora_B shape {tuple(gate_up_b.shape)} does not "
f"match {num_experts} experts at rank {rank}"
)
if down_a.shape != (expected_rank_cols, intermediate):
raise RuntimeError(
f"{prefix}: down lora_A shape {tuple(down_a.shape)} does not match "
f"expected {(expected_rank_cols, intermediate)}"
)
if down_b.shape[1] != expected_rank_cols:
raise RuntimeError(
f"{prefix}: down lora_B shape {tuple(down_b.shape)} does not match "
f"{num_experts} experts at rank {rank}"
)

gate_up_b_by_expert = _packed_lora_b_by_expert(
gate_up_b,
num_experts=num_experts,
rank=rank,
)
down_b_by_expert = _packed_lora_b_by_expert(
down_b,
num_experts=num_experts,
rank=rank,
)
expanded: dict[str, torch.Tensor] = {}
for expert in range(num_experts):
rows = slice(expert * rank, (expert + 1) * rank)
gate_b, up_b = gate_up_b_by_expert[expert].split(intermediate, dim=0)
expert_prefix = f"{prefix}.{expert}"
expanded[f"{expert_prefix}.gate_proj.lora_A.weight"] = _clone(gate_up_a[rows])
expanded[f"{expert_prefix}.gate_proj.lora_B.weight"] = _clone(gate_b)
expanded[f"{expert_prefix}.up_proj.lora_A.weight"] = _clone(gate_up_a[rows])
expanded[f"{expert_prefix}.up_proj.lora_B.weight"] = _clone(up_b)
expanded[f"{expert_prefix}.down_proj.lora_A.weight"] = _clone(down_a[rows])
expanded[f"{expert_prefix}.down_proj.lora_B.weight"] = _clone(
down_b_by_expert[expert]
)
return expanded


def _to_vllm_lora_tensors(
tensors: dict[str, torch.Tensor],
*,
adapter_config: dict[str, Any],
) -> tuple[dict[str, torch.Tensor], dict[str, Any]]:
grouped: dict[str, dict[str, torch.Tensor]] = {}
for key, tensor in tensors.items():
match = _QWEN3_FUSED_MOE_KEY_RE.match(key)
if match is None:
continue
slot = (
f"{'base_layer.' if match.group('base_layer') else ''}{match.group('lora')}"
)
grouped.setdefault(match.group("prefix"), {})[slot] = tensor

if not grouped:
if any(_QWEN3_EXPERT_MOE_KEY_RE.match(key) for key in tensors):
return tensors, _qwen3_moe_config(adapter_config)
return tensors, adapter_config

rank = int(adapter_config["r"])
transformed: dict[str, torch.Tensor] = {}
used_keys: set[str] = set()
for prefix, slots in grouped.items():
transformed.update(_expand_fused_moe_lora(prefix, slots, rank=rank))
used_keys.update(
{
f"{prefix}.base_layer.lora_A.weight",
f"{prefix}.base_layer.lora_B.weight",
f"{prefix}.lora_A.weight",
f"{prefix}.lora_B.weight",
}
)
for key, tensor in tensors.items():
if key in used_keys:
continue
if key in transformed:
raise RuntimeError(f"Duplicate Qwen3 LoRA tensor after conversion: {key}")
transformed[key] = tensor
return transformed, _qwen3_moe_config(adapter_config)
4 changes: 3 additions & 1 deletion src/art/megatron/model_support/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@
"down_proj",
)

_QWEN3_MOE_TARGET_MODULES = (*_DENSE_TARGET_MODULES, "experts")

_QWEN3_5_DENSE_TARGET_MODULES = (
"q_proj",
"k_proj",
Expand Down Expand Up @@ -61,7 +63,7 @@
"Qwen/Qwen3-30B-A3B-Instruct-2507",
"Qwen/Qwen3-235B-A22B-Instruct-2507",
),
default_target_modules=_DENSE_TARGET_MODULES,
default_target_modules=_QWEN3_MOE_TARGET_MODULES,
native_vllm_lora_status=QWEN3_MOE_HANDLER.native_vllm_lora_status,
)

Expand Down
121 changes: 121 additions & 0 deletions tests/integration/megatron/lora/test_lora_disk_codecs.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
QWEN3_5_MOE_HANDLER,
QWEN3_MOE_HANDLER,
)
from art.megatron.model_support.lora_disk import normalize_lora_checkpoint_to_vllm
from art.megatron.weights.merge import load_lora_adapter_state_dict, merge_lora_adapter
from art.utils.convert_moe_lora import convert_checkpoint_if_needed

Expand Down Expand Up @@ -204,6 +205,65 @@ def _qwen3_moe_lora_tensors(prefix: str, *, rank: int = 2) -> dict[str, torch.Te
return tensors


def _pack_lora_b_by_expert(blocks: list[torch.Tensor]) -> torch.Tensor:
stacked = torch.stack(blocks, dim=0)
return stacked.permute(1, 2, 0).reshape(stacked.shape[1], -1).contiguous()


def _qwen3_fused_moe_fixture(
prefix: str,
*,
rank: int = 2,
) -> tuple[dict[str, torch.Tensor], dict[str, torch.Tensor]]:
hidden = 3
intermediate = 4
num_experts = 2
gate_up_a = torch.arange(
num_experts * rank * hidden,
dtype=torch.float32,
).reshape(num_experts * rank, hidden)
down_a = (
torch.arange(
num_experts * rank * intermediate,
dtype=torch.float32,
).reshape(num_experts * rank, intermediate)
+ 100
)
gate_up_b_blocks = [
torch.arange(
2 * intermediate * rank,
dtype=torch.float32,
).reshape(2 * intermediate, rank)
+ 200
+ expert * 100
for expert in range(num_experts)
]
down_b_blocks = [
torch.arange(hidden * rank, dtype=torch.float32).reshape(hidden, rank)
+ 500
+ expert * 100
for expert in range(num_experts)
]
fused = {
f"{prefix}.base_layer.lora_A.weight": gate_up_a,
f"{prefix}.base_layer.lora_B.weight": _pack_lora_b_by_expert(gate_up_b_blocks),
f"{prefix}.lora_A.weight": down_a,
f"{prefix}.lora_B.weight": _pack_lora_b_by_expert(down_b_blocks),
}
expected: dict[str, torch.Tensor] = {}
for expert in range(num_experts):
rows = slice(expert * rank, (expert + 1) * rank)
gate_b, up_b = gate_up_b_blocks[expert].split(intermediate, dim=0)
expert_prefix = f"{prefix}.{expert}"
expected[f"{expert_prefix}.gate_proj.lora_A.weight"] = gate_up_a[rows].clone()
expected[f"{expert_prefix}.gate_proj.lora_B.weight"] = gate_b
expected[f"{expert_prefix}.up_proj.lora_A.weight"] = gate_up_a[rows].clone()
expected[f"{expert_prefix}.up_proj.lora_B.weight"] = up_b
expected[f"{expert_prefix}.down_proj.lora_A.weight"] = down_a[rows].clone()
expected[f"{expert_prefix}.down_proj.lora_B.weight"] = down_b_blocks[expert]
return fused, expected


def test_peft_fused_moe_checkpoint_converts_to_vllm_3d_layout(tmp_path: Path) -> None:
prefix = "base_model.model.model.layers.0.mlp.experts"
peft_tensors = {
Expand Down Expand Up @@ -266,6 +326,67 @@ def test_peft_fused_moe_checkpoint_converts_to_vllm_3d_layout(tmp_path: Path) ->
assert "target_parameters" not in adapter_config


def test_qwen3_fused_identity_normalizes_to_per_expert_vllm_layout(
tmp_path: Path,
) -> None:
prefix = "base_model.model.model.layers.0.mlp.experts"
rank = 2
fused, expected = _qwen3_fused_moe_fixture(prefix, rank=rank)
_save_adapter(
tmp_path,
{
f"{prefix}.base_layer.lora_A.weight": fused[
f"{prefix}.base_layer.lora_B.weight"
].T.contiguous(),
f"{prefix}.base_layer.lora_B.weight": fused[
f"{prefix}.base_layer.lora_A.weight"
].T.contiguous(),
f"{prefix}.lora_A.weight": fused[f"{prefix}.lora_B.weight"].T.contiguous(),
f"{prefix}.lora_B.weight": fused[f"{prefix}.lora_A.weight"].T.contiguous(),
},
{
"r": rank,
"lora_alpha": 4,
"target_modules": ["q_proj"],
"target_parameters": [
"model.layers.0.mlp.experts.gate_up_proj",
"model.layers.0.mlp.experts.down_proj",
],
},
)

convert_checkpoint_if_needed(str(tmp_path))
normalize_lora_checkpoint_to_vllm(
tmp_path,
handler=QWEN3_MOE_HANDLER,
adapter_config=_config("Qwen/Qwen3-30B-A3B", rank=rank),
)

converted = load_file(tmp_path / "adapter_model.safetensors")
_assert_tensors_equal(converted, expected)
adapter_config = json.loads((tmp_path / "adapter_config.json").read_text())
assert "experts" in adapter_config["target_modules"]
loaded_modules = _assert_stock_vllm_loads(
tmp_path,
expected_modules={
"experts.0.gate_proj",
"experts.0.up_proj",
"experts.0.down_proj",
"experts.1.gate_proj",
"experts.1.up_proj",
"experts.1.down_proj",
},
)
assert loaded_modules == [
"model.layers.0.mlp.experts.0.down_proj",
"model.layers.0.mlp.experts.0.gate_proj",
"model.layers.0.mlp.experts.0.up_proj",
"model.layers.0.mlp.experts.1.down_proj",
"model.layers.0.mlp.experts.1.gate_proj",
"model.layers.0.mlp.experts.1.up_proj",
]


def test_qwen35_and_qwen36_vllm_canonical_roundtrip_and_stock_loader(tmp_path: Path):
art_prefix = "base_model.model.model.layers.0"
original = _qwen35_moe_art_tensors(art_prefix)
Expand Down
Loading