JAX-native thermal sampling for discrete energy-based models.
Hamon is a JAX library for sampling from discrete probabilistic graphical models. It provides GPU-accelerated block Gibbs sampling, non-reversible parallel tempering with adaptive schedule optimization, and tools for building and training Ising models, RBMs, and other discrete energy-based models.
Built on Extropic AI's thrml foundation, Hamon diverges as an independent library with original algorithmic contributions and performance optimizations.
In Japanese swordsmithing, the hamon (刃文, "blade pattern") is the visible wave that appears along the edge of a katana after differential hardening. The smith coats the blade in clay — thin along the cutting edge, thick along the spine — then heats the steel to critical temperature and quenches it in water. The edge cools fast into hard martensite; the spine cools slowly into tough pearlite. The boundary between these two phases is the hamon: a pattern born entirely from a thermal process, where controlled temperature gradients reveal structure hidden in disordered steel.
The parallel to this library is direct. Hamon explores discrete energy landscapes by running chains at different temperatures and exchanging information across the thermal gradient. Structure emerges at the boundary between mixing regimes — hot chains explore freely, cold chains resolve fine detail, and the communication between them is what makes sampling work. The hamon on a blade is proof that a thermal process found the right boundary. The diagnostics in this library measure the same thing.
pip install hamonFor development:
git clone https://github.com/dek3rr/hamon.git
cd hamon
pip install -e ".[development,testing,examples]"Requires Python ≥ 3.11 and a JAX installation (GPU setup guide).
With CUDA jax installed, JAX places everything on the GPU — including the
small, dispatch-bound programs where a CPU finishes several times faster.
hamon's entry points (nrpt, nrpt_adaptive, discover_chain_count,
ising_sample, sample_states, sample_with_observation, …) therefore take
a device argument:
"auto"(default) — with no accelerator visible, placement is untouched. Otherwise the work scoren_chains × free nodesdecides: small workloads run on the CPU, large ones on the accelerator. The default threshold (4096, the steady-state crossover measured on an RTX 5080) can be overridden withHAMON_DEVICE_THRESHOLD(calibrate yours withpython benchmarks/device_crossover.py);HAMON_DEVICE=cpu|gpu|noneforces a choice without code changes. Very short one-shot flows are compile-dominated and can favor the CPU regardless of size — passdevice="cpu"for those, or setJAX_COMPILATION_CACHE_DIRso repeated runs skip GPU compilation entirely."cpu"/"gpu"— that platform, raising if it is not visible.- a concrete
jax.Device— used as-is. None— hamon never touches placement.
Routing re-commits the entry arrays (program tensors, states, β ladder) to
the chosen device and returns outputs committed there; pass device=None to
keep full manual control of placement. Orchestrators resolve the device once
and reuse it across all tuning phases, so jit caches stay warm.
import jax
import jax.numpy as jnp
from hamon import SpinNode, Block, SamplingSchedule, sample_states
from hamon.models import IsingEBM, IsingSamplingProgram, hinton_init
nodes = [SpinNode() for _ in range(5)]
edges = [(nodes[i], nodes[i + 1]) for i in range(4)]
model = IsingEBM(nodes, edges, jnp.zeros(5), jnp.ones(4) * 0.5, jnp.array(1.0))
free_blocks = [Block(nodes[::2]), Block(nodes[1::2])]
program = IsingSamplingProgram(model, free_blocks, clamped_blocks=[])
key = jax.random.key(0)
k_init, k_samp = jax.random.split(key, 2)
init_state = hinton_init(k_init, model, free_blocks, ())
schedule = SamplingSchedule(n_warmup=100, n_samples=1000, steps_per_sample=2)
samples = sample_states(k_samp, program, schedule, init_state, [], [Block(nodes)])Hamon implements adaptive NRPT based on Syed et al. (2021), with vectorized swaps that exploit the temperature-linearity of Ising energies:
from hamon.nrpt import nrpt_adaptive
ebm = IsingEBM(nodes, edges, biases, weights, jnp.array(1.0))
program = IsingSamplingProgram(ebm, free_blocks, [])
states, _, stats = nrpt_adaptive(
jax.random.key(42),
init_states=[init_state] * 8,
clamp_state=[],
n_rounds=500,
gibbs_steps_per_round=5,
initial_betas=jnp.linspace(0.1, 2.0, 8),
n_tune=5,
rounds_per_tune=200,
ebm=ebm,
program=program,
)
print(f"Final Λ: {stats['round_trip_diagnostics']['Lambda']:.3f}")
print(f"Round trip rate: {stats['round_trip_diagnostics']['tau_observed']:.4f}")Key features of the NRPT implementation:
- Vectorized swaps: 1 energy evaluation per chain (not 4 per pair), all non-overlapping swaps execute simultaneously via permutation indexing
- Temperature-linear mode: one β = 1 base program serves every chain; interactions are scaled by each chain's β inside the kernel, so no per-chain program construction and n_chains× less interaction memory
- Compile-once adaptive tuning: the round loop is jitted at module level and the β schedule is traced data, so all tuning phases reuse one compiled executable
- Adaptive scheduling: iteratively tunes β spacing to equalize rejection rates, minimizing the global communication barrier Λ
- Round trip tracking: monitors the index process per machine, estimates Λ and predicted optimal rate τ̄ = 1/(2+2Λ)
- Chain count discovery: iteratively probes to find the right number of chains for a target acceptance rate
All chains run in one kernel. Parallel tempering uses jax.vmap over chains
instead of a Python loop. Compile time is constant regardless of chain count.
No redundant work in the sampler loop. Global state is threaded through
lax.scan as a carry. Block updates write back via contiguous slice updates
with static offsets (scatters only as a fallback for non-contiguous layouts)
instead of rebuilding the full state tensor each iteration.
Energy evaluation skips unnecessary work. Pre-built BlockSpec objects are
passed through directly — no reconstruction on every energy() call. Padded
interaction entries are pre-zeroed at program construction, so samplers skip
the per-step active-mask multiply.
Accumulator dtypes are explicit. The moment accumulator pins its dtype at construction, and conditional samplers accumulate in the weights' dtype, avoiding silent promotion on GPU and seeding float64 sums with float32 zeros.
If you use Hamon in your research, please cite:
@software{kerr2026hamon,
author = {Kerr, Douglas E. Jr.},
title = {Hamon: JAX-Native Thermal Sampling for Discrete Energy-Based Models},
year = {2026},
url = {https://github.com/dek3rr/hamon},
version = {0.4.0},
license = {Apache-2.0},
}Hamon's block sampling and PGM infrastructure is derived from thrml (v0.1.3) by Extropic AI, licensed under Apache 2.0. See NOTICE for full attribution. If you use the underlying block Gibbs framework, please also cite:
@misc{jelincic2025efficient,
title = {An efficient probabilistic hardware architecture for diffusion-like models},
author = {Andraž Jelinčič and Owen Lockwood and Akhil Garlapati and Guillaume Verdon and Trevor McCourt},
year = {2025},
eprint = {2510.23972},
archivePrefix= {arXiv},
primaryClass = {cs.LG},
}The non-reversible parallel tempering implementation is based on:
Syed, S., Bouchard-Côté, A., Deligiannidis, G., & Doucet, A. (2021). Non-Reversible Parallel Tempering: a Scalable Highly Parallel MCMC Scheme. arXiv:1905.02939
Apache 2.0. See LICENSE.