fix(base): pass sampler choice through to bambi verbatim#963
Open
AlexanderFengler wants to merge 1 commit into
Open
fix(base): pass sampler choice through to bambi verbatim#963AlexanderFengler wants to merge 1 commit into
AlexanderFengler wants to merge 1 commit into
Conversation
HSSMBase.sample collapsed sampler="numpyro" (and blackjax, nutpie) to inference_method="pymc" before handing off to bambi, which then dispatched to pm.sample(nuts_sampler="pymc"). The user's sampler choice was silently downgraded to PyMC NUTS regardless of what they asked for. Bambi natively accepts inference_method values "pymc", "numpyro", "blackjax", "nutpie" (and "vi"/"laplace") and routes each to the matching nuts_sampler. The collapse conditional negated this. Regression archaeology: - Aug 5, 2024 (commit aef3f9b, "Fix compatibility with Bambi (#516)"): introduced the working pattern -- inference_method="mcmc" (generic NUTS marker) + kwargs["nuts_sampler"]="numpyro"/"blackjax"/etc. injected separately. Bambi's old "mcmc" inference_method was generic and read nuts_sampler from kwargs. Correct under old bambi semantics. - Dec 17, 2025 (commit 20c100b, "fix: update model.sample api to be consistent with bambi's"): bambi had renamed its inference_method values (mcmc -> pymc, nuts_numpyro -> numpyro, etc.). This commit mechanically updated the string list, but ALSO deleted the kwargs["nuts_sampler"] = ... injection block. The flatten-conditional was left in place. After this commit, all four NUTS samplers route to inference_method="pymc" -> nuts_sampler="pymc" with no recourse. - The bug has been live since Dec 17, 2025 (about 5 months). Fix: replace the conditional with inference_method=sampler. Bambi handles each NUTS variant directly under the new API. The injection block deleted in commit 20c100b is correctly absent now -- bambi passes nuts_sampler=sampler_backend to pm.sample explicitly, so injecting it via kwargs would conflict. Side effects: - sampler="numpyro" now actually invokes numpyro NUTS, which runs inside JAX with internal parallelism and does NOT fork worker processes. This avoids the cloudpickle path that breaks on unpicklable ONNX ModelProto closures (surfaced by the sbi NLE tutorial). - sampler="blackjax" and sampler="nutpie" similarly now invoke their respective backends instead of PyMC NUTS. - sampler="pymc" behavior is unchanged (still routes to PyMC NUTS). Other gates that read the user's `sampler` argument (parallel-sampling warning at base.py:621, init default at base.py:636, jitter handling at base.py:644, step-sampler check at base.py:657) all check `sampler` directly, not the post-normalization inference_method value. None are affected by this change. Tests pinning the old behavior: - tests/test_rlssm.py:300, tests/test_save_load.py:39, tests/slow/test_mcmc.py:100-101 use sampler="numpyro" and were silently exercising PyMC NUTS. After this fix they exercise the actual numpyro path. Worth re-running as part of PR review. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
3 tasks
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Summary
Single-line fix to
HSSMBase.sample: pass the user'ssamplerargumentthrough to bambi as
inference_method=samplerinstead of collapsing allNUTS variants to
inference_method="pymc".The bug
HSSMBase.sample(sampler="numpyro")was silently routing to PyMC NUTS,not numpyro NUTS. Same for
sampler="blackjax"andsampler="nutpie".Only
sampler="pymc"worked as advertised.Archaeology
aef3f9b, "Fix compatibility with Bambi Fix compatibility with Bambi #516"):introduced a working pattern —
inference_method="mcmc"(genericbambi NUTS marker) plus
kwargs["nuts_sampler"]="numpyro"/"blackjax"injected separately. Bambi's old
inference_method="mcmc"was ageneric switch that read
nuts_samplerfrom kwargs. Correct underthe old bambi API.
20c100b, "fix: update model.sample api tobe consistent with bambi's"): bambi renamed its
inference_methodvalues (mcmc → pymc, nuts_numpyro → numpyro, etc.). This commit
mechanically updated the string list, but also deleted the
kwargs["nuts_sampler"]injection block. The flatten-conditionalwas left in place. After this commit, all four NUTS samplers route
to
inference_method="pymc"→nuts_sampler="pymc", with norecourse.
The bug has been live since Dec 17, 2025 (~5 months).
The fix
Bambi natively understands each NUTS variant via
_SUPPORTED_METHODS = {"pymc", "numpyro", "blackjax", "nutpie", "vi", "laplace"}and routes each to a distinct
pm.sample(nuts_sampler=...)call. Thecollapse conditional has no purpose under the new API.
Other gates in
HSSMBase.samplethat read the user'ssamplerarg(parallel-sampling warning, init default, jitter handling,
step-sampler check) all check
samplerdirectly, not the post-normalization
inference_methodvalue. None are affected.Test plan
(
pytest tests/distribution_utils/test_onnx.py tests/distribution_utils/test_onnx_model.py).sampler="numpyro"should be re-run as part ofreview — they were silently exercising PyMC NUTS pre-fix and will
exercise the actual numpyro path post-fix. Specifically:
tests/test_rlssm.py:300tests/test_save_load.py:39tests/slow/test_mcmc.py:100-101sampler="blackjax"andsampler="nutpie"nowinvoke their respective NUTS backends in the user's runtime if
those backends are installed.
Surfaced by
The sbi NLE tutorial work on the
sbi-integrationbranch — exportedONNX graphs from sbi-trained classifiers tripped a cloudpickle bug
in PyMC's multiprocess sampler because
sampler="numpyro"wasn'tactually invoking numpyro (numpyro doesn't fork processes). Once
fixed, the sbi tutorial runs
pm.sample(nuts_sampler="numpyro")andsidesteps the cloudpickle path entirely.
🤖 Generated with Claude Code