Skip to content

fix(base): pass sampler choice through to bambi verbatim#963

Open
AlexanderFengler wants to merge 1 commit into
mainfrom
fix/sampler-routing
Open

fix(base): pass sampler choice through to bambi verbatim#963
AlexanderFengler wants to merge 1 commit into
mainfrom
fix/sampler-routing

Conversation

@AlexanderFengler
Copy link
Copy Markdown
Member

Summary

Single-line fix to HSSMBase.sample: pass the user's sampler argument
through to bambi as inference_method=sampler instead of collapsing all
NUTS variants to inference_method="pymc".

The bug

HSSMBase.sample(sampler="numpyro") was silently routing to PyMC NUTS,
not numpyro NUTS. Same for sampler="blackjax" and sampler="nutpie".
Only sampler="pymc" worked as advertised.

Archaeology

  • Aug 5, 2024 (commit aef3f9b, "Fix compatibility with Bambi Fix compatibility with Bambi #516"):
    introduced a working pattern — inference_method="mcmc" (generic
    bambi NUTS marker) plus kwargs["nuts_sampler"]="numpyro"/"blackjax"
    injected separately. Bambi's old inference_method="mcmc" was a
    generic switch that read nuts_sampler from kwargs. Correct under
    the old bambi API.
  • Dec 17, 2025 (commit 20c100b, "fix: update model.sample api to
    be consistent with bambi's"): bambi renamed its inference_method
    values (mcmc → pymc, nuts_numpyro → numpyro, etc.). This commit
    mechanically updated the string list, but also deleted the
    kwargs["nuts_sampler"] injection block
    . The flatten-conditional
    was left in place. After this commit, all four NUTS samplers route
    to inference_method="pymc"nuts_sampler="pymc", with no
    recourse.

The bug has been live since Dec 17, 2025 (~5 months).

The fix

self._inference_obj = self.model.fit(
    inference_method=sampler,   # bambi accepts pymc/numpyro/blackjax/nutpie directly
    ...
)

Bambi natively understands each NUTS variant via
_SUPPORTED_METHODS = {"pymc", "numpyro", "blackjax", "nutpie", "vi", "laplace"}
and routes each to a distinct pm.sample(nuts_sampler=...) call. The
collapse conditional has no purpose under the new API.

Other gates in HSSMBase.sample that read the user's sampler arg
(parallel-sampling warning, init default, jitter handling,
step-sampler check) all check sampler directly, not the post-
normalization inference_method value. None are affected.

Test plan

  • Existing HSSM ONNX test suite passes
    (pytest tests/distribution_utils/test_onnx.py tests/distribution_utils/test_onnx_model.py).
  • Tests using sampler="numpyro" should be re-run as part of
    review — they were silently exercising PyMC NUTS pre-fix and will
    exercise the actual numpyro path post-fix. Specifically:
    • tests/test_rlssm.py:300
    • tests/test_save_load.py:39
    • tests/slow/test_mcmc.py:100-101
  • Verify that sampler="blackjax" and sampler="nutpie" now
    invoke their respective NUTS backends in the user's runtime if
    those backends are installed.

Surfaced by

The sbi NLE tutorial work on the sbi-integration branch — exported
ONNX graphs from sbi-trained classifiers tripped a cloudpickle bug
in PyMC's multiprocess sampler because sampler="numpyro" wasn't
actually invoking numpyro (numpyro doesn't fork processes). Once
fixed, the sbi tutorial runs pm.sample(nuts_sampler="numpyro") and
sidesteps the cloudpickle path entirely.

🤖 Generated with Claude Code

HSSMBase.sample collapsed sampler="numpyro" (and blackjax, nutpie) to
inference_method="pymc" before handing off to bambi, which then dispatched
to pm.sample(nuts_sampler="pymc"). The user's sampler choice was
silently downgraded to PyMC NUTS regardless of what they asked for.

Bambi natively accepts inference_method values "pymc", "numpyro",
"blackjax", "nutpie" (and "vi"/"laplace") and routes each to the
matching nuts_sampler. The collapse conditional negated this.

Regression archaeology:
  - Aug 5, 2024 (commit aef3f9b, "Fix compatibility with Bambi (#516)"):
    introduced the working pattern -- inference_method="mcmc" (generic
    NUTS marker) + kwargs["nuts_sampler"]="numpyro"/"blackjax"/etc.
    injected separately. Bambi's old "mcmc" inference_method was generic
    and read nuts_sampler from kwargs. Correct under old bambi semantics.
  - Dec 17, 2025 (commit 20c100b, "fix: update model.sample api to be
    consistent with bambi's"): bambi had renamed its inference_method
    values (mcmc -> pymc, nuts_numpyro -> numpyro, etc.). This commit
    mechanically updated the string list, but ALSO deleted the
    kwargs["nuts_sampler"] = ... injection block. The flatten-conditional
    was left in place. After this commit, all four NUTS samplers route
    to inference_method="pymc" -> nuts_sampler="pymc" with no recourse.
  - The bug has been live since Dec 17, 2025 (about 5 months).

Fix: replace the conditional with inference_method=sampler. Bambi handles
each NUTS variant directly under the new API. The injection block deleted
in commit 20c100b is correctly absent now -- bambi passes
nuts_sampler=sampler_backend to pm.sample explicitly, so injecting it via
kwargs would conflict.

Side effects:
  - sampler="numpyro" now actually invokes numpyro NUTS, which runs
    inside JAX with internal parallelism and does NOT fork worker
    processes. This avoids the cloudpickle path that breaks on
    unpicklable ONNX ModelProto closures (surfaced by the sbi NLE
    tutorial).
  - sampler="blackjax" and sampler="nutpie" similarly now invoke their
    respective backends instead of PyMC NUTS.
  - sampler="pymc" behavior is unchanged (still routes to PyMC NUTS).

Other gates that read the user's `sampler` argument (parallel-sampling
warning at base.py:621, init default at base.py:636, jitter handling at
base.py:644, step-sampler check at base.py:657) all check `sampler`
directly, not the post-normalization inference_method value. None are
affected by this change.

Tests pinning the old behavior:
  - tests/test_rlssm.py:300, tests/test_save_load.py:39,
    tests/slow/test_mcmc.py:100-101 use sampler="numpyro" and were
    silently exercising PyMC NUTS. After this fix they exercise the
    actual numpyro path. Worth re-running as part of PR review.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Copy link
Copy Markdown
Collaborator

@digicosmos86 digicosmos86 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants