MatrixFSDP is an experimental FSDP2-style runtime for training large PyTorch
models with matrix-aware sharding. It keeps the public API close to PyTorch
FSDP2: choose a DP or HSDP DeviceMesh, call fully_shard(...), then train
with the normal PyTorch loop.
The main difference is the layout policy. MatrixFSDP can keep selected matrices whole on owner ranks for Muon-style optimizers, split dense weights by rows or blocks, and leave TP/EP-owned parameters to the upper-level parallelism stack. This makes the same runtime usable for AdamW baselines, matrix-owner Muon runs, and MoE models where FSDP should manage only the DP/HSDP parameter set.
model = fully_shard(model, mesh=mesh, dp_mesh_dims=dp_mesh_dims)
optim = configure_optimizer(model, "adamw", lr=3e-4)
loss.backward()
optim.step()
optim.zero_grad(set_to_none=True)Create optimizers after sharding so they see the sharded parameter views. For
most users, fully_shard(...) is the only sharding entry point.
Measured on 8x A100-SXM4-80GB (NVLink), CUDA 12.8, BF16 params/compute,
transformer-block sharding with activation checkpointing, batch 1 per rank, Muon
optimizer. With optimizer_policy="mixed_muon_adamw" MatrixFSDP keeps each 2D
matrix whole on its owner rank, so Muon's Newton-Schulz runs without a
full-matrix all-gather; forward and backward stay at parity with FSDP2.
Sequence-length sweep — 32 layers, hidden 4096, intermediate 16384, 32 heads (ms per step):
| seq | FSDP2 total | MatrixFSDP total | total speedup | optimizer step (FSDP2 -> MatrixFSDP) |
|---|---|---|---|---|
| 1024 | 1145 | 537 | 2.13x | 748 -> 175 ms (4.3x) |
| 2048 | 1390 | 791 | 1.76x | 747 -> 175 ms (4.3x) |
| 4096 | 1963 | 1371 | 1.43x | 747 -> 175 ms (4.3x) |
| 8192 | 3435 | 2844 | 1.21x | 752 -> 175 ms (4.3x) |
| 16384 | 7646 | 7057 | 1.08x | 747 -> 176 ms (4.2x) |
The owner-Muon optimizer-step saving is ~4.3x and essentially constant — it depends on the parameter matrices, not the sequence length. Forward/backward run at parity and scale with sequence length, so the total-step speedup is largest when the optimizer is a bigger share of the iteration (shorter sequences) and tapers toward parity as forward/backward dominate — never slower than FSDP2.
Larger model — 40 layers, hidden 5120, intermediate 20480, 40 heads, seq 4096:
| phase | FSDP2 | MatrixFSDP | result |
|---|---|---|---|
| forward | 586 ms | 588 ms | parity |
| backward | 1621 ms | 1590 ms | 1.02x faster |
| optimizer step | 1337 ms | 414 ms | 3.23x faster |
| total / step | 3545 ms | 2593 ms | 1.37x faster |
| peak reserved | 28934 MB | 26610 MB | lower |
The advantage comes from matrix-owner sharding, which avoids the expensive full-matrix optimizer gather; forward/backward match FSDP2 at equal-or-lower peak memory.
pip install -r requirements.txt
pip install -e .fully_shard(model, mesh=..., dp_mesh_dims=...): public FSDP2-like entry point.DataParallelMeshDims(shard="dp_shard"): selects theDeviceMeshdimension used for parameter sharding.DataParallelMeshDims(shard="dp_shard", replicate="dp_replicate"): enables HSDP-style replicate x shard layouts.optimizer_policy="mixed_muon_adamw": enables Muon-aware matrix-owner planning forfully_shard(...).runtime_trace_enabled=True: opt-in lifecycle/collective tracing for debugging; the default training path keeps it disabled.configure_optimizer(...): optimizer helper for AdamW, SGD, Muon, and mixed Muon/AdamW training.save_matrix_dcp(...)andload_matrix_dcp(...): DCP checkpoint helpers for MatrixFSDP shard metadata and optimizer state.
- Introduction: basic concepts, public APIs,
fully_shard(...)arguments, DeviceMesh setup, optimizers, and MoE boundaries. - Tutorial: activation checkpointing, DCP save/load, and resharded load.
- Architecture: package layering, training-step flow, and design invariants, with per-subsystem deep dives for core, planning, runtime, kernels, optim, and checkpoint.
- Multi-node testing with container machines:
local cross-machine correctness validation (rendezvous, env consistency,
fail-fast, shaped-network scheduling) using Apple
containermicro-VMs.
Maintained by findlamp.
- GitHub: @findlamp (findlamp/matrix-fsdp)
- Email: dujinshidai30@gmail.com
Questions, bug reports, and contributions are welcome — open an issue or pull request on the repository, or reach out by email.