Skip to content

NorMuon: opt-in nan_guard_fallback skips step on non-finite NS output#79

Open
JohnLangford wants to merge 2 commits into
mainfrom
jcl/issue-76-nan-fallback
Open

NorMuon: opt-in nan_guard_fallback skips step on non-finite NS output#79
JohnLangford wants to merge 2 commits into
mainfrom
jcl/issue-76-nan-fallback

Conversation

@JohnLangford
Copy link
Copy Markdown
Contributor

Summary

Re: #76. Defensive guard for the intermittent NaN bug in NorMuon +
gram-newton-schulz + quack-kernels 0.4.1: once a single rank's NS output
goes non-finite, the bad value poisons the parameter, all-reduces into
the next step, and the run is dead. With this PR, opting into
nan_guard_fallback=True lets all ranks agree (via a single
all_reduce(MAX) of one byte) to skip the entire post-ortho update on
detection.

  • V (variance buffer) and X (param) stay strictly unchanged on a
    skipped step. M (momentum) was updated from the clean gradient in
    pre-orthogonalize and is left alone.
  • The early-return semantics are deliberately "this batch never
    happened" rather than "zero out U and run normalization", because
    decaying V on zero updates contaminates future steps and applying
    weight decay alone is a half-update we don't want.

Companion to #78 (the env-gated capture wrapper that helps identify the
offending input). #78 is for diagnosis; this PR is for surviving the
issue in production until the upstream quack/gns regression is fixed.

Cost

Off (default) On
Per-step 1 Python branch 1-byte all_reduce(MAX) + .item() per shape group
Behavior change None Skip step on detection; emit RuntimeWarning on rank 0

The sync is required: if rank 0 takes the fallback and rank 1 doesn't,
DDP weights drift or the next megabatch alltoall mismatches and
deadlocks. Single-byte allreduce is microseconds; total overhead is
dominated by the device->host .item() sync.

Usage

opt = NorMuon(
    params,
    distributed_mesh=process_group,
    use_gram_newton_schulz=True,
    use_triton=True,
    nan_guard_fallback=True,  # opt-in defense for issue #76
)

Test plan

  • test_fallback_off_lets_nan_propagate_to_params — baselines that
    the bug actually exists in unit-test form so the positive test isn't
    trivially green
  • test_fallback_on_skips_step_when_ns_returns_nan — bit-exact
    param + zero variance buffer + RuntimeWarning on rank 0
  • test_fallback_on_does_not_block_normal_step — guard is inert
    when NS output is finite
  • Existing test_optimizers.py::TestNorMuon (5 tests) still pass

All run single-rank on CUDA without NCCL.

Issue #76 reports intermittent NaN parameters with NorMuon +
gram-newton-schulz + quack-kernels 0.4.1 on multi-GPU DDP. Once a single
rank's NS output goes non-finite, the post-ortho update poisons the
parameter, the bad value gets all-reduced into the gradient/state on the
next step, and the run is dead.

Add an opt-in ``nan_guard_fallback`` arg to NorMuon. After the megabatch
orthogonalization completes, all ranks check ``isfinite(U_stacked).all()``
and exchange the result via a single-byte ``all_reduce(MAX)``. If any
rank flagged non-finite, every rank early-returns from
``normuon_update_megabatch_async``: V (variance buffer) and X (param)
stay strictly unchanged, so the run state is bit-identical to "this
batch never happened" for these params. M (momentum) was already
updated in pre-orthogonalize from the clean gradient and is left alone.

Why early-return rather than zero U + run normalization:
- normuon_normalization_stacked on zero U decays V toward 0, contaminating
  future steps' normalization.
- weight_decay applies in the post-ortho path; when we know the optimizer
  step is junk, it's cleaner to skip everything than to apply a
  half-update.

Cost when triggered: one tiny allreduce + one device->host sync per
shape group per step. Cost when ``nan_guard_fallback=False`` (default):
zero - just one Python-level branch in the optimizer step.

Sync is required for correctness in the distributed case: if rank 0
takes the fallback and rank 1 doesn't, DDP weights drift or a future
collective deadlocks because rank 1 advances to the next step's
megabatch alltoall while rank 0 hasn't.

Tests (single-rank CUDA, no NCCL):
- ``test_fallback_off_lets_nan_propagate_to_params`` baselines that the
  bug exists when the guard is off, so the next test isn't trivially
  green.
- ``test_fallback_on_skips_step_when_ns_returns_nan`` verifies bit-exact
  param + zero variance buffer + RuntimeWarning emission on rank 0.
- ``test_fallback_on_does_not_block_normal_step`` verifies the guard is
  inert when NS output is finite (params change as usual).

Existing NorMuon tests in test_optimizers.py still pass.
…gabatch's

The head-split and FSDP2 batch-sharded paths in `_create_ortho_tasks`
deliberately override `process_group=None` for the megabatch because no
alltoall is needed there. Gating the nan-flag allreduce on that same
`process_group` silently disabled the sync in those cases:

  - Head-split + DDP: params replicated across ranks, but only some
    ranks may take the fallback. DDP weights drift -- the exact failure
    mode this guard was supposed to prevent.
  - Batch-sharded FSDP2: each rank owns a different shard of the same
    logical param. Divergent skips leave the logical tensor torn (some
    shards stepped, some not), violating the "this batch never happened"
    invariant.

Thread the optimizer's full `self._process_group` through as a separate
`nan_sync_process_group` argument so the nan-skip decision agrees across
all ranks regardless of the megabatch's local-vs-collective config.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant