Skip to content

megabatch: env-gated NaN capture wrapper around newton_schulz_func#78

Open
JohnLangford wants to merge 1 commit into
mainfrom
jcl/issue-76-nan-capture
Open

megabatch: env-gated NaN capture wrapper around newton_schulz_func#78
JohnLangford wants to merge 1 commit into
mainfrom
jcl/issue-76-nan-capture

Conversation

@JohnLangford
Copy link
Copy Markdown
Contributor

Summary

Re: #76. The reporter and maintainers cannot consistently reproduce the
NorMuon + gram-newton-schulz + quack-kernels 0.4.1 NaN bug, so the next
debugging step is to capture the actual NS input that triggers the failure
on the affected hardware (6x Blackwell RTX PRO 6000 + DDP).

This PR wraps newton_schulz_func inside megabatch_orthogonalize_async
with an env-gated capture wrapper. When enabled, each rank that observes a
non-finite NS output writes its own .pt dump (input + output + rank +
shape + epsilon) and by default raises RuntimeError. The check is
local per rank — no extra collective — and it short-circuits at one
os.environ lookup when disabled.

Usage

DION_NAN_CAPTURE=1 \
DION_NAN_CAPTURE_DIR=/some/scratch/dir \
DION_NAN_CAPTURE_RAISE=1 \  # default; set to 0 to keep training
torchrun ... train.py

Per-rank filenames are dion_nan_capture_rank{rank}_shape{...}_pid{...}_{ts_ms}.pt
so concurrent writes from multiple ranks on a shared filesystem do not collide.

Notes

  • Disabled by default. Cost when off: one os.environ lookup per call.
  • The input is .detach().clone()-d up front so we still have the
    offending tensor even if a kernel mutates X in-place.
  • Two new module-level imports (os, time).

Test plan

  • test_capture_disabled_by_default — no env, no dumps even when NS returns NaN
  • test_capture_dumps_on_non_finite_output — dump contains correct input/output/shape/rank/epsilon
  • test_capture_raises_by_default_when_enabledRuntimeError after dump
  • test_capture_no_dump_when_finite — finite output -> no dump even with env on
  • test_capture_filename_includes_rank — non-zero rank reflected in filename

All five pass single-rank without GPU/NCCL.

Issue #76 reports intermittent NaN params with NorMuon + gram-newton-schulz
+ quack-kernels 0.4.1 on 6x Blackwell RTX PRO 6000 + DDP. The reporter and
maintainers cannot consistently reproduce, so the next step is to capture
the actual Newton-Schulz input that triggers the failure for offline replay.

Add an env-gated wrapper around newton_schulz_func inside
megabatch_orthogonalize_async:

  DION_NAN_CAPTURE=1                       # enable
  DION_NAN_CAPTURE_DIR=./dion_nan_captures # output directory
  DION_NAN_CAPTURE_RAISE=1                 # raise after dump (default)

When enabled, each rank that observes a non-finite NS output writes a
rank+shape+pid+timestamp-keyed .pt file containing the (cloned-up-front)
input, the offending output, the rank, and epsilon, then by default raises
RuntimeError. Multiple ranks dumping in parallel do not collide on a shared
filesystem because the filename includes rank + pid.

The check is local per rank: there is no extra collective. When the env
var is unset the wrapper short-circuits at one os.environ lookup per
megabatch ortho call (effectively free). Cloning the input only happens
on the enabled path.

Tests: CPU-only single-rank tests for the four behaviors -
disabled-by-default no-op, dump-on-NaN, raise-by-default, no-dump-when-
finite, and rank-keyed filenames. All five pass without GPU/NCCL.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant