Skip to content

Ensure distribution moments/samples are jax arrays#2206

Open
tillahoffmann wants to merge 3 commits into
pyro-ppl:masterfrom
tillahoffmann:add-jax-array-output-test
Open

Ensure distribution moments/samples are jax arrays#2206
tillahoffmann wants to merge 3 commits into
pyro-ppl:masterfrom
tillahoffmann:add-jax-array-output-test

Conversation

@tillahoffmann

@tillahoffmann tillahoffmann commented Jun 8, 2026

Copy link
Copy Markdown
Collaborator

Summary

This PR makes numpyro distributions return concrete jax.Array from their public methods (mean, variance, entropy, sample, log_prob, cdf, icdf, …) and annotates them accordingly, so that typed downstream code can actually use distribution outputs. To make that true at runtime, parameters are coerced to jax.Array at construction.

The change aligns numpyro with JAX's own typing best practiceswide on input (ArrayLike), strict on output (Array) — and removes a class of type errors that typed consumers will hit once they rely on numpyro's types.

The near-term, standalone win is the parameter coercion: today whether a stored parameter (self.loc, self.rate, …) is a jax.Array, a numpy array, or a python scalar depends on the input — this PR makes it consistently a jax.Array. The output annotations are the cheap, correct-by-construction follow-through; they become observable to consumers once numpyro ships py.typed (a deliberate follow-up — see below).

This is necessarily runtime work, not just an annotation change: many moments return a stored parameter directly (Poisson.mean is self.rate), so an Array return type is only honest if that parameter is actually an array — typing the outputs without making the runtime match would just lie to the type checker.

It also changes one behavior deliberately: argument validation (validate_args=True) is now an eager-only operation. That is the part most worth discussing, so it is called out explicitly below.

Motivation: the output annotations are wrong, and it bites consumers

Public methods were annotated to return jax.typing.ArrayLike:

ArrayLike = Union[jax.Array, np.ndarray, np.bool, np.number, bool, int, float, complex]

ArrayLike is a deliberately wide input union. JAX's typing guidance is explicit that it is the right type for function inputs (be liberal in what you accept) and that outputs should be annotated Array (be strict in what you return). numpyro applied ArrayLike to outputs too, which is incorrect: mean/sample/log_prob/… (should) always produce a concrete array.

The cost lands on every typed consumer. The scalar members of the union (bool/int/float/complex) support neither indexing, .shape, .reshape, nor assignment to an Array parameter, so ordinary usage fails type checking:

d = dist.Normal(jnp.zeros(5), 1.0)
d.mean[3]                  # error: int/float/complex are not indexable
d.variance.shape           # error: .shape not defined on int/float/complex
some_jit_fn(d.mean)        # error: ArrayLike not assignable to Array
d.mean > 1                 # error: ">" not supported for "complex" and "Literal[1]"

Under pyright, every one of those lines is a type error when the method returns ArrayLike (each indexing/attribute access fails once per scalar member of the union — int, float, complex, …), and all of them type-check cleanly once it returns Array.

This is currently latent (numpyro ships no py.typed, so checkers treat it as untyped Any), but it is a trap: the moment a consumer relies on numpyro's types — or numpyro starts enforcing its own — these annotations break real code.

The benefit here is for downstream typed code. This PR does not aim to make numpyro's own modules type-clean: they are not mypy-enforced, and their residual errors are dominated by pre-existing [override] signature mismatches unrelated to these annotations (enabling enforcement is called out as out of scope below).

What this PR does

  1. Coerce parameters to jax.Array at construction. Previously parameter storage was inconsistent: most parameters flow through promote_shapes, which preserved input types and only produced arrays when a broadcast forced a reshape, while some distributions stored the raw argument directly — so whether self.loc ended up an array depended on the input shapes and dtype. Now promote_shapes takes an opt-in promote_array=True (typed to return list[Array]) used at the parameter-assignment sites, and the few directly-assigned parameters are wrapped in jnp.asarray. As a result self.loc, self.scale, … are statically and dynamically jax.Array.

This is runtime work, not just annotation, on purpose: annotating the outputs as Array without actually returning arrays would be a lie to the type checker — many moments simply return a stored parameter (Poisson.mean is self.rate), so the annotation is only honest if the parameter is genuinely an array. Typing that isn't backed by the runtime is worse than no typing.

  1. Annotate outputs as Array. Return-position ArrayLike becomes Array across the distribution modules. Constructor parameters and method arguments (value, q) keep ArrayLike, so callers can still pass python scalars / numpy arrays.

  2. Make instance types reachable through the metaclass. DistributionMeta.__call__ is annotated to return the concrete class being constructed, so dist.Normal(...) is typed Normal (not Any, and not the base Distribution) and the -> Array annotations on its methods are visible to consumers.

  3. Use functional array ops on inputs where appropriate. Where a method only needed a shape it now uses jnp.shape(value) rather than value.shape, which is correct for any ArrayLike (and incidentally more robust). A few methods that genuinely index value coerce it once.

Parameters that must remain non-array are preserved: pytree aux fields (total_count on the multinomial family, adj_matrix on CAR) are static metadata — coercing them to a jax.Array would make them tracers under jit/vmap and break sampling — so they are left un-coerced, as is a scipy-sparse adj_matrix (CAR(is_sparse=True)).

The coercion touches a parameter-assignment site in essentially every distribution, so correctness rests on tests that run across all distributions rather than on per-site review:

  • test_params_and_outputs_are_arrays — samples, moments, and stored (non-aux) parameters are jax.Array. This catches any data parameter the sweep missed.
  • test_aux_fields_are_not_jax_arrays — no pytree aux field is ever a jax.Array. This catches the opposite mistake: an aux field (e.g. total_count) wrongly coerced, which would break jit/vmap (for instance, a jitted Multinomial(...).sample(...) needs a concrete total_count).

The deliberate behavior change: validation is eager-only

Because parameters are coerced to arrays before Distribution.__init__ runs validate_args, an out-of-bounds constant parameter inside jit no longer raises. Previously (#775) it did, by keeping parameters as concrete (non-jax) values so the validity check could constant-fold. #775 added this so that a typo'd literal — Normal(0., -1.) in a jitted model — would fail loudly at trace time rather than silently produce NaNs. That is a real concern, so this is worth making the case for rather than glossing over.

The guarantee it provided was always narrow. It fired only for parameters that were compile-time constants; the moment the same bad value arrives through a traced argument (Normal(0., scale) with scale = -1.), validation cannot run — you can't branch on a tracer — and never has. So on master today, with the same validate_args=True and the same invalid value, whether you get an error depends on how the value was staged:

# constant invalid scale -> raises
jax.jit(lambda: dist.Normal(0.0, -1.0, validate_args=True).log_prob(0.5))()

# traced invalid scale -> silently does NOT raise (same bad value!)
jax.jit(lambda s: dist.Normal(0.0, s, validate_args=True).log_prob(0.5))(-1.0)

A safety net that catches -1.0 written as a literal but not -1.0 passed as an argument is not one a user can rely on without knowing how their parameters get staged.

The new behavior is simpler to state: validate_args=True validates eagerly and does nothing under jit. Eager validation is unchanged — invalid parameters raise exactly as before outside jit, which is where it is most useful for debugging (construct the distribution eagerly, or run once without jit, to surface the error).

This also lines up with how out-of-support sample validation already behaves. That check warns only when the support mask is concrete, so it too goes silent under jit whenever the sampled value is traced — which is the usual case, since in a model/MCMC/SVI the value fed to log_prob is the (traced) data. The only situation where it still fires under jit is a constant-support distribution with a constant value, the same accidental-constant-folding corner that parameter validation used to rely on. So under realistic jitted usage both parameter and sample validation are effectively eager-only.

Performance

Coercion is a jnp.asarray per parameter, but it is skipped for arguments that are already jax.Array (a cheap isinstance check instead of a dispatch), so the cost falls almost entirely on constructing from python scalars / numpy. Eager-construction microbenchmarks (min of repeated runs):

construction master branch
Normal(0., 1.) (python scalars) ~5.4 µs ~31 µs
Normal(jnp.asarray(0.), jnp.asarray(1.)) ~1.8 µs ~2.0 µs
Normal(jnp.zeros(100), 1.) ~9.7 µs ~10.1 µs
MultivariateNormal(jnp.zeros(5), jnp.eye(5)) ~30.7 µs ~30.6 µs

So constructing from jax arrays — the typical case in real code, and what happens under tracing — is unaffected. Only constructing from python-scalar literals pays the full conversion (Normal(0., 1.) ≈5×), and even then it is tens of microseconds. The jitted MCMC/SVI path is unchanged: the model and the distributions it builds are traced, so the coercion folds away (a jitted log-density constructing Normal + Beta is ~3.1 µs/call on both branches, with equal compile time).

Explicitly not in this PR

Enabling mypy enforcement for the distribution modules: the output annotations are now correct, but fully type-clean modules also require resolving pre-existing [override] signature mismatches and Optional handling, which are out of scope here.

Widening constructor inputs to ArrayLike — see Next steps.

Next steps

This PR fixes the output side of the convention (strict Array out). The input side is a natural follow-up: a number of constructor parameters are still annotated with the strict Array (e.g. CategoricalProbs.probs, Dirichlet.concentration, MatrixNormal.loc/scale_tril_*, LowRankMultivariateNormal.*, Distribution.mask), so passing a numpy array or a python scalar to them is a type error even though it works at runtime. These should be widened to ArrayLike (be liberal in what you accept).

Widening the annotations is the easy part; the work is that it surfaces constructor bodies that access .ndim/.shape/.reshape/etc. on a parameter before it is coerced to an array (those attributes are not defined on the scalar members of ArrayLike). So the follow-up is "widen the input annotations and coerce the parameters at the top of each affected constructor", done together. It is kept separate here to keep this PR focused on outputs and to give the input change its own review/CI cycle.

(CAR.adj_matrix is a deliberate exception — it also accepts a scipy.sparse matrix, so it needs a union rather than plain ArrayLike.)

Testing

test/test_distributions.py passes in full, including the new test_outputs_are_arrays and test_aux_fields_are_not_jax_arrays. The eager-only argument/sample validation behavior is covered by the existing validation tests, updated to reflect the new semantics.

@tillahoffmann tillahoffmann marked this pull request as draft June 8, 2026 21:57
@tillahoffmann tillahoffmann force-pushed the add-jax-array-output-test branch from ee74d90 to 2e2dcc1 Compare June 9, 2026 18:51
@tillahoffmann tillahoffmann marked this pull request as ready for review June 9, 2026 22:52
@juanitorduz

Copy link
Copy Markdown
Collaborator

Hi @tillahoffmann ! Coult we try adding these modified modules in https://github.com/pyro-ppl/numpyro/blob/master/pyproject.toml#L100 so that mypy runs the type checks? If we get many errors, we can tackle them in smaller PRs 🤞

@tillahoffmann

Copy link
Copy Markdown
Collaborator Author

Good idea, and yeah it'd need to be done carefully — mypy is a hard gate in CI (make lint runs mypy numpyro and fails the build on errors), not just informational, I think.

I tried adding the modified modules to the ignore_errors = false list locally: it surfaces ~224 errors, mostly pre-existing rather than from this PR — [override] signature mismatches (59), .shape/.ndim on ArrayLike unions (55), [arg-type]/[return-value] (66), etc.

So I think the path you're describing works best in the other order: drive each module's errors to 0 while it's still un-enforced, then flip it into the enforced list in that same PR (rather than adding it first, which would immediately red CI).

Happy to start chipping away at the modules in follow-ups — I've already got input-widening (ArrayLike params) started, which covers a good chunk of the union-attr ones.

@juanitorduz

Copy link
Copy Markdown
Collaborator

Great! @tillahoffmann ! Feel free to open issues and we can all help with this innitiative :)

@tillahoffmann

Copy link
Copy Markdown
Collaborator Author

Sounds good, how shall we proceed with this one?

@juanitorduz

Copy link
Copy Markdown
Collaborator

Sounds good, how shall we proceed with this one?

I would merge it, and then we could build on top of this and fix edge cases with mypy down the line :)

@tillahoffmann

Copy link
Copy Markdown
Collaborator Author

Sounds like a plan. I'll start tackling some of those after we merge this one—I think there might be a lot of merge conflicts otherwise.

@juanitorduz juanitorduz added the enhancement New feature or request label Jun 13, 2026
@tillahoffmann

tillahoffmann commented Jun 23, 2026

Copy link
Copy Markdown
Collaborator Author

@juanitorduz, @fehiepsi, @Qazalbash, are you happy for me to merge this?

@fehiepsi

Copy link
Copy Markdown
Member

Sorry I missed the scope of the CL. I'm not sure why we want to make this change. The jax best practice seems to apply for jax function. Those distribution methods could be flexible I guess. My main concern is the performance. It was intentional to have separate treatment for jax arrays and numpy arrays here and there.

I still don't understand why we need to enforce jax array outputs though. Could you elaborate?

@tillahoffmann

Copy link
Copy Markdown
Collaborator Author

Re-reading the diff, this PR conflated two separate claims, and your pushback applies to one — the other I think is worth keeping.

Parameters coerced to jax.Array (storage) — dropping this. You're right: it costs on python-scalar construction and overrides the deliberate numpy/jax separation.

Outputs annotated Array (return type) — the part worth keeping, and the answer to "why." It's for consumers writing typed code. The methods return ArrayLike, which includes bool/int/float/complex — not indexable, no .shape, and complex has no ordering. Under pyright, ordinary usage errors:

d = dist.Normal(jnp.zeros(5), 1.0)
d.mean[3]            # __getitem__ not defined on int/float/complex
d.variance.shape     # shape unknown on scalar members
d.mean < 0           # "<" not supported for complex

All clean once the return is Array. The return is the API boundary, and ArrayLike is an unusable promise to a caller — internals can stay as flexible as we like; only the boundary is normalized.

Revised PR: return-position ArrayLike → Array, plus jnp.asarray at the ~30 returns that hand back a stored parameter directly (e.g. Poisson.mean is self.rate). No constructor / promote_shapes / validation changes — the diff is signatures, those returns, and one test. Construction perf returns to master.

Does honest Array outputs with internals untouched work for you?

@fehiepsi

fehiepsi commented Jun 23, 2026

Copy link
Copy Markdown
Member

The type error in your example looks expected to me. The code assumes that the output is jax Array, which is unexpected (I understand that it might be expected to some users). My point is we need to use Array only for special distributions (the places that we know that they should be jax arrays). and keep ArrayLike for the rest.

@tillahoffmann

tillahoffmann commented Jun 23, 2026

Copy link
Copy Markdown
Collaborator Author

Makes sense — and I think we're almost there with narrowing outputs to Array anyway.

My main reason for wanting Array here is that ArrayLike isn't very meaningful as an output type. It's Array | np.ndarray | bool | int | float | …, so on Normal.sample() it advertises that the caller might get back a bare float, or a bool — for a method that always returns a jax array. It's not wrong (an array is an ArrayLike), but it does not constrain things and points the caller at cases that never happen. ArrayLike is the right type for an input, where the method really does accept any of those; on an output it just promises less than we actually deliver.

In practice we're most of the way there already. I parametrized the output-type check per method across the suite:

method already returns jax.Array
sample 97 / 98
log_prob 95 / 95
entropy 30 / 30
mean 62 / 81
variance 58 / 69

sample/log_prob/entropy are already jax.Array because they go through random.* / jnp / jax.scipy.special (except Delta.sample which returns the input unchanged). Likewise, ~80% of the moments are already Arrays.

The ones that don't are implementation-dependent rather than meaningful. Gamma is the clearest case:

@property
def mean(self):      return self.concentration / self.rate  # tracks input type
@property
def variance(self):  return self.concentration / jnp.power(self.rate, 2)  # jnp.power -> always Array

variance is always an Array and mean isn't because one line calls jnp and the other doesn't — same distribution, same params. From a consumer's side that's surprising: a type checker would reject dist.mean[0] but accept dist.variance[0] if we annotated the return type with what's actually returned. One jnp.asarray at those ~31 return sites is a smaller cost than that inconsistency, I think.

This would be return-only (in contrast to my earlier, overly broad attempt). Stored params and the numpy/jax separation stay exactly as they are (Poisson(2.0).rate is still 2.0), and it's only on cold moment properties, not sample/log_prob. So no hot-path or construction cost.

What do you think?

@Qazalbash Qazalbash left a comment

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree with this reasoning. ArrayLike makes sense for inputs because it reflects what callers are allowed to pass in, but for outputs it ends up advertising possibilities that never actually occur in practice. Returning jax.Array where we consistently produce jax.Array is a more useful and informative contract for users and type checkers.

The Gamma.mean vs Gamma.variance example illustrates the issue well; differences in output type arising purely from implementation details rather than semantics are surprising from a consumer perspective.

For cases where the output type currently follows from implementation details rather than intent, another option could be to prefer jax.numpy operators such as jnp.add, jnp.divide, jnp.multiply, etc., instead of Python operators. For example, Gamma.mean could use jnp.divide(self.concentration, self.rate), which would naturally align its behavior with Gamma.variance without requiring an explicit jnp.asarray at the return site. That said, I don't have a strong preference between the two approaches as long as the public API consistently returns jax.Array.

@fehiepsi

Copy link
Copy Markdown
Member

I understand that if users expect Array as the output, then type errors will happen. I don't expect Array to be the output though. It's unnecessary to support such behavior. Users should fix their code since we clearly annotate the output is ArrayLike.

As mentioned above, we can use Array for places that we know output is an Array. Using patterns like jnp.divide in the implementation will give Array output but in my opinion, using non-array output at those places is better. It is clearer to me that Normal(0,1) will have mean 0, not float 0 or Array 0.

Comment thread numpyro/distributions/conjugate.py Outdated
(self.concentration,) = promote_shapes(concentration, shape=concentration_shape)
(self.concentration,) = promote_shapes(
concentration, shape=concentration_shape, promote_array=True
)

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why we need this promote?

Comment thread numpyro/distributions/conjugate.py Outdated
self.concentration, self.rate = promote_shapes(concentration, rate)
self.concentration, self.rate = promote_shapes(
concentration, rate, promote_array=True
)

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Like above, we dont need to promote.

Comment thread numpyro/distributions/directional.py Outdated
def __init__(self, concentration: Array, *, validate_args: Optional[bool] = None):
assert jnp.ndim(concentration) >= 1
self.concentration = concentration
self.concentration = jnp.asarray(concentration)

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We dont need this cast I think.

@fehiepsi

Copy link
Copy Markdown
Member

It would be fine to ensure mean and variance return jax Array since we already assume the semantic that they should have the expected shape. Those changes won't affect performance I guess. We might want to reconsider changes at promote logic though.

tillahoffmann and others added 2 commits June 24, 2026 10:26
test_output_is_array checks that sample/log_prob/mean/variance/entropy
return jax.Array rather than python/numpy scalars, parametrized per
method so the per-method coverage is visible. Methods whose output type
tracks the input (e.g. Poisson.mean is self.rate, Beta.mean is plain
python arithmetic) are recorded in OUTPUT_TYPE_XFAIL and marked strict
xfail, so the dict must be updated if any of them later returns an array.

Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
The distribution methods sample/log_prob/mean/variance/entropy return jax
arrays for every distribution exercised by test_output_is_array, so their
return type is narrowed from ArrayLike to Array. ArrayLike is the correct
type for an input (a method does accept python/numpy scalars), but on an
output it is an under-specified promise: it admits bool/int/float/complex
that the method never returns and that a consumer cannot index or reshape.

Input parameter annotations are left as ArrayLike. The implementation-
dependent tail recorded in OUTPUT_TYPE_XFAIL (e.g. Poisson.mean is
self.rate, Beta.mean is plain python arithmetic) keeps ArrayLike, since
those genuinely return a python/numpy scalar when given scalar params.

TransformedDistribution.sample keeps a `# type: ignore[return-value]`
rather than coercing, because Transform.__call__ is typed to return the
wide ArrayLike while the value is already a jax array at runtime.

Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
@tillahoffmann

Copy link
Copy Markdown
Collaborator Author

Re-scoped, this PR is, for now, just type work — return annotations only, no constructor / promote_shapes / validation changes. Poisson(2.0).rate is still 2.0 and construction perf is back to master; the diff is signatures plus one parametrized test.

Concretely it narrows sample / log_prob / entropy and the ~80% of mean / variance that already return jax.Array to an Array return type. The handful of mean / variance that currently pass a python/numpy value straight through (e.g. Gamma.mean) I've left as ArrayLike for now, so every annotation is honest as it stands.

Sounds like we have consensus that those mean / variance returning Array is fine? I'm happy to promote those return types in this PR too — a one-line jnp.asarray at each of the ~30 sites — or keep this one narrow and do it as a follow-up.

I still think the promotion is the right contract, and the cost is negligible: I checked it directly. Under jit the jnp.asarray wrap constant-folds, with the jaxpr and optimized HLO byte-identical (no extra op), so inference sees zero difference. The only cost is a single scalar dispatch in eager (~tens of µs), never sample / log_prob or a per-step path. It's an Array under jit either way; promotion to Array would just make the eager return match what jitted code already produces.

@Qazalbash thanks for pointing that out re jnp.asarray(a / b) vs jnp.divide(a, b)! Maybe we can have a look at that trade-off if/when we promote the moments?

Under jax.jit every method output is materialised as a jax array, so the
moment exceptions in OUTPUT_TYPE_XFAIL only apply in the non-jitted case;
the jitted case asserts every method returns a jax array with no exceptions.

Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
@tillahoffmann tillahoffmann force-pushed the add-jax-array-output-test branch from bb03d8f to 776430d Compare June 24, 2026 16:50
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

enhancement New feature or request

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants