feat(onnx): bayesflow → ONNX exporter (transform_bayesflow_to_onnx)#80
Open
AlexanderFengler wants to merge 1 commit into
Open
feat(onnx): bayesflow → ONNX exporter (transform_bayesflow_to_onnx)#80AlexanderFengler wants to merge 1 commit into
AlexanderFengler wants to merge 1 commit into
Conversation
Adds lanfactory.onnx.transform_bayesflow_to_onnx, the bayesflow sibling of transform_sbi_to_onnx (PR #79). Wraps a trained bayesflow ContinuousApproximator (NLE) or RatioApproximator (NRE) and writes a single-trial ONNX file consumable by HSSM's loglik_kind="approx_differentiable" path. Same I/O contract as the sbi exporter (rank-1 input [theta..., x...], rank-0 scalar log-likelihood, opset 17) so HSSM ingests both via the same loglik="*.onnx" gesture with zero HSSM-side changes. What's in this commit - src/lanfactory/onnx/bayesflow.py: exporter module mirroring sbi.py. Contains _BayesflowNLELogProbWrapper and _BayesflowNRELogRatioWrapper. Pre-evaluates the bayesflow Standardize layer's moving mean/std to torch buffer constants at wrapper construction time so the ONNX trace is fully static (avoids If, Size, Tile dynamic-shape ops that jaxonnxruntime can't run). Guards on KERAS_BACKEND=torch and identity Adapter; both raise actionable errors with concrete fix hints. - src/lanfactory/onnx/__init__.py: export the new function. - pyproject.toml: add [bayesflow] optional extra (bayesflow>=2.0.8, keras>=3.12), add to [all] and [dev]. Also refactors the existing sbi+nflows pair into its own [sbi] extra (mirroring the new [bayesflow]) while keeping them in [all]. - tests/test_bayesflow_nle_export.py: 6 tests. Three-way numerical agreement (torch reference wrapper <-> onnxruntime <-> jaxonnxruntime) at atol=1e-5, gradient agreement at atol=1e-4, log-prob ordering sanity, and three guard tests (wrong backend, non-identity adapter, wrong mode). - tests/test_bayesflow_nre_export.py: 4 tests. Same shape for the NRE path on a RatioApproximator. - tests/test_bayesflow_hssm_integration.py: end-to-end DDM smoke (pytest.importorskip("hssm")). Mirrors test_sbi_hssm_integration.py. - docs/exporting_bayesflow_models.md: full constraint catalog (KERAS_BACKEND, CouplingFlow knobs, silu vs hard_silu activation choice, identity-adapter requirement, JAX x64). Quick-starts for NLE and NRE. "Two paths into HSSM" framing alongside the JAX-callable path used in bayesflow_lre_integration.ipynb. v1 constraints (documented, enforced where introspectable) User must train with: - permutation=None (FixedPermutation -> aten::ravel, unsupported) - use_actnorm=False (untested in v1) - transform=AffineTransform(clamp=False) explicit instance (find_transform("affine") drops kwargs - bayesflow upstream bug) - subnet_kwargs.activation="silu" or another smooth activation (default hard_silu emits HardSwish, no jaxonnxruntime handler) - identity Adapter (numpy-only adapter ops cannot be baked into ONNX) Bayesflow continuous observations only. MNLE-style discrete + continuous deferred until upstream MNLE support lands. Numerical guarantees 19 passing tests across both bayesflow and sbi tracks; no regressions on the existing sbi exporter. Each export is verified for three-way numerical agreement at 1e-5 and gradient agreement at 1e-4. Companion PRs - HSSM: docs(tutorials): add bayesflow_nle_onnx_integration.ipynb on a fresh bayesflow-integration branch off main (sibling, not child, of the sbi-integration branch in PR #964). - HSSMSpine: bayesflow-onnx-integration.md design doc + upstream-bugs-from-bayesflow-onnx-work.md catalog of upstream defects surfaced during this work (jaxonnxruntime missing HardSwish/Size handlers; bayesflow find_transform kwarg-drop bug; bayesflow global torch.set_grad_enabled(False) cross-library leak; torch.onnx missing aten::ravel/asinh symbolic registrations). This branch is stacked on sbi-connector (PR #79). When #79 merges, this PR's base auto-retargets to main. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Summary
Adds
lanfactory.onnx.transform_bayesflow_to_onnx— the bayesflow sibling oftransform_sbi_to_onnxfrom #79. Wraps a trainedbayesflow.ContinuousApproximator(NLE) orRatioApproximator(NRE) and writes a single-trial ONNX file consumable by HSSM'sloglik_kind="approx_differentiable"path.The user gesture becomes the same regardless of training framework:
hssm.HSSM(loglik="model.onnx", loglik_kind="approx_differentiable")works for LAN, sbi, and now bayesflow exports through the identical HSSM-side code path (no HSSM changes required).Branch relationship
This branch is stacked on
sbi-connector(#79). The diff in this PR is exactly the bayesflow-specific additions; #79's commits are not duplicated. When #79 merges tomain, GitHub will offer to auto-retarget this PR's base tomain.What's in this PR
src/lanfactory/onnx/bayesflow.pytransform_bayesflow_to_onnx)src/lanfactory/onnx/__init__.pypyproject.toml[bayesflow]optional extra; refactor existing sbi+nflows pair into a new[sbi]extra (symmetric with[bayesflow]); both added to[all]and[dev]tests/test_bayesflow_nle_export.pytests/test_bayesflow_nre_export.pytests/test_bayesflow_hssm_integration.pypytest.importorskip("hssm"), mirrorstest_sbi_hssm_integration.py)docs/exporting_bayesflow_models.mdexporting_sbi_models.mdArchitectural contract
Same I/O contract as the sbi exporter:
(theta_dim + x_dim,), parameters first then observations.jaxonnxruntimereproducibility.Key implementation choice: the wrapper bakes the bayesflow
Standardizelayer's accumulatedmoving_mean/moving_stdas torch buffer constants at construction time. This sidesteps the dynamic-shape ops (If,Size,Tile) that the live Keras layer would emit at trace time —jaxonnxruntimedoesn't have aSizehandler. The constants are correct because training is complete by export time.v1 constraints (documented, enforced where introspectable)
User must train with:
permutation=None(FixedPermutation →aten::ravel, unsupported in opset 17/20)use_actnorm=False(untested in v1)transform=AffineTransform(clamp=False)as an explicit instance (find_transform("affine")silently dropstransform_kwargs— bayesflow upstream bug, catalogued in companion HSSMSpine PR)subnet_kwargs.activation="silu"(defaulthard_siluexports as the fused ONNX opHardSwish, no jaxonnxruntime handler; silu decomposes toSigmoid + Mul)Adapter(numpy-only Adapter ops can't be baked into ONNX)Each violation produces an actionable error message at export time.
Test status
torch.set_grad_enabled(True)after importing bayesflow to undo the global autograd disable that bayesflow's torch backend does at import time. This is a known upstream issue documented in the companion HSSMSpine PR's upstream-bugs catalog.Companion PRs
bayesflow-integrationbranch (new), addsdocs/tutorials/bayesflow_nle_onnx_integration.ipynb. Sibling, not child, ofsbi-integration(#964) — works against HSSMmainstandalone (Part 1 includes the manualjaxort_only_allow_initializers_as_static_args=Falseworkaround that #964 plans to auto-handle insideonnx2jax).bayesflow-onnx-plansbranch (stacked on NameError when wandb is not found #9 for the cross-reference tosbi-onnx-integration.md). Adds the design doc and an upstream-bugs catalog covering the five real upstream defects surfaced during this work.Test plan
pytest.importorskip("hssm")) once both packages can be installed in the same env🤖 Generated with Claude Code