Skip to content

feat(onnx): sbi → ONNX exporter (transform_sbi_to_onnx) for NLE + NRE#79

Open
AlexanderFengler wants to merge 8 commits into
mainfrom
sbi-connector
Open

feat(onnx): sbi → ONNX exporter (transform_sbi_to_onnx) for NLE + NRE#79
AlexanderFengler wants to merge 8 commits into
mainfrom
sbi-connector

Conversation

@AlexanderFengler
Copy link
Copy Markdown
Member

Summary

Adds lanfactory.onnx.transform_sbi_to_onnx, an exporter that converts a
trained sbi neural likelihood estimator
(NLE) or neural ratio estimator (NRE) to an ONNX file consumable by HSSM's
loglik_kind="approx_differentiable" path. The exported file behaves
identically to a LAN-trained ONNX from transform_to_onnx, so HSSM
consumes it with zero HSSM-side code changes for the basic case.

Lives as a sibling module to the existing LAN exporter
(lanfactory/onnx/transform_onnx.pylanfactory/onnx/sbi.py). The new
deps (sbi>=0.26, nflows>=0.14) are gated under the existing [all]
extra; no impact on the default install.

What's in this branch

# Commit What
C1 ac16eda Scaffolding: stub module, [all] extra extended with sbi + nflows
C2 f7c93c8 Spike tests — torch.onnx.export{onnxruntime, jaxonnxruntime} round-trip on a vanilla MLP and an nflows MAF. Kept as permanent regression guards.
C3 bdcabd3 Core exporter — NLE path. Wraps estimator.log_prob(x, condition=θ) as a torch.nn.Module that takes a 1D concatenated (theta, x) input. Three-way numerical agreement (torch / ORT / jaxonnxruntime) at atol=1e-5, gradient agreement at atol=1e-4.
C4 f4a54fe NRE path — wraps the classifier logit as the log-ratio. Same numerical-agreement contract.
C5 b1fd188 Embedding-net coverage tests for FCEmbedding and CNNEmbedding.
C6 87704bb Documentation: docs/exporting_sbi_models.md integration guide, README mention, nav update.
C7b 4990e85 End-to-end HSSM integration test (pytest.importorskip("hssm") skip-on-missing).
C9 222adf5 Critical fix — export rank-1 input contract for HSSM's vmap-over-trials. Previous 2D (1, D) dummy caused Slice axes=[1] which fails under vmap with rank-1 per-trial input.

Architectural contract

  • Single-trial graph: the exported ONNX takes a 1D vector of length
    theta_dim + x_dim, matching HSSM's make_jax_logp_funcs_from_onnx
    per-trial vmap contract.
  • Inside the wrapper: splits on axis 0, unsqueeze(0) to satisfy
    sbi's batched log_prob/forward API, returns scalar via reshape(()).
  • Opset pinned to 17 for reproducibility against jaxonnxruntime.

Test status

  • 13 LANfactory sbi tests passing locally
    (pytest tests/test_sbi_spike_*.py tests/test_sbi_nle_export.py tests/test_sbi_nre_export.py tests/test_sbi_embeddings.py).
  • C7b integration test is pytest.importorskip("hssm"); clean skip
    in the current LANfactory env. Designed to run in coordinated cross-
    repo CI where both LANfactory + HSSM are installed together.

Known limitations (documented in C6 guide)

  • NSF flows blocked by missing SearchSorted op in jaxonnxruntime
    (planned upstream PR, ~50 lines). Tracked in
    HSSMSpine/plans/sbi-onnx-integration.md.
  • MNLE same SearchSorted blocker (categorical lookup uses
    torch.searchsorted). Same upstream PR unlocks both.
  • NLE-MAF on DDM produces qualitatively wrong posteriors (rt is
    continuous but choice is discrete; MAF can't represent the mixed
    structure). The exporter still supports mode="nle" correctly —
    this is a sbi-method/data-shape issue, not an exporter bug. MNLE is
    the correct sbi method for DDM-like data; deferred until the
    upstream PR lands.

Companion PRs

Consumed by the HSSM-side keystone tutorial in
lnccbrown/HSSM#sbi-integration.
The tutorial uses an importlib fallback to load this exporter even
without LANfactory installed, so this PR can be merged independently.

Test plan

  • LANfactory sbi tests pass (13/13)
  • Cross-repo CI exercise of the integration test (test_sbi_hssm_integration.py)
    once both packages can be installed in the same env

🤖 Generated with Claude Code

AlexanderFengler and others added 8 commits May 13, 2026 17:31
Adds a stub transform_sbi_to_onnx in lanfactory/onnx/sbi.py as a sibling
of the existing LAN exporter. Extends the `all` extra to pull sbi and
nflows, and adds jaxonnxruntime to the dev group for round-trip testing.

First commit of the sbi -> HSSM integration plan
(plans/sbi-onnx-integration.md in HSSMSpine). Implementation lands in
C3 (NLE) and C4 (NRE).

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Adds two permanent regression-guard tests validating the
torch.onnx.export to {onnxruntime, jaxonnxruntime} toolchain that the
sbi exporter (C3) will sit on top of. Both tests assert three-way
numerical agreement to 1e-5 on fixed inputs.

The MAF spike surfaced a real friction: the nflows MAF exports a
Reshape whose shape argument is a Constant node rather than a model
initializer, which jaxonnxruntime default strict mode rejects. Setting
jaxort_only_allow_initializers_as_static_args = False works around it.

Architectural implication for C3: HSSM onnx2jax.py does not set this
flag today, so sbi-exported flow graphs will fail to load through the
HSSM make_jax_logp_funcs_from_onnx path as-is. C3 should either
constant-fold the exported graph (preferred, keeps HSSM untouched) or
we will need a small HSSM-side patch.

Also adds onnxruntime>=1.17 and nflows>=0.14 to the dev dependency
group so uv sync --group dev is sufficient to run these tests.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
…(C3)

Replaces the C1 stub with the real transform_sbi_to_onnx implementation
for mode=nle. The exporter wraps an sbi ConditionalDensityEstimator
(NLE_A trained estimator) as a torch.nn.Module whose forward(combined)
splits a concatenated (theta, x) input and returns log p(x | theta) with
sbi standardization Jacobian baked into the traced graph. Exports a
single-trial graph at opset 17, matching the LAN convention and HSSM
vmap-over-trials expectation.

Rejection paths:
  - Score-based, flow-matching, TabPFN estimators raise ValueError.
  - NLE mode requires .log_prob(input, condition); clear TypeError if
    absent.
  - NRE mode currently raises NotImplementedError (lands in C4).

Tests in test_sbi_nle_export.py train a tiny 2D Gaussian NLE_A with MAF
and verify:
  1. Three-way numerical agreement (torch / onnxruntime / jaxonnxruntime)
     to atol=1e-5 on a fixed test point.
  2. Gradient agreement (torch.autograd vs jax.grad of the translated
     graph) to atol=1e-4.
  3. Sanity check that log-prob ordering matches the analytical Gaussian
     (near-mean point ranks above far point).
  4. Three rejection-path tests for the error contracts above.

Two findings surfaced during C3 that affect later commits:

  - 1D MAFs in sbi collapse to a degenerate Gaussian path with zero-width
    Gemm contractions that jaxonnxruntime cannot handle. The exporter
    must be exercised with >=2D theta and x. Documented in the simulator
    docstring.

  - jaxonnxruntime silently truncates int64 indices in exported flow
    graphs to int32, causing ~0.5 drift in log-prob outputs. The fix is
    jax.config.update("jax_enable_x64", True) BEFORE any JAX import.
    The test file sets this. C7 will decide whether HSSM onnx2jax.py
    should also set it globally (mirrors the C2.5 flag patch) or whether
    it stays a user responsibility documented in C6/C8.

Also adds sbi-logs/ to .gitignore (sbi auto-writes tensorboard logs
during training).

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Extends transform_sbi_to_onnx to support mode="nre". The wrapper splits
the concatenated (theta, x) input and routes through the trained
RatioEstimator forward, returning the logit log r(x, theta). Up to a
theta-independent constant the logit IS log p(x | theta), so MCMC and
HSSM posterior path treat it as the likelihood. No Jacobian correction
is needed since ratios are invariant to z-score standardization.

Rejection: passing an estimator with .log_prob in mode="nre" raises
TypeError, since that signals a density estimator (NLE) rather than a
ratio classifier (NRE). The NLE path has the symmetric check.

New test file tests/test_sbi_nre_export.py trains a tiny 2D Gaussian
NRE_A and verifies the same three-way numerical agreement (atol=1e-5)
and gradient agreement (atol=1e-4) as the NLE path, plus a sanity
ordering check (log-ratio higher at near-theta than far-theta).

The C3 NRE-not-implemented test was repurposed into a cross-mode
rejection test: passing an NLE density estimator with mode="nre" now
raises a clear TypeError instead of NotImplementedError.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Adds tests/test_sbi_embeddings.py exercising NRE_A with two embedding
nets on x:
  - FCEmbedding (representative flat-MLP embedding)
  - CNNEmbedding (1D conv stack; validates Conv / MaxPool / etc. survive
    torch.onnx.export and translate cleanly into jaxonnxruntime)

Both tests train a tiny 2D-theta / 10-dim-x linear-Gaussian classifier
and assert three-way numerical agreement (torch / onnxruntime /
jaxonnxruntime, atol=1e-5).

Other sbi embeddings (PermutationInvariantEmbedding, ResNetEmbedding1D,
ResNetEmbedding2D, LRUEmbedding, TransformerEmbedding, CausalCNNEmbedding,
SpectralConvEmbedding) are out of v1 scope; can be added as follow-up
regressions if a user needs them.

C5 finding: sbi build_mlp_classifier defaults to nn.LayerNorm between
hidden layers, and jaxonnxruntime does NOT implement the
LayerNormalization op (raises NotImplementedError at translation time).
The fix is to pass norm_layer=nn.Identity through classifier_nn kwargs.
This constraint will be documented in C6.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Adds docs/exporting_sbi_models.md as a Guides entry alongside the
MLflow and HuggingFace integration guides. Wires it into mkdocs.yml
nav. Adds a one-line mention in README.md pointing users to the
guide.

The guide covers:
  - Installation (pip install lanfactory[all])
  - Quick-start examples for NLE and NRE
  - Supported architecture matrix (NLE+MAF, NRE+MLP/FC/CNN embeddings)
  - Explicitly-out-of-scope list (NSF, FMPE, NPSE, NPE, TabPFN) with
    one-sentence rationales each
  - Known constraints surfaced during C2-C5:
      * Use 2D+ for theta and x (1D MAFs degenerate in sbi)
      * Disable LayerNorm in NRE MLP classifiers (norm_layer=Identity)
      * Enable jax_enable_x64 before importing JAX in the consumer
  - Numerical guarantees from the regression tests (atol=1e-5 forward,
    atol=1e-4 gradients)
  - Float precision interaction with PyMC

The new function transform_sbi_to_onnx is auto-documented on
docs/api/onnx.md via the existing :::lanfactory.onnx mkdocstrings
directive — no manual API page changes needed.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Adds tests/test_sbi_hssm_integration.py exercising the full keystone
pipeline:
  1. Train tiny sbi NLE_A on synthetic DDM data (ssm-simulators).
  2. Export via lanfactory.onnx.transform_sbi_to_onnx.
  3. Build HSSM model with model="ddm",
     loglik_kind="approx_differentiable", loglik=<path>.
  4. Short MCMC (500 draws + 500 tune, 2 chains) and verify posterior
     mean recovery within +/- 2 sigma and r_hat < 1.05.

Two test functions:
  - test_hssm_model_builds_from_sbi_onnx: verifies the exported ONNX
    loads cleanly into hssm.HSSM (no sampling).
  - test_hssm_mcmc_recovers_ddm_parameters: full MCMC + recovery
    assertion.

Skip guard via pytest.importorskip("hssm") so the test no-ops when HSSM
is not in the env. Currently the test is a no-op in LANfactory's local
uv venv because LANfactory's flax>=0.10.6 pin pulls a JAX version
incompatible with HSSM's numpyro 0.21.0 pin. The test is intended to
run only in a coordinated cross-repo CI environment that resolves both
packages together. Plan tracks this as future ecosystem cleanup.

The C7a HSSM patch (commit d1d7ffe on HSSM sbi-integration branch)
makes jax_enable_x64 self-managed inside HSSM's onnx2jax, so this test
does not need to set it explicitly.

Marked @pytest.mark.flaky(reruns=2, reruns_delay=5) on both test
functions to match HSSM's existing ONNX-test convention.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Surfaced by running the C8 notebook in HSSM: pymc.sample raised
"IndexError: list assignment index out of range" inside
jaxonnxruntime/onnx_ops/slice.py:113 (sub_indx[axis] = slices[i]).

Root cause: HSSM's make_jax_logp_funcs_from_onnx vmaps the per-trial
loglik over a 1D concatenated input vector (param_vector + data) — see
HSSM repos/HSSM/src/hssm/distribution_utils/onnx.py around line 115:

    input_vector = jnp.concatenate((param_vector, data))
    return jax_func(input_vector)

But the C3/C4 exporter was tracing with a 2D dummy of shape
(1, theta_dim + x_dim), which made torch.onnx.export emit Slice ops with
axes=[1]. Under HSSM's vmap the per-trial input is rank-1, so axes=[1]
is out of bounds for the inner Slice handler.

LAN exports don't trip on this because the LAN graph is pure
MatMul/Add/activation — broadcast-rank-agnostic. Ours has explicit Slice
ops from `combined[..., :theta_dim]` and `combined[theta_dim:]`.

Fix:
  - Trace the wrapper with a rank-1 dummy (shape (theta_dim+x_dim,))
    so Slice ops emit axes=[0], which survives HSSM's vmap.
  - Inside _NLELogProbWrapper.forward and _NRELogRatioWrapper.forward,
    take a 1D combined input, split on axis 0, then .unsqueeze(0) the
    two halves to satisfy sbi's batched log_prob / classifier APIs.
    Reshape the (1, 1) output back to () so HSSM's downstream .squeeze()
    sees a clean scalar.
  - Updated module docstring to document the rank-1 contract and why.

Tests:
  - test_sbi_nle_export.py, test_sbi_nre_export.py, test_sbi_embeddings.py:
    pass rank-1 inputs through onnxruntime and jaxonnxruntime; rank-1
    theta_np_1d / x_np_1d for the gradient tests.
  - All 13 sbi tests still green at the same atol thresholds
    (1e-5 forward, 1e-4 gradients).

User impact: anyone who already exported a .onnx with the old C3/C4
code needs to re-export with this commit. The exported .onnx is the
durable artifact — no API change in the call site.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Copilot AI review requested due to automatic review settings May 18, 2026 03:01
Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

Adds an ONNX exporter for trained sbi NLE/NRE estimators so they can be consumed by HSSM’s differentiable likelihood path, alongside documentation and regression coverage for ONNX/JAX round-trips.

Changes:

  • Introduces lanfactory.onnx.transform_sbi_to_onnx for rank-1 single-trial NLE/NRE ONNX export.
  • Adds sbi/nflows-related optional/dev dependencies and lockfile updates.
  • Adds docs and tests covering MLP/MAF round-trips, NLE/NRE exports, embeddings, and optional HSSM integration.

Reviewed changes

Copilot reviewed 12 out of 14 changed files in this pull request and generated 2 comments.

Show a summary per file
File Description
src/lanfactory/onnx/sbi.py Implements the new sbi-to-ONNX exporter and wrappers.
src/lanfactory/onnx/__init__.py Exposes transform_sbi_to_onnx from the ONNX package.
pyproject.toml Adds optional and dev dependencies for sbi ONNX export/testing.
uv.lock Locks new transitive dependencies.
tests/test_sbi_spike_mlp_roundtrip.py Adds baseline MLP ONNX/JAX/ORT round-trip test.
tests/test_sbi_spike_maf_roundtrip.py Adds nflows MAF ONNX/JAX/ORT round-trip test.
tests/test_sbi_nle_export.py Adds NLE exporter numerical, gradient, and validation tests.
tests/test_sbi_nre_export.py Adds NRE exporter numerical, gradient, and sanity tests.
tests/test_sbi_embeddings.py Adds NRE embedding-net export round-trip tests.
tests/test_sbi_hssm_integration.py Adds optional end-to-end HSSM integration tests.
docs/exporting_sbi_models.md Documents installation, usage, supported architectures, and constraints.
README.md Mentions the new sbi ONNX exporter.
mkdocs.yml Adds the sbi exporter guide to docs navigation.
.gitignore Ignores sbi-generated TensorBoard logs.
Comments suppressed due to low confidence (1)

src/lanfactory/onnx/sbi.py:107

  • NLE mode accepts any module with .log_prob, which also includes posterior-shaped sbi estimators such as NPE/SNPE. Those estimators model p(theta | x) with the opposite input/condition semantics, so this wrapper can silently export a posterior density as if it were a likelihood (especially when theta_dim == x_dim). Add an explicit estimator-family check for true likelihood estimators before building the NLE wrapper, or require callers to pass a likelihood-specific estimator type that can be validated.
    if mode == "nle":
        if not hasattr(estimator, "log_prob"):
            raise TypeError(
                f"NLE mode requires an estimator with "
                f".log_prob(input, condition); got {estimator_cls} which lacks "
                f"it. If this is an NRE ratio classifier, use mode='nre' "
                f"instead."
            )
        wrapper: nn.Module = _NLELogProbWrapper(

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment thread pyproject.toml
@@ -74,6 +79,9 @@ dev = [
"ruff>=0.14.4",
"types-PyYAML",
"mlflow>=3.6.0",
Comment on lines +36 to +43
_UNSUPPORTED_ESTIMATORS: frozenset[str] = frozenset(
{
"ScoreEstimator",
"ConditionalScoreEstimator",
"FlowMatchingEstimator",
"ConditionalFlowMatchingEstimator",
"TabPFNEstimator",
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants