Ensure distribution moments/samples are jax arrays#2206
Conversation
ee74d90 to
2e2dcc1
Compare
|
Hi @tillahoffmann ! Coult we try adding these modified modules in https://github.com/pyro-ppl/numpyro/blob/master/pyproject.toml#L100 so that mypy runs the type checks? If we get many errors, we can tackle them in smaller PRs 🤞 |
|
Good idea, and yeah it'd need to be done carefully — mypy is a hard gate in CI ( I tried adding the modified modules to the So I think the path you're describing works best in the other order: drive each module's errors to 0 while it's still un-enforced, then flip it into the enforced list in that same PR (rather than adding it first, which would immediately red CI). Happy to start chipping away at the modules in follow-ups — I've already got input-widening ( |
|
Great! @tillahoffmann ! Feel free to open issues and we can all help with this innitiative :) |
|
Sounds good, how shall we proceed with this one? |
I would merge it, and then we could build on top of this and fix edge cases with mypy down the line :) |
|
Sounds like a plan. I'll start tackling some of those after we merge this one—I think there might be a lot of merge conflicts otherwise. |
|
@juanitorduz, @fehiepsi, @Qazalbash, are you happy for me to merge this? |
|
Sorry I missed the scope of the CL. I'm not sure why we want to make this change. The jax best practice seems to apply for jax function. Those distribution methods could be flexible I guess. My main concern is the performance. It was intentional to have separate treatment for jax arrays and numpy arrays here and there. I still don't understand why we need to enforce jax array outputs though. Could you elaborate? |
|
Re-reading the diff, this PR conflated two separate claims, and your pushback applies to one — the other I think is worth keeping. Parameters coerced to Outputs annotated d = dist.Normal(jnp.zeros(5), 1.0)
d.mean[3] # __getitem__ not defined on int/float/complex
d.variance.shape # shape unknown on scalar members
d.mean < 0 # "<" not supported for complexAll clean once the return is Revised PR: return-position Does honest |
|
The type error in your example looks expected to me. The code assumes that the output is jax Array, which is unexpected (I understand that it might be expected to some users). My point is we need to use Array only for special distributions (the places that we know that they should be jax arrays). and keep ArrayLike for the rest. |
|
Makes sense — and I think we're almost there with narrowing outputs to My main reason for wanting In practice we're most of the way there already. I parametrized the output-type check per method across the suite:
The ones that don't are implementation-dependent rather than meaningful. @property
def mean(self): return self.concentration / self.rate # tracks input type
@property
def variance(self): return self.concentration / jnp.power(self.rate, 2) # jnp.power -> always Array
This would be return-only (in contrast to my earlier, overly broad attempt). Stored params and the numpy/jax separation stay exactly as they are ( What do you think? |
Qazalbash
left a comment
There was a problem hiding this comment.
I agree with this reasoning. ArrayLike makes sense for inputs because it reflects what callers are allowed to pass in, but for outputs it ends up advertising possibilities that never actually occur in practice. Returning jax.Array where we consistently produce jax.Array is a more useful and informative contract for users and type checkers.
The Gamma.mean vs Gamma.variance example illustrates the issue well; differences in output type arising purely from implementation details rather than semantics are surprising from a consumer perspective.
For cases where the output type currently follows from implementation details rather than intent, another option could be to prefer jax.numpy operators such as jnp.add, jnp.divide, jnp.multiply, etc., instead of Python operators. For example, Gamma.mean could use jnp.divide(self.concentration, self.rate), which would naturally align its behavior with Gamma.variance without requiring an explicit jnp.asarray at the return site. That said, I don't have a strong preference between the two approaches as long as the public API consistently returns jax.Array.
|
I understand that if users expect Array as the output, then type errors will happen. I don't expect Array to be the output though. It's unnecessary to support such behavior. Users should fix their code since we clearly annotate the output is ArrayLike. As mentioned above, we can use Array for places that we know output is an Array. Using patterns like jnp.divide in the implementation will give Array output but in my opinion, using non-array output at those places is better. It is clearer to me that Normal(0,1) will have mean 0, not float 0 or Array 0. |
| (self.concentration,) = promote_shapes(concentration, shape=concentration_shape) | ||
| (self.concentration,) = promote_shapes( | ||
| concentration, shape=concentration_shape, promote_array=True | ||
| ) |
| self.concentration, self.rate = promote_shapes(concentration, rate) | ||
| self.concentration, self.rate = promote_shapes( | ||
| concentration, rate, promote_array=True | ||
| ) |
There was a problem hiding this comment.
Like above, we dont need to promote.
| def __init__(self, concentration: Array, *, validate_args: Optional[bool] = None): | ||
| assert jnp.ndim(concentration) >= 1 | ||
| self.concentration = concentration | ||
| self.concentration = jnp.asarray(concentration) |
There was a problem hiding this comment.
We dont need this cast I think.
|
It would be fine to ensure mean and variance return jax Array since we already assume the semantic that they should have the expected shape. Those changes won't affect performance I guess. We might want to reconsider changes at promote logic though. |
test_output_is_array checks that sample/log_prob/mean/variance/entropy return jax.Array rather than python/numpy scalars, parametrized per method so the per-method coverage is visible. Methods whose output type tracks the input (e.g. Poisson.mean is self.rate, Beta.mean is plain python arithmetic) are recorded in OUTPUT_TYPE_XFAIL and marked strict xfail, so the dict must be updated if any of them later returns an array. Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
The distribution methods sample/log_prob/mean/variance/entropy return jax arrays for every distribution exercised by test_output_is_array, so their return type is narrowed from ArrayLike to Array. ArrayLike is the correct type for an input (a method does accept python/numpy scalars), but on an output it is an under-specified promise: it admits bool/int/float/complex that the method never returns and that a consumer cannot index or reshape. Input parameter annotations are left as ArrayLike. The implementation- dependent tail recorded in OUTPUT_TYPE_XFAIL (e.g. Poisson.mean is self.rate, Beta.mean is plain python arithmetic) keeps ArrayLike, since those genuinely return a python/numpy scalar when given scalar params. TransformedDistribution.sample keeps a `# type: ignore[return-value]` rather than coercing, because Transform.__call__ is typed to return the wide ArrayLike while the value is already a jax array at runtime. Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
|
Re-scoped, this PR is, for now, just type work — return annotations only, no constructor / Concretely it narrows Sounds like we have consensus that those I still think the promotion is the right contract, and the cost is negligible: I checked it directly. Under @Qazalbash thanks for pointing that out re |
Under jax.jit every method output is materialised as a jax array, so the moment exceptions in OUTPUT_TYPE_XFAIL only apply in the non-jitted case; the jitted case asserts every method returns a jax array with no exceptions. Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
bb03d8f to
776430d
Compare
Summary
This PR makes numpyro distributions return concrete
jax.Arrayfrom their public methods (mean,variance,entropy,sample,log_prob,cdf,icdf, …) and annotates them accordingly, so that typed downstream code can actually use distribution outputs. To make that true at runtime, parameters are coerced tojax.Arrayat construction.The change aligns numpyro with JAX's own typing best practices — wide on input (
ArrayLike), strict on output (Array) — and removes a class of type errors that typed consumers will hit once they rely on numpyro's types.The near-term, standalone win is the parameter coercion: today whether a stored parameter (
self.loc,self.rate, …) is ajax.Array, a numpy array, or a python scalar depends on the input — this PR makes it consistently ajax.Array. The output annotations are the cheap, correct-by-construction follow-through; they become observable to consumers once numpyro shipspy.typed(a deliberate follow-up — see below).This is necessarily runtime work, not just an annotation change: many moments return a stored parameter directly (
Poisson.meanisself.rate), so anArrayreturn type is only honest if that parameter is actually an array — typing the outputs without making the runtime match would just lie to the type checker.It also changes one behavior deliberately: argument validation (
validate_args=True) is now an eager-only operation. That is the part most worth discussing, so it is called out explicitly below.Motivation: the output annotations are wrong, and it bites consumers
Public methods were annotated to return
jax.typing.ArrayLike:ArrayLikeis a deliberately wide input union. JAX's typing guidance is explicit that it is the right type for function inputs (be liberal in what you accept) and that outputs should be annotatedArray(be strict in what you return). numpyro appliedArrayLiketo outputs too, which is incorrect:mean/sample/log_prob/… (should) always produce a concrete array.The cost lands on every typed consumer. The scalar members of the union (
bool/int/float/complex) support neither indexing,.shape,.reshape, nor assignment to anArrayparameter, so ordinary usage fails type checking:Under pyright, every one of those lines is a type error when the method returns
ArrayLike(each indexing/attribute access fails once per scalar member of the union —int,float,complex, …), and all of them type-check cleanly once it returnsArray.This is currently latent (numpyro ships no
py.typed, so checkers treat it as untypedAny), but it is a trap: the moment a consumer relies on numpyro's types — or numpyro starts enforcing its own — these annotations break real code.The benefit here is for downstream typed code. This PR does not aim to make numpyro's own modules type-clean: they are not mypy-enforced, and their residual errors are dominated by pre-existing
[override]signature mismatches unrelated to these annotations (enabling enforcement is called out as out of scope below).What this PR does
jax.Arrayat construction. Previously parameter storage was inconsistent: most parameters flow throughpromote_shapes, which preserved input types and only produced arrays when a broadcast forced a reshape, while some distributions stored the raw argument directly — so whetherself.locended up an array depended on the input shapes and dtype. Nowpromote_shapestakes an opt-inpromote_array=True(typed to returnlist[Array]) used at the parameter-assignment sites, and the few directly-assigned parameters are wrapped injnp.asarray. As a resultself.loc,self.scale, … are statically and dynamicallyjax.Array.This is runtime work, not just annotation, on purpose: annotating the outputs as
Arraywithout actually returning arrays would be a lie to the type checker — many moments simply return a stored parameter (Poisson.meanisself.rate), so the annotation is only honest if the parameter is genuinely an array. Typing that isn't backed by the runtime is worse than no typing.Annotate outputs as
Array. Return-positionArrayLikebecomesArrayacross the distribution modules. Constructor parameters and method arguments (value,q) keepArrayLike, so callers can still pass python scalars / numpy arrays.Make instance types reachable through the metaclass.
DistributionMeta.__call__is annotated to return the concrete class being constructed, sodist.Normal(...)is typedNormal(notAny, and not the baseDistribution) and the-> Arrayannotations on its methods are visible to consumers.Use functional array ops on inputs where appropriate. Where a method only needed a shape it now uses
jnp.shape(value)rather thanvalue.shape, which is correct for anyArrayLike(and incidentally more robust). A few methods that genuinely indexvaluecoerce it once.Parameters that must remain non-array are preserved: pytree aux fields (
total_counton the multinomial family,adj_matrixonCAR) are static metadata — coercing them to ajax.Arraywould make them tracers underjit/vmapand break sampling — so they are left un-coerced, as is a scipy-sparseadj_matrix(CAR(is_sparse=True)).The coercion touches a parameter-assignment site in essentially every distribution, so correctness rests on tests that run across all distributions rather than on per-site review:
test_params_and_outputs_are_arrays— samples, moments, and stored (non-aux) parameters arejax.Array. This catches any data parameter the sweep missed.test_aux_fields_are_not_jax_arrays— no pytree aux field is ever ajax.Array. This catches the opposite mistake: an aux field (e.g.total_count) wrongly coerced, which would breakjit/vmap(for instance, a jittedMultinomial(...).sample(...)needs a concretetotal_count).The deliberate behavior change: validation is eager-only
Because parameters are coerced to arrays before
Distribution.__init__runsvalidate_args, an out-of-bounds constant parameter insidejitno longer raises. Previously (#775) it did, by keeping parameters as concrete (non-jax) values so the validity check could constant-fold. #775 added this so that a typo'd literal —Normal(0., -1.)in a jitted model — would fail loudly at trace time rather than silently produce NaNs. That is a real concern, so this is worth making the case for rather than glossing over.The guarantee it provided was always narrow. It fired only for parameters that were compile-time constants; the moment the same bad value arrives through a traced argument (
Normal(0., scale)withscale = -1.), validation cannot run — you can't branch on a tracer — and never has. So onmastertoday, with the samevalidate_args=Trueand the same invalid value, whether you get an error depends on how the value was staged:A safety net that catches
-1.0written as a literal but not-1.0passed as an argument is not one a user can rely on without knowing how their parameters get staged.The new behavior is simpler to state:
validate_args=Truevalidates eagerly and does nothing underjit. Eager validation is unchanged — invalid parameters raise exactly as before outsidejit, which is where it is most useful for debugging (construct the distribution eagerly, or run once withoutjit, to surface the error).This also lines up with how out-of-support sample validation already behaves. That check warns only when the support mask is concrete, so it too goes silent under
jitwhenever the sampled value is traced — which is the usual case, since in a model/MCMC/SVI the value fed tolog_probis the (traced) data. The only situation where it still fires underjitis a constant-support distribution with a constant value, the same accidental-constant-folding corner that parameter validation used to rely on. So under realistic jitted usage both parameter and sample validation are effectively eager-only.Performance
Coercion is a
jnp.asarrayper parameter, but it is skipped for arguments that are alreadyjax.Array(a cheapisinstancecheck instead of a dispatch), so the cost falls almost entirely on constructing from python scalars / numpy. Eager-construction microbenchmarks (min of repeated runs):Normal(0., 1.)(python scalars)Normal(jnp.asarray(0.), jnp.asarray(1.))Normal(jnp.zeros(100), 1.)MultivariateNormal(jnp.zeros(5), jnp.eye(5))So constructing from jax arrays — the typical case in real code, and what happens under tracing — is unaffected. Only constructing from python-scalar literals pays the full conversion (
Normal(0., 1.)≈5×), and even then it is tens of microseconds. The jitted MCMC/SVI path is unchanged: the model and the distributions it builds are traced, so the coercion folds away (a jitted log-density constructingNormal+Betais ~3.1 µs/call on both branches, with equal compile time).Explicitly not in this PR
Enabling mypy enforcement for the distribution modules: the output annotations are now correct, but fully type-clean modules also require resolving pre-existing
[override]signature mismatches andOptionalhandling, which are out of scope here.Widening constructor inputs to
ArrayLike— see Next steps.Next steps
This PR fixes the output side of the convention (strict
Arrayout). The input side is a natural follow-up: a number of constructor parameters are still annotated with the strictArray(e.g.CategoricalProbs.probs,Dirichlet.concentration,MatrixNormal.loc/scale_tril_*,LowRankMultivariateNormal.*,Distribution.mask), so passing a numpy array or a python scalar to them is a type error even though it works at runtime. These should be widened toArrayLike(be liberal in what you accept).Widening the annotations is the easy part; the work is that it surfaces constructor bodies that access
.ndim/.shape/.reshape/etc. on a parameter before it is coerced to an array (those attributes are not defined on the scalar members ofArrayLike). So the follow-up is "widen the input annotations and coerce the parameters at the top of each affected constructor", done together. It is kept separate here to keep this PR focused on outputs and to give the input change its own review/CI cycle.(
CAR.adj_matrixis a deliberate exception — it also accepts ascipy.sparsematrix, so it needs a union rather than plainArrayLike.)Testing
test/test_distributions.pypasses in full, including the newtest_outputs_are_arraysandtest_aux_fields_are_not_jax_arrays. The eager-only argument/sample validation behavior is covered by the existing validation tests, updated to reflect the new semantics.