Skip to content

feat: automatic CPU/GPU device routing#25

Merged
dek3rr merged 9 commits into
mainfrom
feat/device-routing
Jun 12, 2026
Merged

feat: automatic CPU/GPU device routing#25
dek3rr merged 9 commits into
mainfrom
feat/device-routing

Conversation

@dek3rr

@dek3rr dek3rr commented Jun 12, 2026

Copy link
Copy Markdown
Owner

What

Automatic CPU/GPU device routing: every public entry point (nrpt, nrpt_adaptive, discover_chain_count, ising_sample, sample_states, sample_with_observation, estimate_moments, estimate_kl_grad) now takes device="auto". With no accelerator visible, placement is untouched (byte-for-byte legacy behavior — CPU-only CI sees no change). With a GPU present, the work score n_chains × free_nodes routes small workloads to the CPU and large ones to the GPU.

  • hamon/device.py owns all policy: resolve_device, tree_device_put (identity fast path keeps jit caches warm), tracer guard (routing is a no-op inside jit/vmap/grad), HAMON_DEVICE / HAMON_DEVICE_THRESHOLD env overrides, device="cpu"|"gpu"|jax.Device|None escape hatches.
  • Orchestrators resolve the device once and pass the concrete device down, so it never flips between tuning phases; the jit-once round loop is preserved (new trace-count regression test for the routed path).
  • Test suite pins the default device to CPU (tests/conftest.py); a new gpu marker runs a smoke subset on real hardware (auto-skipped when absent); HAMON_TEST_DEVICE=gpu unpins the suite.
  • benchmarks/device_crossover.py measures a machine's crossover and recommends a threshold.

Why

With CUDA jax installed, JAX defaults everything to the GPU — including tiny dispatch-bound programs where the CPU is several times faster. Measured on an RTX 5080 (WSL, jax 0.10.1): the test suite ran 837s on GPU vs 172s CPU-forced, while large sampling is 2–11× faster on GPU. Goal: having GPU jax installed should never make a workload slower than CPU-only jax.

Measured results (RTX 5080)

Workload Before After
Full test suite on the GPU machine 837s 159s
Crossover sweep score ≤ 2048 → CPU wins; score ≥ 4096 → GPU wins 2–11×
Auto-routing sanity score 256 → cpu:0, score 32768 → cuda:0

The default threshold (4096) is the measured steady-state crossover. Known simplification: rounds are excluded from the score, so very short one-shot flows remain compile-dominated and can favor CPU — documented escape hatches (device="cpu", JAX_COMPILATION_CACHE_DIR).

Reviewer notes

  • Routing re-commits entry arrays to the resolved device; outputs come back committed there. device=None is the contract for "hamon never touches placement".
  • tree_device_put maps only jax.Array leaves; leaf-level statics (spec/blocks/nodes) keep object identity and reconstructed container Modules compare equal, so equinox caches hit — proven by test_adaptive_phases_compile_once_with_explicit_cpu.
  • Arrays closed over by energy_delta_fn cannot be moved by routing (documented in the nrpt docstring).
  • Verified green: Windows CPU-only (296 passed, 3 gpu-skipped), WSL+5080 full suite, WSL -m gpu smoke subset, WSL with JAX_PLATFORMS=cpu.
  • The first two commits (float32-under-x64 fix, test compile cache) target main locally but weren't pushed yet; once main is pushed this PR reduces to the two routing commits.

🤖 Generated with Claude Code

dek3rr and others added 9 commits June 12, 2026 01:52
With CUDA jax installed, JAX places everything on the GPU, including
the tiny dispatch-bound programs where a CPU finishes several times
faster. Public entry points now take device=auto: no accelerator
visible means placement is untouched; otherwise the work score
n_chains x free_nodes routes small workloads to the CPU and large
ones to the accelerator. Explicit cpu/gpu, concrete jax.Device,
and None (full opt-out) are supported; HAMON_DEVICE_THRESHOLD and
HAMON_DEVICE override the heuristic.

Orchestrators resolve the device once and pass it down, so it never
flips between tuning phases and the jit-once round loop is preserved
(trace-count regression tests cover the routed path). The test suite
pins the default device to CPU, with a gpu marker smoke subset for
real-hardware coverage and HAMON_TEST_DEVICE=gpu to unpin.

benchmarks/device_crossover.py measures a machine''s crossover and
recommends a threshold.

Co-Authored-By: Claude Fable 5 <noreply@anthropic.com>
Sweep on the target hardware (benchmarks/device_crossover.py,
jax 0.10.1): every point at score <= 2048 ran faster on CPU, every
point at score >= 4096 ran 2-11x faster on GPU. Short one-shot flows
remain compile-dominated; documented the device=cpu and
JAX_COMPILATION_CACHE_DIR escape hatches.

Co-Authored-By: Claude Fable 5 <noreply@anthropic.com>
Integrates device routing with the _ChainSource dispatch from #23 and
the usability features from #24: the device is resolved once in
nrpt_adaptive via _ChainSource.metadata_free_nodes, the beta=1 template
pair is committed to the device once outside the phase loop
(device_put_template), and discover_chain_count forwards its concrete
device through the unified probe call. nrpt applies routing after the
new beta-ladder validation and supports stacked init states inside the
device context.

Co-Authored-By: Claude Fable 5 <noreply@anthropic.com>
jax.Device is a runtime alias into the pybind11 extension that type
checkers cannot use in type expressions; annotations now go through a
TYPE_CHECKING JaxDevice alias (Any when checking, jax.Device at
runtime). jax.core needs an explicit submodule import for the Tracer
reference to resolve. tree_device_put is typed generically (T -> T),
which also restores Optional narrowing for betas after nrpt''s
device_put unpack. Moved test_device''s importlib workaround below the
imports for E402.

Co-Authored-By: Claude Fable 5 <noreply@anthropic.com>
Formatting-only; brings the 33 files that predated the formatter into
compliance so ruff format --check passes repo-wide.

Co-Authored-By: Claude Fable 5 <noreply@anthropic.com>
tests/test_nrpt.py and tests/test_diagnostics.py arrived from upstream
with cp1252-mangled UTF-8 (Lambda as Λ, arrows as â†';
test_diagnostics.py was double-encoded). Restored via sloppy-cp1252
round trips; text content only, no code changes.

Co-Authored-By: Claude Fable 5 <noreply@anthropic.com>
The previous format pass used a locally installed ruff 0.15.2, whose
output differs from the 0.15.16 pinned in .pre-commit-config.yaml,
so the hook kept reformatting on the PR. This applies the pinned
version''s formatting, including notebook code cells, which 0.15.16
formats natively.

Co-Authored-By: Claude Fable 5 <noreply@anthropic.com>
ruff format on notebooks is not idempotent in a single pass; the
pre-commit hook converges on the second. Cell sources only.

Co-Authored-By: Claude Fable 5 <noreply@anthropic.com>
@dek3rr dek3rr merged commit 62d5c1b into main Jun 12, 2026
9 checks passed
@dek3rr dek3rr deleted the feat/device-routing branch June 12, 2026 15:15
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant