feat: automatic CPU/GPU device routing#25
Merged
Merged
Conversation
With CUDA jax installed, JAX places everything on the GPU, including the tiny dispatch-bound programs where a CPU finishes several times faster. Public entry points now take device=auto: no accelerator visible means placement is untouched; otherwise the work score n_chains x free_nodes routes small workloads to the CPU and large ones to the accelerator. Explicit cpu/gpu, concrete jax.Device, and None (full opt-out) are supported; HAMON_DEVICE_THRESHOLD and HAMON_DEVICE override the heuristic. Orchestrators resolve the device once and pass it down, so it never flips between tuning phases and the jit-once round loop is preserved (trace-count regression tests cover the routed path). The test suite pins the default device to CPU, with a gpu marker smoke subset for real-hardware coverage and HAMON_TEST_DEVICE=gpu to unpin. benchmarks/device_crossover.py measures a machine''s crossover and recommends a threshold. Co-Authored-By: Claude Fable 5 <noreply@anthropic.com>
Sweep on the target hardware (benchmarks/device_crossover.py, jax 0.10.1): every point at score <= 2048 ran faster on CPU, every point at score >= 4096 ran 2-11x faster on GPU. Short one-shot flows remain compile-dominated; documented the device=cpu and JAX_COMPILATION_CACHE_DIR escape hatches. Co-Authored-By: Claude Fable 5 <noreply@anthropic.com>
Integrates device routing with the _ChainSource dispatch from #23 and the usability features from #24: the device is resolved once in nrpt_adaptive via _ChainSource.metadata_free_nodes, the beta=1 template pair is committed to the device once outside the phase loop (device_put_template), and discover_chain_count forwards its concrete device through the unified probe call. nrpt applies routing after the new beta-ladder validation and supports stacked init states inside the device context. Co-Authored-By: Claude Fable 5 <noreply@anthropic.com>
jax.Device is a runtime alias into the pybind11 extension that type checkers cannot use in type expressions; annotations now go through a TYPE_CHECKING JaxDevice alias (Any when checking, jax.Device at runtime). jax.core needs an explicit submodule import for the Tracer reference to resolve. tree_device_put is typed generically (T -> T), which also restores Optional narrowing for betas after nrpt''s device_put unpack. Moved test_device''s importlib workaround below the imports for E402. Co-Authored-By: Claude Fable 5 <noreply@anthropic.com>
Formatting-only; brings the 33 files that predated the formatter into compliance so ruff format --check passes repo-wide. Co-Authored-By: Claude Fable 5 <noreply@anthropic.com>
tests/test_nrpt.py and tests/test_diagnostics.py arrived from upstream with cp1252-mangled UTF-8 (Lambda as Λ, arrows as â†'; test_diagnostics.py was double-encoded). Restored via sloppy-cp1252 round trips; text content only, no code changes. Co-Authored-By: Claude Fable 5 <noreply@anthropic.com>
The previous format pass used a locally installed ruff 0.15.2, whose output differs from the 0.15.16 pinned in .pre-commit-config.yaml, so the hook kept reformatting on the PR. This applies the pinned version''s formatting, including notebook code cells, which 0.15.16 formats natively. Co-Authored-By: Claude Fable 5 <noreply@anthropic.com>
ruff format on notebooks is not idempotent in a single pass; the pre-commit hook converges on the second. Cell sources only. Co-Authored-By: Claude Fable 5 <noreply@anthropic.com>
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
What
Automatic CPU/GPU device routing: every public entry point (
nrpt,nrpt_adaptive,discover_chain_count,ising_sample,sample_states,sample_with_observation,estimate_moments,estimate_kl_grad) now takesdevice="auto". With no accelerator visible, placement is untouched (byte-for-byte legacy behavior — CPU-only CI sees no change). With a GPU present, the work scoren_chains × free_nodesroutes small workloads to the CPU and large ones to the GPU.hamon/device.pyowns all policy:resolve_device,tree_device_put(identity fast path keeps jit caches warm), tracer guard (routing is a no-op inside jit/vmap/grad),HAMON_DEVICE/HAMON_DEVICE_THRESHOLDenv overrides,device="cpu"|"gpu"|jax.Device|Noneescape hatches.tests/conftest.py); a newgpumarker runs a smoke subset on real hardware (auto-skipped when absent);HAMON_TEST_DEVICE=gpuunpins the suite.benchmarks/device_crossover.pymeasures a machine's crossover and recommends a threshold.Why
With CUDA jax installed, JAX defaults everything to the GPU — including tiny dispatch-bound programs where the CPU is several times faster. Measured on an RTX 5080 (WSL, jax 0.10.1): the test suite ran 837s on GPU vs 172s CPU-forced, while large sampling is 2–11× faster on GPU. Goal: having GPU jax installed should never make a workload slower than CPU-only jax.
Measured results (RTX 5080)
cpu:0, score 32768 →cuda:0The default threshold (4096) is the measured steady-state crossover. Known simplification: rounds are excluded from the score, so very short one-shot flows remain compile-dominated and can favor CPU — documented escape hatches (
device="cpu",JAX_COMPILATION_CACHE_DIR).Reviewer notes
device=Noneis the contract for "hamon never touches placement".tree_device_putmaps onlyjax.Arrayleaves; leaf-level statics (spec/blocks/nodes) keep object identity and reconstructed container Modules compare equal, so equinox caches hit — proven bytest_adaptive_phases_compile_once_with_explicit_cpu.energy_delta_fncannot be moved by routing (documented in thenrptdocstring).-m gpusmoke subset, WSL withJAX_PLATFORMS=cpu.mainlocally but weren't pushed yet; oncemainis pushed this PR reduces to the two routing commits.🤖 Generated with Claude Code