Skip to content

X-Square-Robot/dmuon

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

78 Commits
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

DMuon icon

DMuon

Drop-in Distributed Muon optimizer implementation in Near-AdamW cost


DMuon

Tech Report Wiki

DMuon is a high-performance distributed implementation of the Muon optimizer that drops into any existing training pipeline in just 3 lines of code. Through fine-grained kernel tuning, load-balanced work scheduling, and a redesigned distributed communication path, DMuon delivers near-AdamW step time while keeping Muon's optimization benefits — fully plug-and-play, with no changes to your model.

Install

git clone git@github.com:X-Square-Robot/dmuon.git && cd dmuon
pip install -e .

3-Line Integration

import dmuon  # auto-patches FSDP2

# Mark which params get dedicated ownership (auto-balanced across ranks)
dmuon.dedicate_params(model, dp_mesh, predicate=lambda n, p: "proj" in n and p.ndim == 2)

# Use FSDP2 as usual — dedicated params are handled automatically
for layer in model.layers:
    fully_shard(layer, mesh=dp_mesh)
fully_shard(model, mesh=dp_mesh)

Forward broadcast, backward reduce, and owner-only optimizer execution are handled by hooks.

Full Training Example

import torch
from torch.distributed.fsdp import fully_shard
from torch.distributed.device_mesh import init_device_mesh
import dmuon

# Setup
mesh = init_device_mesh("cuda", (world_size,))
model = MyModel().cuda()

# Apply DMuon + FSDP2
dmuon.dedicate_params(model, mesh, predicate=lambda n, p: "proj" in n and p.ndim == 2)
for layer in model.layers:
    fully_shard(layer, mesh=mesh)
fully_shard(model, mesh=mesh)

# Muon for dedicated matrix params, AdamW for the rest — handled automatically
optimizer = dmuon.Muon(model, lr=0.02, ns_steps=5, adamw_lr=1e-3)

# Training loop
for batch in dataloader:
    optimizer.zero_grad()
    loss = model(batch).loss
    loss.backward()
    optimizer.step()

For multi-node, pass a 2D (replicate, shard) mesh to dedicate_params(..., replicate_mesh=...) and fully_shard(..., mesh=hsdp); everything else is identical.

Benchmark

Measured on a 16-node cluster, DMuon runs the matrix optimizer at roughly AdamW step time.

Model AdamW step DMuon step Δ vs AdamW
WallX 1259 ms 1285 ms +2.1%
Pi0 1617 ms 1645 ms +1.7 %
Wall-WM 3309 ms 3424 ms 3.4 %

Acknowledgments

DMuon builds upon the ideas and engineering of several excellent prior works:

  • Muon by Keller Jordan et al. — the original Muon optimizer, which orthogonalizes momentum updates via Newton-Schulz iteration.
  • Moonlight (Muon is Scalable for LLM Training) by the Moonshot AI (Kimi) team — which demonstrated Muon's scalability to large-scale LLM training and introduced the weight-decay and update-scale adjustments that make it practical as a drop-in AdamW replacement. Our distributed design is heavily inspired by their ZeRO-1 Muon implementation.
  • Gram Newton-Schulz (blog post) by Jack Zhang, Noah Amsel, Berlin Chen, and Tri Dao — a hardware-aware reformulation of Newton-Schulz that iterates on the Gram matrix and exploits its symmetry with dedicated CuTeDSL GEMM kernels. Our SYRK-based kernel design follows this line of work.

We thank the authors of these projects for open-sourcing their code and insights.

About

No description, website, or topics provided.

Resources

License

Contributing

Stars

Watchers

Forks

Releases

No releases published

Packages

 
 
 

Contributors