Skip to content

findlamp/matrix-fsdp

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

85 Commits
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

MatrixFSDP

MatrixFSDP is an experimental FSDP2-style runtime for training large PyTorch models with matrix-aware sharding. It keeps the public API close to PyTorch FSDP2: choose a DP or HSDP DeviceMesh, call fully_shard(...), then train with the normal PyTorch loop.

The main difference is the layout policy. MatrixFSDP can keep selected matrices whole on owner ranks for Muon-style optimizers, split dense weights by rows or blocks, and leave TP/EP-owned parameters to the upper-level parallelism stack. This makes the same runtime usable for AdamW baselines, matrix-owner Muon runs, and MoE models where FSDP should manage only the DP/HSDP parameter set.

model = fully_shard(model, mesh=mesh, dp_mesh_dims=dp_mesh_dims)
optim = configure_optimizer(model, "adamw", lr=3e-4)

loss.backward()
optim.step()
optim.zero_grad(set_to_none=True)

Create optimizers after sharding so they see the sharded parameter views. For most users, fully_shard(...) is the only sharding entry point.

Performance Highlights

Measured on 8x A100-SXM4-80GB (NVLink), CUDA 12.8, BF16 params/compute, transformer-block sharding with activation checkpointing, batch 1 per rank, Muon optimizer. With optimizer_policy="mixed_muon_adamw" MatrixFSDP keeps each 2D matrix whole on its owner rank, so Muon's Newton-Schulz runs without a full-matrix all-gather; forward and backward stay at parity with FSDP2.

Sequence-length sweep — 32 layers, hidden 4096, intermediate 16384, 32 heads (ms per step):

seq FSDP2 total MatrixFSDP total total speedup optimizer step (FSDP2 -> MatrixFSDP)
1024 1145 537 2.13x 748 -> 175 ms (4.3x)
2048 1390 791 1.76x 747 -> 175 ms (4.3x)
4096 1963 1371 1.43x 747 -> 175 ms (4.3x)
8192 3435 2844 1.21x 752 -> 175 ms (4.3x)
16384 7646 7057 1.08x 747 -> 176 ms (4.2x)

The owner-Muon optimizer-step saving is ~4.3x and essentially constant — it depends on the parameter matrices, not the sequence length. Forward/backward run at parity and scale with sequence length, so the total-step speedup is largest when the optimizer is a bigger share of the iteration (shorter sequences) and tapers toward parity as forward/backward dominate — never slower than FSDP2.

Larger model — 40 layers, hidden 5120, intermediate 20480, 40 heads, seq 4096:

phase FSDP2 MatrixFSDP result
forward 586 ms 588 ms parity
backward 1621 ms 1590 ms 1.02x faster
optimizer step 1337 ms 414 ms 3.23x faster
total / step 3545 ms 2593 ms 1.37x faster
peak reserved 28934 MB 26610 MB lower

The advantage comes from matrix-owner sharding, which avoids the expensive full-matrix optimizer gather; forward/backward match FSDP2 at equal-or-lower peak memory.

Install

pip install -r requirements.txt
pip install -e .

Main APIs

  • fully_shard(model, mesh=..., dp_mesh_dims=...): public FSDP2-like entry point.
  • DataParallelMeshDims(shard="dp_shard"): selects the DeviceMesh dimension used for parameter sharding.
  • DataParallelMeshDims(shard="dp_shard", replicate="dp_replicate"): enables HSDP-style replicate x shard layouts.
  • optimizer_policy="mixed_muon_adamw": enables Muon-aware matrix-owner planning for fully_shard(...).
  • runtime_trace_enabled=True: opt-in lifecycle/collective tracing for debugging; the default training path keeps it disabled.
  • configure_optimizer(...): optimizer helper for AdamW, SGD, Muon, and mixed Muon/AdamW training.
  • save_matrix_dcp(...) and load_matrix_dcp(...): DCP checkpoint helpers for MatrixFSDP shard metadata and optimizer state.

Documentation

Contact

Maintained by findlamp.

Questions, bug reports, and contributions are welcome — open an issue or pull request on the repository, or reach out by email.

About

No description, website, or topics provided.

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

 
 
 

Contributors