Skip to content

Add Aurora optimizer for non-square matrices#80

Open
JohnLangford wants to merge 4 commits into
mainfrom
jcl/aurora-optimizer
Open

Add Aurora optimizer for non-square matrices#80
JohnLangford wants to merge 4 commits into
mainfrom
jcl/aurora-optimizer

Conversation

@JohnLangford
Copy link
Copy Markdown
Contributor

Summary

  • Adds Aurora (blog, reference impl) as a new optimizer in dion/aurora.py, exported as dion.Aurora.
  • Aurora is a leverage-uniform polar update for non-square matrices: it iteratively row-preconditions the gradient before each polar (Newton-Schulz) call so that all left-singular rows of the orthogonalized update have comparable norms. For square matrices it reduces to standard Muon polar bit-for-bit.
  • Reuses Muon's pre/post-orthogonalize stages and the shared megabatch infrastructure; the new behavior lives entirely in a newton_schulz_func wrapper, so mega-batching, FSDP2 sharding, num_heads, and mixed param groups all work transparently.

What changed

  • dion/aurora.py (new): Aurora class + make_aurora_polar(base_polar, pp_iterations, pp_beta) factory. Default hyperparameters match the reference (pp_iterations=2, pp_beta=0.5, nesterov=True, adjust_lr=None).
  • dion/__init__.py: export Aurora.
  • tests/test_optimizers.py: new TestAurora class (11 tests) covering basic operation, determinism, parameter update, Nesterov, cautious WD, pp_iterations sweep, megabatching, mixed shapes, validation. Two Aurora-specific assertions:
    • Square matrices: Aurora's polar output is bit-for-bit identical to the standard polar.
    • Non-square matrices: Aurora's row-norm max/min ratio is meaningfully tighter than standard polar (and below 1.15).

Algorithm

For non-square G with (m, n):

  1. Transpose to tall (m >= n).
  2. Initialize D = 1 / row_norm(G).
  3. For pp_iterations rounds: U = polar(D * G); between rounds update D *= (target_row_sq / row_sq(U))^pp_beta where target_row_sq = n / m.
  4. Transpose back if needed; multiply by max(1, m/n)^0.5 (Aurora's tall-aspect-ratio scaling, baked into the orthogonalization output rather than the LR).

For square matrices the loop is bypassed entirely.

Test plan

  • pytest tests/test_optimizers.py::TestAurora — 11/11 pass on CUDA.
  • pytest tests/test_optimizers.py tests/test_polar_express.py tests/test_newton_shulz.py — 92/92 pass (no regressions).
  • Verified empirically on a tall (512×128) matrix: Aurora row-norm std drops 6× (0.029 → 0.005) and max/min ratio improves from 1.40 to 1.06.
  • Verified Aurora == polar bit-for-bit on square (128×128) input.
  • FSDP2 / DTensor sharded smoke test (not run; the wrapper plugs into the existing megabatch path that NorMuon also uses, so coverage is via the shared infrastructure).

John Langford added 4 commits May 10, 2026 10:29
Aurora (https://blog.tilderesearch.com/blog/aurora) approximates a
projection onto the intersection of the row oblique and Stiefel
manifolds via diagonal preconditioning, producing leverage-uniform
updates. Standard polar (Newton-Schulz) inherits the non-uniform
left-singular row norms of the gradient; Aurora iteratively rescales
rows so that all neurons receive comparably-sized updates.

For square matrices Aurora reduces bit-for-bit to standard polar.
For non-square matrices it transposes to tall, runs pp_iterations
rounds of row-norm preconditioning around the existing polar function,
and applies Aurora's max(1, m/n)**0.5 aspect-ratio scaling.

Implementation reuses Muon's pre/post-orthogonalize stages and the
shared megabatch infrastructure; the new logic lives entirely in a
newton_schulz_func wrapper, so all dion features (mega-batching,
FSDP2 sharding, num_heads, mixed param groups) work transparently.
- Switch Aurora's aspect-ratio scaling to dion's standard ``adjust_lr``
  pathway: ``adjust_lr="spectral_norm"`` is now the default (matching
  Muon/NorMuon), and the baked-in ``max(1, m/n)**0.5`` multiplication is
  removed from ``aurora_polar``. Note: this differs from the Aurora
  reference for wide matrices (dion uses ``sqrt(m/n)``, the reference
  uses ``max(1, m/n)**0.5``); on tall and square matrices the two agree.
- Add ``dion/aurora_reference.py``: single-file readable port of
  ``tilde-research/aurora-release`` (simple-quintic Newton-Schulz +
  diag-preconditioned polar + AdamW/Lion fallback), exported as
  ``dion.AuroraReference``. Mirrors the existing
  ``muon_reference.py`` / ``dion_reference.py`` pattern.
- README: list Aurora alongside Muon/Dion2/NorMuon in the optimizer
  table, the ``1D Sharding Configuration`` heading, the imports
  example, and the per-file descriptions.
- Tests: drop the now-unneeded aspect-ratio rescale in the row-norm
  uniformity assertion; add two AuroraReference tests.
The init-time wrapper closure captured ``pp_iterations`` and ``pp_beta``
as constants, so an LR scheduler or warmup that mutates the param
group's ``pp_iterations`` was silently ignored. Every other Aurora
hyperparameter (``lr``, ``mu``, ``weight_decay``, ``epsilon``,
``flatten``, ``adjust_lr``, ``cautious_wd``, ``nesterov``) is re-read
from the group each step in ``_create_ortho_tasks``; the new
hyperparameters now follow the same convention.

Stash the unwrapped base polar as ``self._aurora_base_polar`` and
rebuild the Aurora wrapper inside ``_create_ortho_tasks`` from the
group's current values. Validate at the same point so bad runtime
mutations fail fast.

Also drop unused ``import math`` from ``aurora_reference.py``.
Adds ``adjust_lr_aurora_aspect`` (``lr * sqrt(max(1, m/n))``) alongside
``spectral_norm`` and ``rms_norm``. The Aurora reference uses one-sided
``max(1, m/n)^0.5`` aspect-ratio scaling, which differs from dion's
default ``spectral_norm`` (= ``sqrt(m/n)``) on wide matrices: dion
shrinks wide-matrix LR by < 1 while the reference leaves it unscaled.

This lets ``dion.Aurora(adjust_lr='aurora_aspect')`` reproduce the
reference's LR conventions exactly while keeping the megabatch
all-to-all communication path. Previously, matching the reference's
wide-matrix behavior required falling back to ``dion.AuroraReference``,
which uses slower per-param ``full_tensor()`` + ``redistribute()``.
Comment thread dion/aurora.py

# Resolve the base polar function (the one wrapped by Aurora's
# diagonal preconditioning). Mirrors DistributedOrthoBase resolution.
if newton_schulz_func is not None:
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

arora's polar function uses 12 iterations. maybe too soon to decide, but should we add an option to use their polar function, or at least to do more than 5 iterations? (it will take a little care to integrate that with gram newton schulz, but it's doable)

Comment thread dion/aurora.py
cautious_wd: bool = False,
) -> Generator[None, None, None]:
"""
Mega-batched Aurora update. Reuses Muon's pre/post-orthogonalize stages
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So what is the difference between this function and the generic update_megabatch_async?

Comment thread tests/test_optimizers.py
_run_steps(Aurora, params, dict(lr=0.01), n_steps=1)
assert not torch.equal(params[0].data, before)

def test_nesterov(self):
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what does this test? just that there's no error? same for test_cautious_wd and test_pp_iterations

@NoahAmsel
Copy link
Copy Markdown
Contributor

Questions

  1. Arora is concerned about the up projections (tall/skinny) only, since that's where they observe the neuron death phenomenon. Why do they even apply all this extra normalization to short / fat matrices? Why should we?
  2. I don't really understand why they design the damping parameter as they do. It does however seem unnecessary to multiply D by target_row_sq, which is constant across rows, since polar is invariant to scaling by a constant anyway. Probably no harm though

@JohnLangford
Copy link
Copy Markdown
Contributor Author

The first comment seems valid in my experiments---I'm seeing the best results with spectral norm (no min). So far, it seems like the ideal learning rate is more like 0.01, but there's evidence that this provides a modest advantage over normuon.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants