From 7d1cb47a761cb024a7a13a348748d36f487cc081 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Simon=20Dupourqu=C3=A9?= <49200287+renecotyfanboy@users.noreply.github.com> Date: Mon, 25 May 2026 21:25:28 +0200 Subject: [PATCH 01/15] fix validation context --- numpyro/distributions/distribution.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/numpyro/distributions/distribution.py b/numpyro/distributions/distribution.py index 7b12879eb..ac5b51914 100644 --- a/numpyro/distributions/distribution.py +++ b/numpyro/distributions/distribution.py @@ -155,7 +155,7 @@ class Distribution(metaclass=DistributionMeta): arg_constraints: dict[str, Any] = {} _support: Optional[constraints.Constraint] = None - _validate_args: bool = False + _validate_args: bool = _VALIDATION_ENABLED _arg_names: ClassVar[Optional[tuple[str, ...]]] = None pytree_data_fields: tuple[str, ...] = () pytree_aux_fields: tuple[str, ...] = ("_batch_shape", "_event_shape") From c3f06a8b150a9792376fd79cd47dc957a9d87064 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Simon=20Dupourqu=C3=A9?= <49200287+renecotyfanboy@users.noreply.github.com> Date: Mon, 25 May 2026 21:34:01 +0200 Subject: [PATCH 02/15] set default to False --- numpyro/distributions/distribution.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/numpyro/distributions/distribution.py b/numpyro/distributions/distribution.py index ac5b51914..6b5cdf3b5 100644 --- a/numpyro/distributions/distribution.py +++ b/numpyro/distributions/distribution.py @@ -64,7 +64,7 @@ from . import constraints -_VALIDATION_ENABLED = True +_VALIDATION_ENABLED = False def enable_validation(is_validate: bool = True) -> None: From c66e2fbb034fde8df928c50313f76da4ce4f118b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Simon=20Dupourqu=C3=A9?= <49200287+renecotyfanboy@users.noreply.github.com> Date: Tue, 26 May 2026 13:25:56 +0200 Subject: [PATCH 03/15] set default to True --- numpyro/distributions/distribution.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/numpyro/distributions/distribution.py b/numpyro/distributions/distribution.py index 6b5cdf3b5..ac5b51914 100644 --- a/numpyro/distributions/distribution.py +++ b/numpyro/distributions/distribution.py @@ -64,7 +64,7 @@ from . import constraints -_VALIDATION_ENABLED = False +_VALIDATION_ENABLED = True def enable_validation(is_validate: bool = True) -> None: From 1c35a7c018f8a7c727c5978c5d3edba1a30df3c7 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Simon=20Dupourqu=C3=A9?= <49200287+renecotyfanboy@users.noreply.github.com> Date: Tue, 26 May 2026 15:45:16 +0200 Subject: [PATCH 04/15] empty From 955d50f63006992b562be1309a75d089c81d6443 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Simon=20Dupourqu=C3=A9?= <49200287+renecotyfanboy@users.noreply.github.com> Date: Fri, 29 May 2026 11:22:37 +0200 Subject: [PATCH 05/15] Revert "set default to True" This reverts commit c66e2fbb034fde8df928c50313f76da4ce4f118b. --- numpyro/distributions/distribution.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/numpyro/distributions/distribution.py b/numpyro/distributions/distribution.py index ac5b51914..6b5cdf3b5 100644 --- a/numpyro/distributions/distribution.py +++ b/numpyro/distributions/distribution.py @@ -64,7 +64,7 @@ from . import constraints -_VALIDATION_ENABLED = True +_VALIDATION_ENABLED = False def enable_validation(is_validate: bool = True) -> None: From f853fc8bc889a0e0d551ad0feb9d50d3ba8a77ff Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Simon=20Dupourqu=C3=A9?= <49200287+renecotyfanboy@users.noreply.github.com> Date: Fri, 29 May 2026 11:52:43 +0200 Subject: [PATCH 06/15] Reapply "set default to True" This reverts commit 955d50f63006992b562be1309a75d089c81d6443. --- numpyro/distributions/distribution.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/numpyro/distributions/distribution.py b/numpyro/distributions/distribution.py index 6b5cdf3b5..ac5b51914 100644 --- a/numpyro/distributions/distribution.py +++ b/numpyro/distributions/distribution.py @@ -64,7 +64,7 @@ from . import constraints -_VALIDATION_ENABLED = False +_VALIDATION_ENABLED = True def enable_validation(is_validate: bool = True) -> None: From d0d3e4b005fbc08eea3453f16912a1d3e1cec414 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Simon=20Dupourqu=C3=A9?= <49200287+renecotyfanboy@users.noreply.github.com> Date: Fri, 29 May 2026 14:40:43 +0200 Subject: [PATCH 07/15] fix test_get_dependencies --- test/infer/test_inspect.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/infer/test_inspect.py b/test/infer/test_inspect.py index 675b27f5b..19182beaf 100644 --- a/test/infer/test_inspect.py +++ b/test/infer/test_inspect.py @@ -18,7 +18,7 @@ class NonreparameterizedNormal(dist.Normal): def test_get_dependencies(): def model(data): a = numpyro.sample("a", dist.Normal(0, 1)) - b = numpyro.sample("b", NonreparameterizedNormal(a, 0)) + b = numpyro.sample("b", NonreparameterizedNormal(a, 1)) c = numpyro.sample("c", dist.Normal(b, 1)) d = numpyro.sample("d", dist.Normal(a, jnp.exp(c))) From 2eb84d4102e286bcce555e1d819980a5a557e3cc Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Simon=20Dupourqu=C3=A9?= <49200287+renecotyfanboy@users.noreply.github.com> Date: Fri, 29 May 2026 14:41:12 +0200 Subject: [PATCH 08/15] fix categorical_probs --- test/infer/test_infer_util.py | 3 +-- test/ops/test_indexing.py | 4 ++-- 2 files changed, 3 insertions(+), 4 deletions(-) diff --git a/test/infer/test_infer_util.py b/test/infer/test_infer_util.py index 364a146dc..1207311ef 100644 --- a/test/infer/test_infer_util.py +++ b/test/infer/test_infer_util.py @@ -81,9 +81,8 @@ def model(X, y=None): def categorical_probs(): - probs0 = 0.5 nbatch0, nbatch1 = 2, 1 - probs = jnp.ones((nbatch0, nbatch1, 3)) * probs0 + probs = jnp.ones((nbatch0, nbatch1, 3)) / 3 def model(probs): probs = numpyro.deterministic("probs", probs) diff --git a/test/ops/test_indexing.py b/test/ops/test_indexing.py index fb8016319..2809d68b3 100644 --- a/test/ops/test_indexing.py +++ b/test/ops/test_indexing.py @@ -107,8 +107,8 @@ def test_shape(expression, expected_shape): @pytest.mark.parametrize("x_shape", [(), (2,), (3, 1), (4, 1, 1), (4, 3, 2)], ids=str) def test_value(x_shape, i_shape, j_shape, event_shape): x = jnp.array(np.random.rand(*(x_shape + (5, 6) + event_shape))) - i = dist.Categorical(jnp.ones((5,))).sample(random.key(1), i_shape) - j = dist.Categorical(jnp.ones((6,))).sample(random.key(2), j_shape) + i = dist.Categorical(jnp.ones((5,)) / 5).sample(random.key(1), i_shape) + j = dist.Categorical(jnp.ones((6,)) / 6).sample(random.key(2), j_shape) if event_shape: actual = Vindex(x)[..., i, j, :] else: From cde151164ca9c9cc1886dfe07ad854052bdaf221 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Simon=20Dupourqu=C3=A9?= <49200287+renecotyfanboy@users.noreply.github.com> Date: Fri, 29 May 2026 16:10:38 +0200 Subject: [PATCH 09/15] better params and disable validation for contributed examples --- test/contrib/hsgp/test_approximation.py | 108 +++++++++++++----------- test/contrib/test_enum_elbo.py | 72 ++++++++++------ test/infer/test_autoguide.py | 6 +- test/infer/test_mcmc.py | 2 +- test/test_distributions.py | 4 +- 5 files changed, 111 insertions(+), 81 deletions(-) diff --git a/test/contrib/hsgp/test_approximation.py b/test/contrib/hsgp/test_approximation.py index 4aea026ff..29b621d6c 100644 --- a/test/contrib/hsgp/test_approximation.py +++ b/test/contrib/hsgp/test_approximation.py @@ -400,14 +400,17 @@ def latent_gp(x, alpha, length, ell, m, non_centered): ) def model(x, ell, m, non_centered, y=None): - alpha = numpyro.sample("alpha", dist.LogNormal(0.0, 1.0)) - length = numpyro.sample("length", dist.LogNormal(0.0, 1.0)) - noise = numpyro.sample("noise", dist.LogNormal(0.0, 1.0)) - f = scope(latent_gp, prefix="se", divider="::")( - x=x, alpha=alpha, length=length, ell=ell, m=m, non_centered=non_centered - ) - with numpyro.plate("data", x.shape[0]): - numpyro.sample("likelihood", dist.Normal(loc=f, scale=noise), obs=y) + + with numpyro.validation_enabled(False): + alpha = numpyro.sample("alpha", dist.LogNormal(0.0, 1.0)) + length = numpyro.sample("length", dist.LogNormal(0.0, 1.0)) + noise = numpyro.sample("noise", dist.LogNormal(0.0, 1.0)) + f = scope(latent_gp, prefix="se", divider="::")( + x=x, alpha=alpha, length=length, ell=ell, m=m, non_centered=non_centered + ) + + with numpyro.plate("data", x.shape[0]): + numpyro.sample("likelihood", dist.Normal(loc=f, scale=noise), obs=y) x, y_obs = synthetic_one_dim_data if num_dim == 1 else synthetic_two_dim_data model_trace = trace(seed(model, random.key(0))).get_trace( @@ -464,20 +467,23 @@ def latent_gp(x, nu, alpha, length, ell, m, non_centered): ) def model(x, nu, ell, m, non_centered, y=None): - alpha = numpyro.sample("alpha", dist.LogNormal(0.0, 1.0)) - length = numpyro.sample("length", dist.LogNormal(0.0, 1.0)) - noise = numpyro.sample("noise", dist.LogNormal(0.0, 1.0)) - f = scope(latent_gp, prefix="matern", divider="::")( - x=x, - nu=nu, - alpha=alpha, - length=length, - ell=ell, - m=m, - non_centered=non_centered, - ) - with numpyro.plate("data", x.shape[0]): - numpyro.sample("likelihood", dist.Normal(loc=f, scale=noise), obs=y) + + with numpyro.validation_enabled(False): + alpha = numpyro.sample("alpha", dist.LogNormal(0.0, 1.0)) + length = numpyro.sample("length", dist.LogNormal(0.0, 1.0)) + noise = numpyro.sample("noise", dist.LogNormal(0.0, 1.0)) + f = scope(latent_gp, prefix="matern", divider="::")( + x=x, + nu=nu, + alpha=alpha, + length=length, + ell=ell, + m=m, + non_centered=non_centered, + ) + + with numpyro.plate("data", x.shape[0]): + numpyro.sample("likelihood", dist.Normal(loc=f, scale=noise), obs=y) x, y_obs = synthetic_one_dim_data if num_dim == 1 else synthetic_two_dim_data model_trace = trace(seed(model, random.key(0))).get_trace( @@ -679,20 +685,23 @@ def latent_gp(x, alpha, length, scale_mixture, ell, m, non_centered): ) def model(x, scale_mixture, ell, m, non_centered, y=None): - alpha = numpyro.sample("alpha", dist.LogNormal(0.0, 1.0)) - length = numpyro.sample("length", dist.LogNormal(0.0, 1.0)) - noise = numpyro.sample("noise", dist.LogNormal(0.0, 1.0)) - f = scope(latent_gp, prefix="rq", divider="::")( - x=x, - alpha=alpha, - length=length, - scale_mixture=scale_mixture, - ell=ell, - m=m, - non_centered=non_centered, - ) - with numpyro.plate("data", x.shape[0]): - numpyro.sample("likelihood", dist.Normal(loc=f, scale=noise), obs=y) + + with numpyro.validation_enabled(False): + alpha = numpyro.sample("alpha", dist.LogNormal(0.0, 1.0)) + length = numpyro.sample("length", dist.LogNormal(0.0, 1.0)) + noise = numpyro.sample("noise", dist.LogNormal(0.0, 1.0)) + f = scope(latent_gp, prefix="rq", divider="::")( + x=x, + alpha=alpha, + length=length, + scale_mixture=scale_mixture, + ell=ell, + m=m, + non_centered=non_centered, + ) + + with numpyro.plate("data", x.shape[0]): + numpyro.sample("likelihood", dist.Normal(loc=f, scale=noise), obs=y) x, y_obs = synthetic_one_dim_data if num_dim == 1 else synthetic_two_dim_data model_trace = trace(seed(model, random.key(0))).get_trace( @@ -736,18 +745,21 @@ def latent_gp(x, alpha, length, w0, m): ) def model(x, w0, m, y=None): - alpha = numpyro.sample("alpha", dist.LogNormal(0.0, 1.0)) - length = numpyro.sample("length", dist.LogNormal(0.0, 1.0)) - noise = numpyro.sample("noise", dist.LogNormal(0.0, 1.0)) - f = scope(latent_gp, prefix="periodic", divider="::")( - x=x, - alpha=alpha, - length=length, - w0=w0, - m=m, - ) - with numpyro.plate("data", x.shape[0]): - numpyro.sample("likelihood", dist.Normal(loc=f, scale=noise), obs=y) + + with numpyro.validation_enabled(False): + alpha = numpyro.sample("alpha", dist.LogNormal(0.0, 1.0)) + length = numpyro.sample("length", dist.LogNormal(0.0, 1.0)) + noise = numpyro.sample("noise", dist.LogNormal(0.0, 1.0)) + f = scope(latent_gp, prefix="periodic", divider="::")( + x=x, + alpha=alpha, + length=length, + w0=w0, + m=m, + ) + + with numpyro.plate("data", x.shape[0]): + numpyro.sample("likelihood", dist.Normal(loc=f, scale=noise), obs=y) x, y_obs = synthetic_one_dim_data model_trace = trace(seed(model, random.key(0))).get_trace(x, w0, m, y_obs) diff --git a/test/contrib/test_enum_elbo.py b/test/contrib/test_enum_elbo.py index 6bd5b5010..77377c03f 100644 --- a/test/contrib/test_enum_elbo.py +++ b/test/contrib/test_enum_elbo.py @@ -772,18 +772,22 @@ def auto_model(data, params): params["model_probs_e"], constraint=constraints.simplex, ) - a = pyro.sample("a", dist.Categorical(probs_a)) - b = pyro.sample( - "b", dist.Categorical(probs_b[a]), infer={"enumerate": "parallel"} - ) - with pyro.plate("data", 2): - c = pyro.sample("c", dist.Categorical(probs_c[a])) - d = pyro.sample( - "d", - dist.Categorical(Vindex(probs_d)[b, c]), - infer={"enumerate": "parallel"}, + + # Categorical distributions are not normalized here + with numpyro.validation_enabled(False): + + a = pyro.sample("a", dist.Categorical(probs_a)) + b = pyro.sample( + "b", dist.Categorical(probs_b[a]), infer={"enumerate": "parallel"} ) - pyro.sample("obs", dist.Categorical(probs_e[d]), obs=data) + with pyro.plate("data", 2): + c = pyro.sample("c", dist.Categorical(probs_c[a])) + d = pyro.sample( + "d", + dist.Categorical(Vindex(probs_d)[b, c]), + infer={"enumerate": "parallel"}, + ) + pyro.sample("obs", dist.Categorical(probs_e[d]), obs=data) @handlers.scale(scale=scale) def auto_guide(data, params): @@ -793,9 +797,13 @@ def auto_guide(data, params): probs_c = pyro.param( "guide_probs_c", params["guide_probs_c"], constraint=constraints.simplex ) - a = pyro.sample("a", dist.Categorical(probs_a), infer={"enumerate": "parallel"}) - with pyro.plate("data", 2): - pyro.sample("c", dist.Categorical(probs_c[a])) + + # Categorical distributions are not normalized here + with numpyro.validation_enabled(False): + + a = pyro.sample("a", dist.Categorical(probs_a), infer={"enumerate": "parallel"}) + with pyro.plate("data", 2): + pyro.sample("c", dist.Categorical(probs_c[a])) @handlers.scale(scale=scale) def hand_model(data, params): @@ -820,18 +828,22 @@ def hand_model(data, params): params["model_probs_e"], constraint=constraints.simplex, ) - a = pyro.sample("a", dist.Categorical(probs_a)) - b = pyro.sample( - "b", dist.Categorical(probs_b[a]), infer={"enumerate": "parallel"} - ) - for i in range(2): - c = pyro.sample(f"c_{i}", dist.Categorical(probs_c[a])) - d = pyro.sample( - f"d_{i}", - dist.Categorical(Vindex(probs_d)[b, c]), - infer={"enumerate": "parallel"}, + + # Categorical distributions are not normalized here + with numpyro.validation_enabled(False): + + a = pyro.sample("a", dist.Categorical(probs_a)) + b = pyro.sample( + "b", dist.Categorical(probs_b[a]), infer={"enumerate": "parallel"} ) - pyro.sample(f"obs_{i}", dist.Categorical(probs_e[d]), obs=data[i]) + for i in range(2): + c = pyro.sample(f"c_{i}", dist.Categorical(probs_c[a])) + d = pyro.sample( + f"d_{i}", + dist.Categorical(Vindex(probs_d)[b, c]), + infer={"enumerate": "parallel"}, + ) + pyro.sample(f"obs_{i}", dist.Categorical(probs_e[d]), obs=data[i]) @handlers.scale(scale=scale) def hand_guide(data, params): @@ -841,9 +853,13 @@ def hand_guide(data, params): probs_c = pyro.param( "guide_probs_c", params["guide_probs_c"], constraint=constraints.simplex ) - a = pyro.sample("a", dist.Categorical(probs_a), infer={"enumerate": "parallel"}) - for i in range(2): - pyro.sample(f"c_{i}", dist.Categorical(probs_c[a])) + + # Categorical distributions are not normalized here + with numpyro.validation_enabled(False): + + a = pyro.sample("a", dist.Categorical(probs_a), infer={"enumerate": "parallel"}) + for i in range(2): + pyro.sample(f"c_{i}", dist.Categorical(probs_c[a])) data = jnp.array([0, 0]) diff --git a/test/infer/test_autoguide.py b/test/infer/test_autoguide.py index d7bae8dff..21710b405 100644 --- a/test/infer/test_autoguide.py +++ b/test/infer/test_autoguide.py @@ -404,8 +404,10 @@ def model(x, y): a = numpyro.sample("a", dist.Normal(0, 10)) b = numpyro.sample("b", dist.Normal(0, 10).expand([3]).to_event()) mu = a + b[0] * x + b[1] * x**2 + b[2] * x**3 - with numpyro.plate("N", len(x)): - numpyro.sample("y", dist.Normal(mu, 0.00001), obs=y) + + with numpyro.validation_enabled(False): + with numpyro.plate("N", len(x)): + numpyro.sample("y", dist.Normal(mu, 0.00001), obs=y) x = random.normal(random.key(0), (3,)) y = 1 + 2 * x + 3 * x**2 + 4 * x**3 diff --git a/test/infer/test_mcmc.py b/test/infer/test_mcmc.py index 39637e23d..e7cea8814 100644 --- a/test/infer/test_mcmc.py +++ b/test/infer/test_mcmc.py @@ -738,7 +738,7 @@ def test_functional_map(algo, map_fn): pytest.skip("pmap test requires device_count greater than 1.") true_mean, true_std = 1.0, 2.0 - num_warmup, num_samples = 1000, 8000 + num_warmup, num_samples = 2000, 10_000 def potential_fn(z): return 0.5 * jnp.sum(((z - true_mean) / true_std) ** 2) diff --git a/test/test_distributions.py b/test/test_distributions.py index 256cb1f82..19d1272b5 100644 --- a/test/test_distributions.py +++ b/test/test_distributions.py @@ -2044,8 +2044,8 @@ def test_beta_binomial_log_prob(total_count, shape): @pytest.mark.parametrize("n", [1, 2, 5, 10]) @pytest.mark.parametrize("shape", [(1,), (3, 1), (2, 3, 1)]) def test_beta_negative_binomial_log_prob(n, shape): - concentration0 = np.exp(np.random.normal(size=shape)) - concentration1 = np.exp(np.random.normal(size=shape)) + concentration0 = 1 + np.exp(np.random.normal(size=shape)) + concentration1 = 1 + np.exp(np.random.normal(size=shape)) value = jnp.arange(15) num_samples = 300000 From e6550eaa1ccc44a5b33b72e8ab48fc635f658183 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Simon=20Dupourqu=C3=A9?= <49200287+renecotyfanboy@users.noreply.github.com> Date: Fri, 29 May 2026 17:01:51 +0200 Subject: [PATCH 10/15] ruff format --- test/contrib/hsgp/test_approximation.py | 4 ---- test/contrib/test_enum_elbo.py | 12 ++++++------ 2 files changed, 6 insertions(+), 10 deletions(-) diff --git a/test/contrib/hsgp/test_approximation.py b/test/contrib/hsgp/test_approximation.py index 29b621d6c..7777b3ce9 100644 --- a/test/contrib/hsgp/test_approximation.py +++ b/test/contrib/hsgp/test_approximation.py @@ -400,7 +400,6 @@ def latent_gp(x, alpha, length, ell, m, non_centered): ) def model(x, ell, m, non_centered, y=None): - with numpyro.validation_enabled(False): alpha = numpyro.sample("alpha", dist.LogNormal(0.0, 1.0)) length = numpyro.sample("length", dist.LogNormal(0.0, 1.0)) @@ -467,7 +466,6 @@ def latent_gp(x, nu, alpha, length, ell, m, non_centered): ) def model(x, nu, ell, m, non_centered, y=None): - with numpyro.validation_enabled(False): alpha = numpyro.sample("alpha", dist.LogNormal(0.0, 1.0)) length = numpyro.sample("length", dist.LogNormal(0.0, 1.0)) @@ -685,7 +683,6 @@ def latent_gp(x, alpha, length, scale_mixture, ell, m, non_centered): ) def model(x, scale_mixture, ell, m, non_centered, y=None): - with numpyro.validation_enabled(False): alpha = numpyro.sample("alpha", dist.LogNormal(0.0, 1.0)) length = numpyro.sample("length", dist.LogNormal(0.0, 1.0)) @@ -745,7 +742,6 @@ def latent_gp(x, alpha, length, w0, m): ) def model(x, w0, m, y=None): - with numpyro.validation_enabled(False): alpha = numpyro.sample("alpha", dist.LogNormal(0.0, 1.0)) length = numpyro.sample("length", dist.LogNormal(0.0, 1.0)) diff --git a/test/contrib/test_enum_elbo.py b/test/contrib/test_enum_elbo.py index 77377c03f..dc22029f7 100644 --- a/test/contrib/test_enum_elbo.py +++ b/test/contrib/test_enum_elbo.py @@ -775,7 +775,6 @@ def auto_model(data, params): # Categorical distributions are not normalized here with numpyro.validation_enabled(False): - a = pyro.sample("a", dist.Categorical(probs_a)) b = pyro.sample( "b", dist.Categorical(probs_b[a]), infer={"enumerate": "parallel"} @@ -800,8 +799,9 @@ def auto_guide(data, params): # Categorical distributions are not normalized here with numpyro.validation_enabled(False): - - a = pyro.sample("a", dist.Categorical(probs_a), infer={"enumerate": "parallel"}) + a = pyro.sample( + "a", dist.Categorical(probs_a), infer={"enumerate": "parallel"} + ) with pyro.plate("data", 2): pyro.sample("c", dist.Categorical(probs_c[a])) @@ -831,7 +831,6 @@ def hand_model(data, params): # Categorical distributions are not normalized here with numpyro.validation_enabled(False): - a = pyro.sample("a", dist.Categorical(probs_a)) b = pyro.sample( "b", dist.Categorical(probs_b[a]), infer={"enumerate": "parallel"} @@ -856,8 +855,9 @@ def hand_guide(data, params): # Categorical distributions are not normalized here with numpyro.validation_enabled(False): - - a = pyro.sample("a", dist.Categorical(probs_a), infer={"enumerate": "parallel"}) + a = pyro.sample( + "a", dist.Categorical(probs_a), infer={"enumerate": "parallel"} + ) for i in range(2): pyro.sample(f"c_{i}", dist.Categorical(probs_c[a])) From 572dfccc922af76fd57244a09d699edde49f2fa7 Mon Sep 17 00:00:00 2001 From: Juan Orduz Date: Fri, 5 Jun 2026 11:20:37 +0200 Subject: [PATCH 11/15] Fix tests for validation enabled by default With validation defaulting to True, fix the remaining failures using valid parameters where possible, and skip validation only where a test or library path intentionally uses invalid / off-manifold values: - test_distribution_constraints: build out-of-bounds params under validation_enabled(False) (the kwarg does not reach compound dists' internals). - test_log_prob_gradient / test_sample_gradient: scope validation off around the finite-difference reference, which perturbs params off their constraint manifold. - AutoLaplaceApproximation.get_posterior: build its MVN with validate_args=False (singular-Hessian fallback intentionally zeroes scale_tril). - EulerMaruyama / GaussianRandomWalk log_prob: internal Normals built from the input value use validate_args=False; the public log_prob still validates value. - HSGP centered approximation: Normal(0, spd) with validate_args=False since spd underflows to 0 for high-frequency basis functions. - compat.param: project a concrete constraint-violating init onto its constraint (matches Pyro semantics), fixing the unnormalized-simplex pyroapi case. - enum_elbo: fix unnormalized guide_probs_a; drop validation_enabled(False) wraps. - HSGP tests: drop validation_enabled(False) wraps. - test_subposterior_structure: use non-degenerate subposteriors. Co-Authored-By: Claude Opus 4.8 --- numpyro/compat/pyro.py | 19 ++++- numpyro/contrib/hsgp/approximation.py | 7 +- numpyro/distributions/continuous.py | 16 +++- numpyro/infer/autoguide.py | 7 +- test/contrib/hsgp/test_approximation.py | 104 +++++++++++------------- test/contrib/test_enum_elbo.py | 70 +++++++--------- test/infer/test_autoguide.py | 6 +- test/infer/test_hmc_util.py | 7 +- test/test_distributions.py | 24 ++++-- 9 files changed, 147 insertions(+), 113 deletions(-) diff --git a/numpyro/compat/pyro.py b/numpyro/compat/pyro.py index 18a443748..916bb6785 100644 --- a/numpyro/compat/pyro.py +++ b/numpyro/compat/pyro.py @@ -3,8 +3,11 @@ import warnings +import numpy as np + from numpyro.compat.util import UnsupportedAPIWarning -from numpyro.util import find_stack_level +from numpyro.distributions.transforms import biject_to +from numpyro.util import find_stack_level, not_jax_tracer from numpyro.primitives import module, plate, sample # noqa: F401 isort:skip from numpyro.primitives import param as _param # noqa: F401 isort:skip @@ -38,4 +41,18 @@ def param(name, *args, **kwargs): param_store = get_param_store() if name in param_store: val = param_store[name] + # Match Pyro's constrained-param semantics: an init value is treated as a + # constrained value and projected onto its constraint. NumPyro returns the + # raw init at trace time, so project it here when it is concrete and does not + # already satisfy the constraint (e.g. an unnormalized simplex value). Valid + # values (including boundary points) and tracers are left untouched. + constraint = kwargs.get("constraint") + if ( + val is not None + and constraint is not None + and not_jax_tracer(val) + and not np.all(constraint(val)) + ): + transform = biject_to(constraint) + val = transform(transform.inv(val)) return val diff --git a/numpyro/contrib/hsgp/approximation.py b/numpyro/contrib/hsgp/approximation.py index fe075142f..045817e95 100644 --- a/numpyro/contrib/hsgp/approximation.py +++ b/numpyro/contrib/hsgp/approximation.py @@ -31,7 +31,12 @@ def _non_centered_approximation(phi: Array, spd: Array, m: int) -> Array: def _centered_approximation(phi: Array, spd: Array, m: int) -> Array: with numpyro.plate("basis", m): - beta = numpyro.sample("beta", dist.Normal(loc=0.0, scale=spd)) + # ``spd`` is the square root of the spectral density and can underflow to + # exactly 0 for high-frequency basis functions, yielding a benign + # zero-variance (degenerate) coefficient. Skip validation for that edge. + beta = numpyro.sample( + "beta", dist.Normal(loc=0.0, scale=spd, validate_args=False) + ) return phi @ beta diff --git a/numpyro/distributions/continuous.py b/numpyro/distributions/continuous.py index 91c504ad8..12372967c 100644 --- a/numpyro/distributions/continuous.py +++ b/numpyro/distributions/continuous.py @@ -654,7 +654,12 @@ def log_prob(self, value: ArrayLike) -> ArrayLike: mu = xtm1 + dt * f sigma = jnp.sqrt(dt) * g - sde_log_prob = Normal(mu, sigma).to_event(self.event_dim).log_prob(xt) + # ``mu``/``sigma`` are derived from ``value``; out-of-support values + # would make them invalid, so skip validation of this internal + # distribution. The public log_prob already validates ``value``. + sde_log_prob = ( + Normal(mu, sigma, validate_args=False).to_event(self.event_dim).log_prob(xt) + ) init_log_prob = self.init_dist.log_prob(value0) return sde_log_prob + init_log_prob @@ -1094,9 +1099,14 @@ def sample(self, key: jax.Array, sample_shape: tuple[int, ...] = ()) -> ArrayLik @validate_sample def log_prob(self, value: ArrayLike) -> ArrayLike: - init_prob = Normal(0.0, self.scale).log_prob(value[..., 0]) + # the step distribution uses ``value`` as its location, so out-of-support + # values would make it invalid; skip validation of these internal + # distributions since the public log_prob already validates ``value``. + init_prob = Normal(0.0, self.scale, validate_args=False).log_prob(value[..., 0]) scale = jnp.expand_dims(self.scale, -1) - step_probs = Normal(value[..., :-1], scale).log_prob(value[..., 1:]) + step_probs = Normal(value[..., :-1], scale, validate_args=False).log_prob( + value[..., 1:] + ) return init_prob + jnp.sum(step_probs, axis=-1) @property diff --git a/numpyro/infer/autoguide.py b/numpyro/infer/autoguide.py index d722d30ce..2e6cafb6c 100644 --- a/numpyro/infer/autoguide.py +++ b/numpyro/infer/autoguide.py @@ -2243,7 +2243,12 @@ def get_posterior(self, params): Returns a multivariate Normal posterior distribution. """ transform = self.get_transform(params) - return dist.MultivariateNormal(transform.loc, scale_tril=transform.scale_tril) + # When the Hessian at the MAP point is singular, ``get_transform`` warns + # and falls back to a zeroed ``scale_tril`` (a degenerate, constant + # posterior). That fallback is intentional, so skip argument validation. + return dist.MultivariateNormal( + transform.loc, scale_tril=transform.scale_tril, validate_args=False + ) def sample_posterior(self, rng_key, params, *args, sample_shape=(), **kwargs): latent_sample = self.get_posterior(params).sample(rng_key, sample_shape) diff --git a/test/contrib/hsgp/test_approximation.py b/test/contrib/hsgp/test_approximation.py index 7777b3ce9..4aea026ff 100644 --- a/test/contrib/hsgp/test_approximation.py +++ b/test/contrib/hsgp/test_approximation.py @@ -400,16 +400,14 @@ def latent_gp(x, alpha, length, ell, m, non_centered): ) def model(x, ell, m, non_centered, y=None): - with numpyro.validation_enabled(False): - alpha = numpyro.sample("alpha", dist.LogNormal(0.0, 1.0)) - length = numpyro.sample("length", dist.LogNormal(0.0, 1.0)) - noise = numpyro.sample("noise", dist.LogNormal(0.0, 1.0)) - f = scope(latent_gp, prefix="se", divider="::")( - x=x, alpha=alpha, length=length, ell=ell, m=m, non_centered=non_centered - ) - - with numpyro.plate("data", x.shape[0]): - numpyro.sample("likelihood", dist.Normal(loc=f, scale=noise), obs=y) + alpha = numpyro.sample("alpha", dist.LogNormal(0.0, 1.0)) + length = numpyro.sample("length", dist.LogNormal(0.0, 1.0)) + noise = numpyro.sample("noise", dist.LogNormal(0.0, 1.0)) + f = scope(latent_gp, prefix="se", divider="::")( + x=x, alpha=alpha, length=length, ell=ell, m=m, non_centered=non_centered + ) + with numpyro.plate("data", x.shape[0]): + numpyro.sample("likelihood", dist.Normal(loc=f, scale=noise), obs=y) x, y_obs = synthetic_one_dim_data if num_dim == 1 else synthetic_two_dim_data model_trace = trace(seed(model, random.key(0))).get_trace( @@ -466,22 +464,20 @@ def latent_gp(x, nu, alpha, length, ell, m, non_centered): ) def model(x, nu, ell, m, non_centered, y=None): - with numpyro.validation_enabled(False): - alpha = numpyro.sample("alpha", dist.LogNormal(0.0, 1.0)) - length = numpyro.sample("length", dist.LogNormal(0.0, 1.0)) - noise = numpyro.sample("noise", dist.LogNormal(0.0, 1.0)) - f = scope(latent_gp, prefix="matern", divider="::")( - x=x, - nu=nu, - alpha=alpha, - length=length, - ell=ell, - m=m, - non_centered=non_centered, - ) - - with numpyro.plate("data", x.shape[0]): - numpyro.sample("likelihood", dist.Normal(loc=f, scale=noise), obs=y) + alpha = numpyro.sample("alpha", dist.LogNormal(0.0, 1.0)) + length = numpyro.sample("length", dist.LogNormal(0.0, 1.0)) + noise = numpyro.sample("noise", dist.LogNormal(0.0, 1.0)) + f = scope(latent_gp, prefix="matern", divider="::")( + x=x, + nu=nu, + alpha=alpha, + length=length, + ell=ell, + m=m, + non_centered=non_centered, + ) + with numpyro.plate("data", x.shape[0]): + numpyro.sample("likelihood", dist.Normal(loc=f, scale=noise), obs=y) x, y_obs = synthetic_one_dim_data if num_dim == 1 else synthetic_two_dim_data model_trace = trace(seed(model, random.key(0))).get_trace( @@ -683,22 +679,20 @@ def latent_gp(x, alpha, length, scale_mixture, ell, m, non_centered): ) def model(x, scale_mixture, ell, m, non_centered, y=None): - with numpyro.validation_enabled(False): - alpha = numpyro.sample("alpha", dist.LogNormal(0.0, 1.0)) - length = numpyro.sample("length", dist.LogNormal(0.0, 1.0)) - noise = numpyro.sample("noise", dist.LogNormal(0.0, 1.0)) - f = scope(latent_gp, prefix="rq", divider="::")( - x=x, - alpha=alpha, - length=length, - scale_mixture=scale_mixture, - ell=ell, - m=m, - non_centered=non_centered, - ) - - with numpyro.plate("data", x.shape[0]): - numpyro.sample("likelihood", dist.Normal(loc=f, scale=noise), obs=y) + alpha = numpyro.sample("alpha", dist.LogNormal(0.0, 1.0)) + length = numpyro.sample("length", dist.LogNormal(0.0, 1.0)) + noise = numpyro.sample("noise", dist.LogNormal(0.0, 1.0)) + f = scope(latent_gp, prefix="rq", divider="::")( + x=x, + alpha=alpha, + length=length, + scale_mixture=scale_mixture, + ell=ell, + m=m, + non_centered=non_centered, + ) + with numpyro.plate("data", x.shape[0]): + numpyro.sample("likelihood", dist.Normal(loc=f, scale=noise), obs=y) x, y_obs = synthetic_one_dim_data if num_dim == 1 else synthetic_two_dim_data model_trace = trace(seed(model, random.key(0))).get_trace( @@ -742,20 +736,18 @@ def latent_gp(x, alpha, length, w0, m): ) def model(x, w0, m, y=None): - with numpyro.validation_enabled(False): - alpha = numpyro.sample("alpha", dist.LogNormal(0.0, 1.0)) - length = numpyro.sample("length", dist.LogNormal(0.0, 1.0)) - noise = numpyro.sample("noise", dist.LogNormal(0.0, 1.0)) - f = scope(latent_gp, prefix="periodic", divider="::")( - x=x, - alpha=alpha, - length=length, - w0=w0, - m=m, - ) - - with numpyro.plate("data", x.shape[0]): - numpyro.sample("likelihood", dist.Normal(loc=f, scale=noise), obs=y) + alpha = numpyro.sample("alpha", dist.LogNormal(0.0, 1.0)) + length = numpyro.sample("length", dist.LogNormal(0.0, 1.0)) + noise = numpyro.sample("noise", dist.LogNormal(0.0, 1.0)) + f = scope(latent_gp, prefix="periodic", divider="::")( + x=x, + alpha=alpha, + length=length, + w0=w0, + m=m, + ) + with numpyro.plate("data", x.shape[0]): + numpyro.sample("likelihood", dist.Normal(loc=f, scale=noise), obs=y) x, y_obs = synthetic_one_dim_data model_trace = trace(seed(model, random.key(0))).get_trace(x, w0, m, y_obs) diff --git a/test/contrib/test_enum_elbo.py b/test/contrib/test_enum_elbo.py index dc22029f7..9a4afd569 100644 --- a/test/contrib/test_enum_elbo.py +++ b/test/contrib/test_enum_elbo.py @@ -746,7 +746,7 @@ def test_elbo_enumerate_plate_7(scale): [[[0.4, 0.6], [0.3, 0.7]], [[0.3, 0.7], [0.2, 0.8]]] ) params["model_probs_e"] = jnp.array([[0.75, 0.25], [0.55, 0.45]]) - params["guide_probs_a"] = jnp.array([0.35, 0.64]) + params["guide_probs_a"] = jnp.array([0.35, 0.65]) params["guide_probs_c"] = jnp.array([[0.0, 1.0], [1.0, 0.0]]) # deterministic @handlers.scale(scale=scale) @@ -773,20 +773,18 @@ def auto_model(data, params): constraint=constraints.simplex, ) - # Categorical distributions are not normalized here - with numpyro.validation_enabled(False): - a = pyro.sample("a", dist.Categorical(probs_a)) - b = pyro.sample( - "b", dist.Categorical(probs_b[a]), infer={"enumerate": "parallel"} + a = pyro.sample("a", dist.Categorical(probs_a)) + b = pyro.sample( + "b", dist.Categorical(probs_b[a]), infer={"enumerate": "parallel"} + ) + with pyro.plate("data", 2): + c = pyro.sample("c", dist.Categorical(probs_c[a])) + d = pyro.sample( + "d", + dist.Categorical(Vindex(probs_d)[b, c]), + infer={"enumerate": "parallel"}, ) - with pyro.plate("data", 2): - c = pyro.sample("c", dist.Categorical(probs_c[a])) - d = pyro.sample( - "d", - dist.Categorical(Vindex(probs_d)[b, c]), - infer={"enumerate": "parallel"}, - ) - pyro.sample("obs", dist.Categorical(probs_e[d]), obs=data) + pyro.sample("obs", dist.Categorical(probs_e[d]), obs=data) @handlers.scale(scale=scale) def auto_guide(data, params): @@ -797,13 +795,9 @@ def auto_guide(data, params): "guide_probs_c", params["guide_probs_c"], constraint=constraints.simplex ) - # Categorical distributions are not normalized here - with numpyro.validation_enabled(False): - a = pyro.sample( - "a", dist.Categorical(probs_a), infer={"enumerate": "parallel"} - ) - with pyro.plate("data", 2): - pyro.sample("c", dist.Categorical(probs_c[a])) + a = pyro.sample("a", dist.Categorical(probs_a), infer={"enumerate": "parallel"}) + with pyro.plate("data", 2): + pyro.sample("c", dist.Categorical(probs_c[a])) @handlers.scale(scale=scale) def hand_model(data, params): @@ -829,20 +823,18 @@ def hand_model(data, params): constraint=constraints.simplex, ) - # Categorical distributions are not normalized here - with numpyro.validation_enabled(False): - a = pyro.sample("a", dist.Categorical(probs_a)) - b = pyro.sample( - "b", dist.Categorical(probs_b[a]), infer={"enumerate": "parallel"} + a = pyro.sample("a", dist.Categorical(probs_a)) + b = pyro.sample( + "b", dist.Categorical(probs_b[a]), infer={"enumerate": "parallel"} + ) + for i in range(2): + c = pyro.sample(f"c_{i}", dist.Categorical(probs_c[a])) + d = pyro.sample( + f"d_{i}", + dist.Categorical(Vindex(probs_d)[b, c]), + infer={"enumerate": "parallel"}, ) - for i in range(2): - c = pyro.sample(f"c_{i}", dist.Categorical(probs_c[a])) - d = pyro.sample( - f"d_{i}", - dist.Categorical(Vindex(probs_d)[b, c]), - infer={"enumerate": "parallel"}, - ) - pyro.sample(f"obs_{i}", dist.Categorical(probs_e[d]), obs=data[i]) + pyro.sample(f"obs_{i}", dist.Categorical(probs_e[d]), obs=data[i]) @handlers.scale(scale=scale) def hand_guide(data, params): @@ -853,13 +845,9 @@ def hand_guide(data, params): "guide_probs_c", params["guide_probs_c"], constraint=constraints.simplex ) - # Categorical distributions are not normalized here - with numpyro.validation_enabled(False): - a = pyro.sample( - "a", dist.Categorical(probs_a), infer={"enumerate": "parallel"} - ) - for i in range(2): - pyro.sample(f"c_{i}", dist.Categorical(probs_c[a])) + a = pyro.sample("a", dist.Categorical(probs_a), infer={"enumerate": "parallel"}) + for i in range(2): + pyro.sample(f"c_{i}", dist.Categorical(probs_c[a])) data = jnp.array([0, 0]) diff --git a/test/infer/test_autoguide.py b/test/infer/test_autoguide.py index 21710b405..d7bae8dff 100644 --- a/test/infer/test_autoguide.py +++ b/test/infer/test_autoguide.py @@ -404,10 +404,8 @@ def model(x, y): a = numpyro.sample("a", dist.Normal(0, 10)) b = numpyro.sample("b", dist.Normal(0, 10).expand([3]).to_event()) mu = a + b[0] * x + b[1] * x**2 + b[2] * x**3 - - with numpyro.validation_enabled(False): - with numpyro.plate("N", len(x)): - numpyro.sample("y", dist.Normal(mu, 0.00001), obs=y) + with numpyro.plate("N", len(x)): + numpyro.sample("y", dist.Normal(mu, 0.00001), obs=y) x = random.normal(random.key(0), (3,)) y = 1 + 2 * x + 3 * x**2 + 4 * x**3 diff --git a/test/infer/test_hmc_util.py b/test/infer/test_hmc_util.py index d0998895f..d3cd4f861 100644 --- a/test/infer/test_hmc_util.py +++ b/test/infer/test_hmc_util.py @@ -468,8 +468,13 @@ def test_gaussian_subposterior(method, diagonal): @pytest.mark.parametrize("method", [consensus, parametric_draws]) def test_subposterior_structure(method): + # use non-degenerate draws so the per-subposterior covariance is invertible subposteriors = [ - {"x": jnp.ones((100, 3)), "y": jnp.zeros((100,))} for i in range(10) + { + "x": random.normal(random.key(i), (100, 3)), + "y": random.normal(random.key(10 + i), (100,)), + } + for i in range(10) ] draws = method(subposteriors, num_draws=9) assert draws["x"].shape == (9, 3) diff --git a/test/test_distributions.py b/test/test_distributions.py index 19d1272b5..a27876b6f 100644 --- a/test/test_distributions.py +++ b/test/test_distributions.py @@ -24,6 +24,7 @@ from jax.scipy.special import expit, logsumexp from jax.scipy.stats import norm as jax_norm, truncnorm as jax_truncnorm +import numpyro import numpyro.distributions as dist from numpyro.distributions import ( SineBivariateVonMises, @@ -1646,8 +1647,12 @@ def fn(args): continue args_lhs = [p if j != i else p - eps for j, p in enumerate(repara_params)] args_rhs = [p if j != i else p + eps for j, p in enumerate(repara_params)] - fn_lhs = fn(args_lhs) - fn_rhs = fn(args_rhs) + # the finite-difference reference perturbs parameters off their + # constraint manifold (e.g. scale_tril, simplex probs), so disable + # validation here; jax.grad above traces and is unaffected. + with numpyro.validation_enabled(False): + fn_lhs = fn(args_lhs) + fn_rhs = fn(args_rhs) # finite diff approximation expected_grad = (fn_rhs - fn_lhs) / (2.0 * eps) assert jnp.shape(actual_grad[i]) == jnp.shape(repara_params[i]) @@ -2183,8 +2188,12 @@ def fn(*args): actual_grad = jax.grad(fn, i)(*params) args_lhs = [p if j != i else p - eps for j, p in enumerate(params)] args_rhs = [p if j != i else p + eps for j, p in enumerate(params)] - fn_lhs = fn(*args_lhs) - fn_rhs = fn(*args_rhs) + # the finite-difference reference perturbs parameters off their + # constraint manifold (e.g. scale_tril, simplex probs), so disable + # validation here; jax.grad above traces and is unaffected. + with numpyro.validation_enabled(False): + fn_lhs = fn(*args_lhs) + fn_rhs = fn(*args_rhs) # finite diff approximation expected_grad = (fn_rhs - fn_lhs) / (2.0 * eps) assert jnp.shape(actual_grad) == jnp.shape(params[i]) @@ -2485,7 +2494,12 @@ def test_distribution_constraints(jax_dist, sp_dist, params, prepend_shape): # a > 0 and b > 0. Then, make b = a + b. valid_params[1] += valid_params[0] - assert jax_dist(*oob_params) + # Out-of-bounds params must still construct when validation is off. Use the + # context manager (not the validate_args kwarg) so that compound + # distributions, which build internal distributions without forwarding + # validate_args, also skip validation. + with numpyro.validation_enabled(False): + assert jax_dist(*oob_params) # Invalid parameter values throw ValueError if not dependent_constraint and ( From c94bfb051ce85edd98cdfb44ff0a4c5cb7a1c1a6 Mon Sep 17 00:00:00 2001 From: Juan Orduz Date: Fri, 5 Jun 2026 12:34:33 +0200 Subject: [PATCH 12/15] Stabilize test_functional_map convergence pmap-NUTS landed just over rtol=0.06 on the mean (0.937 vs 1.0) because pmap and vmap accumulate floats differently on a borderline estimate. Bump num_samples to 15000 (keeping num_warmup=2000) to shrink the Monte Carlo error; all map_fn/algo combinations now pass with a comfortable margin (worst relative error ~0.023). Note: higher warmup destabilizes the fixed-trajectory HMC chains, so only the sample count is increased. Co-Authored-By: Claude Opus 4.8 --- test/infer/test_mcmc.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/test/infer/test_mcmc.py b/test/infer/test_mcmc.py index e7cea8814..c83f8ff52 100644 --- a/test/infer/test_mcmc.py +++ b/test/infer/test_mcmc.py @@ -738,7 +738,9 @@ def test_functional_map(algo, map_fn): pytest.skip("pmap test requires device_count greater than 1.") true_mean, true_std = 1.0, 2.0 - num_warmup, num_samples = 2000, 10_000 + # enough samples so the Monte Carlo error on the mean stays within tolerance + # for every map_fn (pmap and vmap accumulate floats differently). + num_warmup, num_samples = 2000, 15_000 def potential_fn(z): return 0.5 * jnp.sum(((z - true_mean) / true_std) ** 2) From b7cb852fdde1fac5132cce4f86ed4fec18f22837 Mon Sep 17 00:00:00 2001 From: Juan Orduz Date: Fri, 5 Jun 2026 15:54:47 +0200 Subject: [PATCH 13/15] Address review: clip / reparametrize instead of validate_args=False Replace the validate_args=False fallbacks with root-cause fixes: - HSGP _centered_approximation: clip spd (sqrt spectral density) to a tiny positive floor so the scale stays valid when it underflows to 0. - EulerMaruyama / GaussianRandomWalk log_prob: use Normal's location invariance, evaluating the residual / increments under a zero-mean Normal so the loc stays valid for out-of-support values. The public log_prob's @validate_sample still warns about out-of-support values. Co-Authored-By: Claude Opus 4.8 --- numpyro/contrib/hsgp/approximation.py | 11 +++++------ numpyro/distributions/continuous.py | 24 +++++++++++------------- 2 files changed, 16 insertions(+), 19 deletions(-) diff --git a/numpyro/contrib/hsgp/approximation.py b/numpyro/contrib/hsgp/approximation.py index 045817e95..8b9c080bd 100644 --- a/numpyro/contrib/hsgp/approximation.py +++ b/numpyro/contrib/hsgp/approximation.py @@ -30,13 +30,12 @@ def _non_centered_approximation(phi: Array, spd: Array, m: int) -> Array: def _centered_approximation(phi: Array, spd: Array, m: int) -> Array: + # ``spd`` is the square root of the spectral density and can underflow to + # exactly 0 for high-frequency basis functions; floor it to a tiny positive + # value so the (degenerate) coefficient keeps a valid positive scale. + spd = jnp.clip(spd, min=jnp.finfo(spd.dtype).tiny) with numpyro.plate("basis", m): - # ``spd`` is the square root of the spectral density and can underflow to - # exactly 0 for high-frequency basis functions, yielding a benign - # zero-variance (degenerate) coefficient. Skip validation for that edge. - beta = numpyro.sample( - "beta", dist.Normal(loc=0.0, scale=spd, validate_args=False) - ) + beta = numpyro.sample("beta", dist.Normal(loc=0.0, scale=spd)) return phi @ beta diff --git a/numpyro/distributions/continuous.py b/numpyro/distributions/continuous.py index 12372967c..39a37615c 100644 --- a/numpyro/distributions/continuous.py +++ b/numpyro/distributions/continuous.py @@ -654,12 +654,11 @@ def log_prob(self, value: ArrayLike) -> ArrayLike: mu = xtm1 + dt * f sigma = jnp.sqrt(dt) * g - # ``mu``/``sigma`` are derived from ``value``; out-of-support values - # would make them invalid, so skip validation of this internal - # distribution. The public log_prob already validates ``value``. - sde_log_prob = ( - Normal(mu, sigma, validate_args=False).to_event(self.event_dim).log_prob(xt) - ) + # Normal is location-invariant, so evaluate the residual under a + # zero-mean Normal. This keeps the loc valid even for out-of-support + # ``value`` (where ``mu`` would be NaN); the public log_prob's + # @validate_sample still warns about out-of-support values. + sde_log_prob = Normal(0.0, sigma).to_event(self.event_dim).log_prob(xt - mu) init_log_prob = self.init_dist.log_prob(value0) return sde_log_prob + init_log_prob @@ -1099,14 +1098,13 @@ def sample(self, key: jax.Array, sample_shape: tuple[int, ...] = ()) -> ArrayLik @validate_sample def log_prob(self, value: ArrayLike) -> ArrayLike: - # the step distribution uses ``value`` as its location, so out-of-support - # values would make it invalid; skip validation of these internal - # distributions since the public log_prob already validates ``value``. - init_prob = Normal(0.0, self.scale, validate_args=False).log_prob(value[..., 0]) + init_prob = Normal(0.0, self.scale).log_prob(value[..., 0]) scale = jnp.expand_dims(self.scale, -1) - step_probs = Normal(value[..., :-1], scale, validate_args=False).log_prob( - value[..., 1:] - ) + # Normal is location-invariant, so evaluate the increments under a + # zero-mean Normal. This keeps the loc valid even for out-of-support + # ``value``; the public log_prob's @validate_sample still warns about + # out-of-support values. + step_probs = Normal(0.0, scale).log_prob(value[..., 1:] - value[..., :-1]) return init_prob + jnp.sum(step_probs, axis=-1) @property From 8a75729b1e77f3601ee69098d1ad7ca687a7122b Mon Sep 17 00:00:00 2001 From: Juan Orduz Date: Fri, 5 Jun 2026 16:54:12 +0200 Subject: [PATCH 14/15] rm kwarg --- numpyro/contrib/hsgp/approximation.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/numpyro/contrib/hsgp/approximation.py b/numpyro/contrib/hsgp/approximation.py index 8b9c080bd..e179790f5 100644 --- a/numpyro/contrib/hsgp/approximation.py +++ b/numpyro/contrib/hsgp/approximation.py @@ -33,7 +33,7 @@ def _centered_approximation(phi: Array, spd: Array, m: int) -> Array: # ``spd`` is the square root of the spectral density and can underflow to # exactly 0 for high-frequency basis functions; floor it to a tiny positive # value so the (degenerate) coefficient keeps a valid positive scale. - spd = jnp.clip(spd, min=jnp.finfo(spd.dtype).tiny) + spd = jnp.clip(spd, jnp.finfo(spd.dtype).tiny) with numpyro.plate("basis", m): beta = numpyro.sample("beta", dist.Normal(loc=0.0, scale=spd)) From e03f20e6da352bbdd03aa62b35fcb189cc794c34 Mon Sep 17 00:00:00 2001 From: Juan Orduz Date: Fri, 5 Jun 2026 17:17:03 +0200 Subject: [PATCH 15/15] Fix pyroapi test_constraints instead of compat param projection Revert the numpyro/compat/pyro.py constraint projection and instead override the vendored pyroapi test_constraints with a valid simplex init for the guide param `q` (was an unnormalized exp(randn(3))), so it passes with validation enabled by default. Co-Authored-By: Claude Opus 4.8 --- numpyro/compat/pyro.py | 19 +------------------ test/pyroapi/test_pyroapi.py | 28 ++++++++++++++++++++++++++++ 2 files changed, 29 insertions(+), 18 deletions(-) diff --git a/numpyro/compat/pyro.py b/numpyro/compat/pyro.py index 916bb6785..18a443748 100644 --- a/numpyro/compat/pyro.py +++ b/numpyro/compat/pyro.py @@ -3,11 +3,8 @@ import warnings -import numpy as np - from numpyro.compat.util import UnsupportedAPIWarning -from numpyro.distributions.transforms import biject_to -from numpyro.util import find_stack_level, not_jax_tracer +from numpyro.util import find_stack_level from numpyro.primitives import module, plate, sample # noqa: F401 isort:skip from numpyro.primitives import param as _param # noqa: F401 isort:skip @@ -41,18 +38,4 @@ def param(name, *args, **kwargs): param_store = get_param_store() if name in param_store: val = param_store[name] - # Match Pyro's constrained-param semantics: an init value is treated as a - # constrained value and projected onto its constraint. NumPyro returns the - # raw init at trace time, so project it here when it is concrete and does not - # already satisfy the constraint (e.g. an unnormalized simplex value). Valid - # values (including boundary points) and tracers are left untouched. - constraint = kwargs.get("constraint") - if ( - val is not None - and constraint is not None - and not_jax_tracer(val) - and not np.all(constraint(val)) - ): - transform = biject_to(constraint) - val = transform(transform.inv(val)) return val diff --git a/test/pyroapi/test_pyroapi.py b/test/pyroapi/test_pyroapi.py index 1454e496a..f640aa1d3 100644 --- a/test/pyroapi/test_pyroapi.py +++ b/test/pyroapi/test_pyroapi.py @@ -2,7 +2,9 @@ # SPDX-License-Identifier: Apache-2.0 from pyroapi import pyro_backend +from pyroapi.dispatch import distributions as dist, infer, ops, pyro from pyroapi.tests import * # noqa F401 +from pyroapi.tests.test_svi import assert_ok import pytest from numpyro.infer import RenyiELBO, Trace_ELBO, TraceMeanField_ELBO @@ -26,3 +28,29 @@ def backend(): with pyro_backend("numpy"): yield + + +# pyroapi's test_constraints inits the simplex param `q` with an unnormalized +# exp(randn(3)); use a valid simplex so it passes with validation enabled. +@pytest.mark.parametrize("jit", [False, True], ids=["py", "jit"]) +def test_constraints(backend, jit): + data = ops.tensor(0.5) + + def model(): + locs = pyro.param("locs", ops.randn(3), constraint=dist.constraints.real) + scales = pyro.param( + "scales", ops.exp(ops.randn(3)), constraint=dist.constraints.positive + ) + p = ops.tensor([0.5, 0.3, 0.2]) + x = pyro.sample("x", dist.Categorical(p)) + pyro.sample("obs", dist.Normal(locs[x], scales[x]), obs=data) + + def guide(): + q = pyro.param( + "q", ops.tensor([0.4, 0.3, 0.3]), constraint=dist.constraints.simplex + ) + pyro.sample("x", dist.Categorical(q)) + + Elbo = infer.JitTrace_ELBO if jit else infer.Trace_ELBO + elbo = Elbo(ignore_jit_warnings=True) + assert_ok(model, guide, elbo)