NorMuon: opt-in nan_guard_fallback skips step on non-finite NS output#79
Open
JohnLangford wants to merge 2 commits into
Open
NorMuon: opt-in nan_guard_fallback skips step on non-finite NS output#79JohnLangford wants to merge 2 commits into
JohnLangford wants to merge 2 commits into
Conversation
Issue #76 reports intermittent NaN parameters with NorMuon + gram-newton-schulz + quack-kernels 0.4.1 on multi-GPU DDP. Once a single rank's NS output goes non-finite, the post-ortho update poisons the parameter, the bad value gets all-reduced into the gradient/state on the next step, and the run is dead. Add an opt-in ``nan_guard_fallback`` arg to NorMuon. After the megabatch orthogonalization completes, all ranks check ``isfinite(U_stacked).all()`` and exchange the result via a single-byte ``all_reduce(MAX)``. If any rank flagged non-finite, every rank early-returns from ``normuon_update_megabatch_async``: V (variance buffer) and X (param) stay strictly unchanged, so the run state is bit-identical to "this batch never happened" for these params. M (momentum) was already updated in pre-orthogonalize from the clean gradient and is left alone. Why early-return rather than zero U + run normalization: - normuon_normalization_stacked on zero U decays V toward 0, contaminating future steps' normalization. - weight_decay applies in the post-ortho path; when we know the optimizer step is junk, it's cleaner to skip everything than to apply a half-update. Cost when triggered: one tiny allreduce + one device->host sync per shape group per step. Cost when ``nan_guard_fallback=False`` (default): zero - just one Python-level branch in the optimizer step. Sync is required for correctness in the distributed case: if rank 0 takes the fallback and rank 1 doesn't, DDP weights drift or a future collective deadlocks because rank 1 advances to the next step's megabatch alltoall while rank 0 hasn't. Tests (single-rank CUDA, no NCCL): - ``test_fallback_off_lets_nan_propagate_to_params`` baselines that the bug exists when the guard is off, so the next test isn't trivially green. - ``test_fallback_on_skips_step_when_ns_returns_nan`` verifies bit-exact param + zero variance buffer + RuntimeWarning emission on rank 0. - ``test_fallback_on_does_not_block_normal_step`` verifies the guard is inert when NS output is finite (params change as usual). Existing NorMuon tests in test_optimizers.py still pass.
…gabatch's
The head-split and FSDP2 batch-sharded paths in `_create_ortho_tasks`
deliberately override `process_group=None` for the megabatch because no
alltoall is needed there. Gating the nan-flag allreduce on that same
`process_group` silently disabled the sync in those cases:
- Head-split + DDP: params replicated across ranks, but only some
ranks may take the fallback. DDP weights drift -- the exact failure
mode this guard was supposed to prevent.
- Batch-sharded FSDP2: each rank owns a different shard of the same
logical param. Divergent skips leave the logical tensor torn (some
shards stepped, some not), violating the "this batch never happened"
invariant.
Thread the optimizer's full `self._process_group` through as a separate
`nan_sync_process_group` argument so the nan-skip decision agrees across
all ranks regardless of the megabatch's local-vs-collective config.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Summary
Re: #76. Defensive guard for the intermittent NaN bug in NorMuon +
gram-newton-schulz + quack-kernels 0.4.1: once a single rank's NS output
goes non-finite, the bad value poisons the parameter, all-reduces into
the next step, and the run is dead. With this PR, opting into
nan_guard_fallback=Truelets all ranks agree (via a singleall_reduce(MAX)of one byte) to skip the entire post-ortho update ondetection.
V(variance buffer) andX(param) stay strictly unchanged on askipped step.
M(momentum) was updated from the clean gradient inpre-orthogonalize and is left alone.
happened" rather than "zero out U and run normalization", because
decaying V on zero updates contaminates future steps and applying
weight decay alone is a half-update we don't want.
Companion to #78 (the env-gated capture wrapper that helps identify the
offending input). #78 is for diagnosis; this PR is for surviving the
issue in production until the upstream quack/gns regression is fixed.
Cost
all_reduce(MAX)+.item()per shape groupThe sync is required: if rank 0 takes the fallback and rank 1 doesn't,
DDP weights drift or the next megabatch alltoall mismatches and
deadlocks. Single-byte allreduce is microseconds; total overhead is
dominated by the device->host
.item()sync.Usage
Test plan
test_fallback_off_lets_nan_propagate_to_params— baselines thatthe bug actually exists in unit-test form so the positive test isn't
trivially green
test_fallback_on_skips_step_when_ns_returns_nan— bit-exactparam + zero variance buffer + RuntimeWarning on rank 0
test_fallback_on_does_not_block_normal_step— guard is inertwhen NS output is finite
test_optimizers.py::TestNorMuon(5 tests) still passAll run single-rank on CUDA without NCCL.