Skip to content

feat(onnx): bayesflow → ONNX exporter (transform_bayesflow_to_onnx)#80

Open
AlexanderFengler wants to merge 1 commit into
sbi-connectorfrom
bayesflow-connector
Open

feat(onnx): bayesflow → ONNX exporter (transform_bayesflow_to_onnx)#80
AlexanderFengler wants to merge 1 commit into
sbi-connectorfrom
bayesflow-connector

Conversation

@AlexanderFengler
Copy link
Copy Markdown
Member

Summary

Adds lanfactory.onnx.transform_bayesflow_to_onnx — the bayesflow sibling of transform_sbi_to_onnx from #79. Wraps a trained bayesflow.ContinuousApproximator (NLE) or RatioApproximator (NRE) and writes a single-trial ONNX file consumable by HSSM's loglik_kind="approx_differentiable" path.

The user gesture becomes the same regardless of training framework: hssm.HSSM(loglik="model.onnx", loglik_kind="approx_differentiable") works for LAN, sbi, and now bayesflow exports through the identical HSSM-side code path (no HSSM changes required).

Branch relationship

This branch is stacked on sbi-connector (#79). The diff in this PR is exactly the bayesflow-specific additions; #79's commits are not duplicated. When #79 merges to main, GitHub will offer to auto-retarget this PR's base to main.

What's in this PR

File Action Purpose
src/lanfactory/onnx/bayesflow.py new Exporter module (NLE + NRE wrappers + transform_bayesflow_to_onnx)
src/lanfactory/onnx/__init__.py edit Export the new function
pyproject.toml edit New [bayesflow] optional extra; refactor existing sbi+nflows pair into a new [sbi] extra (symmetric with [bayesflow]); both added to [all] and [dev]
tests/test_bayesflow_nle_export.py new 6 tests: three-way numerical agreement (atol=1e-5), gradient agreement (atol=1e-4), log-prob ordering sanity, three guard tests
tests/test_bayesflow_nre_export.py new 4 tests: same shape for the NRE path
tests/test_bayesflow_hssm_integration.py new End-to-end DDM smoke (pytest.importorskip("hssm"), mirrors test_sbi_hssm_integration.py)
docs/exporting_bayesflow_models.md new Sibling of exporting_sbi_models.md

Architectural contract

Same I/O contract as the sbi exporter:

  • Input: rank-1 tensor of shape (theta_dim + x_dim,), parameters first then observations.
  • Output: rank-0 scalar log-likelihood.
  • Opset: pinned to 17 for jaxonnxruntime reproducibility.

Key implementation choice: the wrapper bakes the bayesflow Standardize layer's accumulated moving_mean / moving_std as torch buffer constants at construction time. This sidesteps the dynamic-shape ops (If, Size, Tile) that the live Keras layer would emit at trace time — jaxonnxruntime doesn't have a Size handler. The constants are correct because training is complete by export time.

v1 constraints (documented, enforced where introspectable)

User must train with:

  • permutation=None (FixedPermutation → aten::ravel, unsupported in opset 17/20)
  • use_actnorm=False (untested in v1)
  • transform=AffineTransform(clamp=False) as an explicit instance (find_transform("affine") silently drops transform_kwargs — bayesflow upstream bug, catalogued in companion HSSMSpine PR)
  • subnet_kwargs.activation="silu" (default hard_silu exports as the fused ONNX op HardSwish, no jaxonnxruntime handler; silu decomposes to Sigmoid + Mul)
  • Identity Adapter (numpy-only Adapter ops can't be baked into ONNX)

Each violation produces an actionable error message at export time.

Test status

  • bayesflow NLE: 6/6 passing
  • bayesflow NRE: 4/4 passing
  • bayesflow ↔ HSSM integration: skip-on-missing-HSSM in this env (designed for the coordinated cross-repo CI matrix)
  • sbi regression check: existing sbi tests still pass (no regression). Note: the bayesflow test modules call torch.set_grad_enabled(True) after importing bayesflow to undo the global autograd disable that bayesflow's torch backend does at import time. This is a known upstream issue documented in the companion HSSMSpine PR's upstream-bugs catalog.

Companion PRs

  • HSSM: bayesflow-integration branch (new), adds docs/tutorials/bayesflow_nle_onnx_integration.ipynb. Sibling, not child, of sbi-integration (#964) — works against HSSM main standalone (Part 1 includes the manual jaxort_only_allow_initializers_as_static_args=False workaround that #964 plans to auto-handle inside onnx2jax).
  • HSSMSpine: bayesflow-onnx-plans branch (stacked on NameError when wandb is not found #9 for the cross-reference to sbi-onnx-integration.md). Adds the design doc and an upstream-bugs catalog covering the five real upstream defects surfaced during this work.

Test plan

  • bayesflow NLE tests pass (6/6)
  • bayesflow NRE tests pass (4/4)
  • No regression on existing sbi tests
  • Cross-repo HSSM integration test (pytest.importorskip("hssm")) once both packages can be installed in the same env

🤖 Generated with Claude Code

Adds lanfactory.onnx.transform_bayesflow_to_onnx, the bayesflow sibling of
transform_sbi_to_onnx (PR #79). Wraps a trained bayesflow
ContinuousApproximator (NLE) or RatioApproximator (NRE) and writes a
single-trial ONNX file consumable by HSSM's loglik_kind="approx_differentiable"
path. Same I/O contract as the sbi exporter (rank-1 input
[theta..., x...], rank-0 scalar log-likelihood, opset 17) so HSSM ingests
both via the same loglik="*.onnx" gesture with zero HSSM-side changes.

What's in this commit

- src/lanfactory/onnx/bayesflow.py: exporter module mirroring sbi.py.
  Contains _BayesflowNLELogProbWrapper and _BayesflowNRELogRatioWrapper.
  Pre-evaluates the bayesflow Standardize layer's moving mean/std to
  torch buffer constants at wrapper construction time so the ONNX trace
  is fully static (avoids If, Size, Tile dynamic-shape ops that
  jaxonnxruntime can't run). Guards on KERAS_BACKEND=torch and identity
  Adapter; both raise actionable errors with concrete fix hints.

- src/lanfactory/onnx/__init__.py: export the new function.

- pyproject.toml: add [bayesflow] optional extra
  (bayesflow>=2.0.8, keras>=3.12), add to [all] and [dev]. Also
  refactors the existing sbi+nflows pair into its own [sbi] extra
  (mirroring the new [bayesflow]) while keeping them in [all].

- tests/test_bayesflow_nle_export.py: 6 tests. Three-way numerical
  agreement (torch reference wrapper <-> onnxruntime <-> jaxonnxruntime)
  at atol=1e-5, gradient agreement at atol=1e-4, log-prob ordering
  sanity, and three guard tests (wrong backend, non-identity adapter,
  wrong mode).

- tests/test_bayesflow_nre_export.py: 4 tests. Same shape for the NRE
  path on a RatioApproximator.

- tests/test_bayesflow_hssm_integration.py: end-to-end DDM smoke
  (pytest.importorskip("hssm")). Mirrors test_sbi_hssm_integration.py.

- docs/exporting_bayesflow_models.md: full constraint catalog
  (KERAS_BACKEND, CouplingFlow knobs, silu vs hard_silu activation
  choice, identity-adapter requirement, JAX x64). Quick-starts for NLE
  and NRE. "Two paths into HSSM" framing alongside the JAX-callable
  path used in bayesflow_lre_integration.ipynb.

v1 constraints (documented, enforced where introspectable)

User must train with:
- permutation=None (FixedPermutation -> aten::ravel, unsupported)
- use_actnorm=False (untested in v1)
- transform=AffineTransform(clamp=False) explicit instance
  (find_transform("affine") drops kwargs - bayesflow upstream bug)
- subnet_kwargs.activation="silu" or another smooth activation
  (default hard_silu emits HardSwish, no jaxonnxruntime handler)
- identity Adapter (numpy-only adapter ops cannot be baked into ONNX)

Bayesflow continuous observations only. MNLE-style discrete + continuous
deferred until upstream MNLE support lands.

Numerical guarantees

19 passing tests across both bayesflow and sbi tracks; no regressions on
the existing sbi exporter. Each export is verified for three-way numerical
agreement at 1e-5 and gradient agreement at 1e-4.

Companion PRs

- HSSM: docs(tutorials): add bayesflow_nle_onnx_integration.ipynb
  on a fresh bayesflow-integration branch off main (sibling, not
  child, of the sbi-integration branch in PR #964).
- HSSMSpine: bayesflow-onnx-integration.md design doc +
  upstream-bugs-from-bayesflow-onnx-work.md catalog of upstream
  defects surfaced during this work (jaxonnxruntime missing
  HardSwish/Size handlers; bayesflow find_transform kwarg-drop bug;
  bayesflow global torch.set_grad_enabled(False) cross-library leak;
  torch.onnx missing aten::ravel/asinh symbolic registrations).

This branch is stacked on sbi-connector (PR #79). When #79 merges,
this PR's base auto-retargets to main.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant