Add Aurora optimizer for non-square matrices#80
Conversation
Aurora (https://blog.tilderesearch.com/blog/aurora) approximates a projection onto the intersection of the row oblique and Stiefel manifolds via diagonal preconditioning, producing leverage-uniform updates. Standard polar (Newton-Schulz) inherits the non-uniform left-singular row norms of the gradient; Aurora iteratively rescales rows so that all neurons receive comparably-sized updates. For square matrices Aurora reduces bit-for-bit to standard polar. For non-square matrices it transposes to tall, runs pp_iterations rounds of row-norm preconditioning around the existing polar function, and applies Aurora's max(1, m/n)**0.5 aspect-ratio scaling. Implementation reuses Muon's pre/post-orthogonalize stages and the shared megabatch infrastructure; the new logic lives entirely in a newton_schulz_func wrapper, so all dion features (mega-batching, FSDP2 sharding, num_heads, mixed param groups) work transparently.
- Switch Aurora's aspect-ratio scaling to dion's standard ``adjust_lr`` pathway: ``adjust_lr="spectral_norm"`` is now the default (matching Muon/NorMuon), and the baked-in ``max(1, m/n)**0.5`` multiplication is removed from ``aurora_polar``. Note: this differs from the Aurora reference for wide matrices (dion uses ``sqrt(m/n)``, the reference uses ``max(1, m/n)**0.5``); on tall and square matrices the two agree. - Add ``dion/aurora_reference.py``: single-file readable port of ``tilde-research/aurora-release`` (simple-quintic Newton-Schulz + diag-preconditioned polar + AdamW/Lion fallback), exported as ``dion.AuroraReference``. Mirrors the existing ``muon_reference.py`` / ``dion_reference.py`` pattern. - README: list Aurora alongside Muon/Dion2/NorMuon in the optimizer table, the ``1D Sharding Configuration`` heading, the imports example, and the per-file descriptions. - Tests: drop the now-unneeded aspect-ratio rescale in the row-norm uniformity assertion; add two AuroraReference tests.
The init-time wrapper closure captured ``pp_iterations`` and ``pp_beta`` as constants, so an LR scheduler or warmup that mutates the param group's ``pp_iterations`` was silently ignored. Every other Aurora hyperparameter (``lr``, ``mu``, ``weight_decay``, ``epsilon``, ``flatten``, ``adjust_lr``, ``cautious_wd``, ``nesterov``) is re-read from the group each step in ``_create_ortho_tasks``; the new hyperparameters now follow the same convention. Stash the unwrapped base polar as ``self._aurora_base_polar`` and rebuild the Aurora wrapper inside ``_create_ortho_tasks`` from the group's current values. Validate at the same point so bad runtime mutations fail fast. Also drop unused ``import math`` from ``aurora_reference.py``.
Adds ``adjust_lr_aurora_aspect`` (``lr * sqrt(max(1, m/n))``) alongside ``spectral_norm`` and ``rms_norm``. The Aurora reference uses one-sided ``max(1, m/n)^0.5`` aspect-ratio scaling, which differs from dion's default ``spectral_norm`` (= ``sqrt(m/n)``) on wide matrices: dion shrinks wide-matrix LR by < 1 while the reference leaves it unscaled. This lets ``dion.Aurora(adjust_lr='aurora_aspect')`` reproduce the reference's LR conventions exactly while keeping the megabatch all-to-all communication path. Previously, matching the reference's wide-matrix behavior required falling back to ``dion.AuroraReference``, which uses slower per-param ``full_tensor()`` + ``redistribute()``.
|
|
||
| # Resolve the base polar function (the one wrapped by Aurora's | ||
| # diagonal preconditioning). Mirrors DistributedOrthoBase resolution. | ||
| if newton_schulz_func is not None: |
There was a problem hiding this comment.
arora's polar function uses 12 iterations. maybe too soon to decide, but should we add an option to use their polar function, or at least to do more than 5 iterations? (it will take a little care to integrate that with gram newton schulz, but it's doable)
| cautious_wd: bool = False, | ||
| ) -> Generator[None, None, None]: | ||
| """ | ||
| Mega-batched Aurora update. Reuses Muon's pre/post-orthogonalize stages |
There was a problem hiding this comment.
So what is the difference between this function and the generic update_megabatch_async?
| _run_steps(Aurora, params, dict(lr=0.01), n_steps=1) | ||
| assert not torch.equal(params[0].data, before) | ||
|
|
||
| def test_nesterov(self): |
There was a problem hiding this comment.
what does this test? just that there's no error? same for test_cautious_wd and test_pp_iterations
|
Questions
|
|
The first comment seems valid in my experiments---I'm seeing the best results with spectral norm (no min). So far, it seems like the ideal learning rate is more like 0.01, but there's evidence that this provides a modest advantage over normuon. |
Summary
dion/aurora.py, exported asdion.Aurora.newton_schulz_funcwrapper, so mega-batching, FSDP2 sharding,num_heads, and mixed param groups all work transparently.What changed
dion/aurora.py(new):Auroraclass +make_aurora_polar(base_polar, pp_iterations, pp_beta)factory. Default hyperparameters match the reference (pp_iterations=2,pp_beta=0.5,nesterov=True,adjust_lr=None).dion/__init__.py: exportAurora.tests/test_optimizers.py: newTestAuroraclass (11 tests) covering basic operation, determinism, parameter update, Nesterov, cautious WD,pp_iterationssweep, megabatching, mixed shapes, validation. Two Aurora-specific assertions:max/minratio is meaningfully tighter than standard polar (and below 1.15).Algorithm
For non-square
Gwith(m, n):m >= n).D = 1 / row_norm(G).pp_iterationsrounds:U = polar(D * G); between rounds updateD *= (target_row_sq / row_sq(U))^pp_betawheretarget_row_sq = n / m.max(1, m/n)^0.5(Aurora's tall-aspect-ratio scaling, baked into the orthogonalization output rather than the LR).For square matrices the loop is bypassed entirely.
Test plan
pytest tests/test_optimizers.py::TestAurora— 11/11 pass on CUDA.pytest tests/test_optimizers.py tests/test_polar_express.py tests/test_newton_shulz.py— 92/92 pass (no regressions).stddrops 6× (0.029 → 0.005) andmax/minratio improves from 1.40 to 1.06.Aurora == polarbit-for-bit on square (128×128) input.