diff --git a/numpyro/contrib/hsgp/approximation.py b/numpyro/contrib/hsgp/approximation.py index fe075142f..e179790f5 100644 --- a/numpyro/contrib/hsgp/approximation.py +++ b/numpyro/contrib/hsgp/approximation.py @@ -30,6 +30,10 @@ def _non_centered_approximation(phi: Array, spd: Array, m: int) -> Array: def _centered_approximation(phi: Array, spd: Array, m: int) -> Array: + # ``spd`` is the square root of the spectral density and can underflow to + # exactly 0 for high-frequency basis functions; floor it to a tiny positive + # value so the (degenerate) coefficient keeps a valid positive scale. + spd = jnp.clip(spd, jnp.finfo(spd.dtype).tiny) with numpyro.plate("basis", m): beta = numpyro.sample("beta", dist.Normal(loc=0.0, scale=spd)) diff --git a/numpyro/distributions/continuous.py b/numpyro/distributions/continuous.py index 91c504ad8..39a37615c 100644 --- a/numpyro/distributions/continuous.py +++ b/numpyro/distributions/continuous.py @@ -654,7 +654,11 @@ def log_prob(self, value: ArrayLike) -> ArrayLike: mu = xtm1 + dt * f sigma = jnp.sqrt(dt) * g - sde_log_prob = Normal(mu, sigma).to_event(self.event_dim).log_prob(xt) + # Normal is location-invariant, so evaluate the residual under a + # zero-mean Normal. This keeps the loc valid even for out-of-support + # ``value`` (where ``mu`` would be NaN); the public log_prob's + # @validate_sample still warns about out-of-support values. + sde_log_prob = Normal(0.0, sigma).to_event(self.event_dim).log_prob(xt - mu) init_log_prob = self.init_dist.log_prob(value0) return sde_log_prob + init_log_prob @@ -1096,7 +1100,11 @@ def sample(self, key: jax.Array, sample_shape: tuple[int, ...] = ()) -> ArrayLik def log_prob(self, value: ArrayLike) -> ArrayLike: init_prob = Normal(0.0, self.scale).log_prob(value[..., 0]) scale = jnp.expand_dims(self.scale, -1) - step_probs = Normal(value[..., :-1], scale).log_prob(value[..., 1:]) + # Normal is location-invariant, so evaluate the increments under a + # zero-mean Normal. This keeps the loc valid even for out-of-support + # ``value``; the public log_prob's @validate_sample still warns about + # out-of-support values. + step_probs = Normal(0.0, scale).log_prob(value[..., 1:] - value[..., :-1]) return init_prob + jnp.sum(step_probs, axis=-1) @property diff --git a/numpyro/distributions/distribution.py b/numpyro/distributions/distribution.py index 7b12879eb..ac5b51914 100644 --- a/numpyro/distributions/distribution.py +++ b/numpyro/distributions/distribution.py @@ -155,7 +155,7 @@ class Distribution(metaclass=DistributionMeta): arg_constraints: dict[str, Any] = {} _support: Optional[constraints.Constraint] = None - _validate_args: bool = False + _validate_args: bool = _VALIDATION_ENABLED _arg_names: ClassVar[Optional[tuple[str, ...]]] = None pytree_data_fields: tuple[str, ...] = () pytree_aux_fields: tuple[str, ...] = ("_batch_shape", "_event_shape") diff --git a/numpyro/infer/autoguide.py b/numpyro/infer/autoguide.py index d722d30ce..2e6cafb6c 100644 --- a/numpyro/infer/autoguide.py +++ b/numpyro/infer/autoguide.py @@ -2243,7 +2243,12 @@ def get_posterior(self, params): Returns a multivariate Normal posterior distribution. """ transform = self.get_transform(params) - return dist.MultivariateNormal(transform.loc, scale_tril=transform.scale_tril) + # When the Hessian at the MAP point is singular, ``get_transform`` warns + # and falls back to a zeroed ``scale_tril`` (a degenerate, constant + # posterior). That fallback is intentional, so skip argument validation. + return dist.MultivariateNormal( + transform.loc, scale_tril=transform.scale_tril, validate_args=False + ) def sample_posterior(self, rng_key, params, *args, sample_shape=(), **kwargs): latent_sample = self.get_posterior(params).sample(rng_key, sample_shape) diff --git a/test/contrib/test_enum_elbo.py b/test/contrib/test_enum_elbo.py index 6bd5b5010..9a4afd569 100644 --- a/test/contrib/test_enum_elbo.py +++ b/test/contrib/test_enum_elbo.py @@ -746,7 +746,7 @@ def test_elbo_enumerate_plate_7(scale): [[[0.4, 0.6], [0.3, 0.7]], [[0.3, 0.7], [0.2, 0.8]]] ) params["model_probs_e"] = jnp.array([[0.75, 0.25], [0.55, 0.45]]) - params["guide_probs_a"] = jnp.array([0.35, 0.64]) + params["guide_probs_a"] = jnp.array([0.35, 0.65]) params["guide_probs_c"] = jnp.array([[0.0, 1.0], [1.0, 0.0]]) # deterministic @handlers.scale(scale=scale) @@ -772,6 +772,7 @@ def auto_model(data, params): params["model_probs_e"], constraint=constraints.simplex, ) + a = pyro.sample("a", dist.Categorical(probs_a)) b = pyro.sample( "b", dist.Categorical(probs_b[a]), infer={"enumerate": "parallel"} @@ -793,6 +794,7 @@ def auto_guide(data, params): probs_c = pyro.param( "guide_probs_c", params["guide_probs_c"], constraint=constraints.simplex ) + a = pyro.sample("a", dist.Categorical(probs_a), infer={"enumerate": "parallel"}) with pyro.plate("data", 2): pyro.sample("c", dist.Categorical(probs_c[a])) @@ -820,6 +822,7 @@ def hand_model(data, params): params["model_probs_e"], constraint=constraints.simplex, ) + a = pyro.sample("a", dist.Categorical(probs_a)) b = pyro.sample( "b", dist.Categorical(probs_b[a]), infer={"enumerate": "parallel"} @@ -841,6 +844,7 @@ def hand_guide(data, params): probs_c = pyro.param( "guide_probs_c", params["guide_probs_c"], constraint=constraints.simplex ) + a = pyro.sample("a", dist.Categorical(probs_a), infer={"enumerate": "parallel"}) for i in range(2): pyro.sample(f"c_{i}", dist.Categorical(probs_c[a])) diff --git a/test/infer/test_hmc_util.py b/test/infer/test_hmc_util.py index d0998895f..d3cd4f861 100644 --- a/test/infer/test_hmc_util.py +++ b/test/infer/test_hmc_util.py @@ -468,8 +468,13 @@ def test_gaussian_subposterior(method, diagonal): @pytest.mark.parametrize("method", [consensus, parametric_draws]) def test_subposterior_structure(method): + # use non-degenerate draws so the per-subposterior covariance is invertible subposteriors = [ - {"x": jnp.ones((100, 3)), "y": jnp.zeros((100,))} for i in range(10) + { + "x": random.normal(random.key(i), (100, 3)), + "y": random.normal(random.key(10 + i), (100,)), + } + for i in range(10) ] draws = method(subposteriors, num_draws=9) assert draws["x"].shape == (9, 3) diff --git a/test/infer/test_infer_util.py b/test/infer/test_infer_util.py index 364a146dc..1207311ef 100644 --- a/test/infer/test_infer_util.py +++ b/test/infer/test_infer_util.py @@ -81,9 +81,8 @@ def model(X, y=None): def categorical_probs(): - probs0 = 0.5 nbatch0, nbatch1 = 2, 1 - probs = jnp.ones((nbatch0, nbatch1, 3)) * probs0 + probs = jnp.ones((nbatch0, nbatch1, 3)) / 3 def model(probs): probs = numpyro.deterministic("probs", probs) diff --git a/test/infer/test_inspect.py b/test/infer/test_inspect.py index 675b27f5b..19182beaf 100644 --- a/test/infer/test_inspect.py +++ b/test/infer/test_inspect.py @@ -18,7 +18,7 @@ class NonreparameterizedNormal(dist.Normal): def test_get_dependencies(): def model(data): a = numpyro.sample("a", dist.Normal(0, 1)) - b = numpyro.sample("b", NonreparameterizedNormal(a, 0)) + b = numpyro.sample("b", NonreparameterizedNormal(a, 1)) c = numpyro.sample("c", dist.Normal(b, 1)) d = numpyro.sample("d", dist.Normal(a, jnp.exp(c))) diff --git a/test/infer/test_mcmc.py b/test/infer/test_mcmc.py index 39637e23d..c83f8ff52 100644 --- a/test/infer/test_mcmc.py +++ b/test/infer/test_mcmc.py @@ -738,7 +738,9 @@ def test_functional_map(algo, map_fn): pytest.skip("pmap test requires device_count greater than 1.") true_mean, true_std = 1.0, 2.0 - num_warmup, num_samples = 1000, 8000 + # enough samples so the Monte Carlo error on the mean stays within tolerance + # for every map_fn (pmap and vmap accumulate floats differently). + num_warmup, num_samples = 2000, 15_000 def potential_fn(z): return 0.5 * jnp.sum(((z - true_mean) / true_std) ** 2) diff --git a/test/ops/test_indexing.py b/test/ops/test_indexing.py index fb8016319..2809d68b3 100644 --- a/test/ops/test_indexing.py +++ b/test/ops/test_indexing.py @@ -107,8 +107,8 @@ def test_shape(expression, expected_shape): @pytest.mark.parametrize("x_shape", [(), (2,), (3, 1), (4, 1, 1), (4, 3, 2)], ids=str) def test_value(x_shape, i_shape, j_shape, event_shape): x = jnp.array(np.random.rand(*(x_shape + (5, 6) + event_shape))) - i = dist.Categorical(jnp.ones((5,))).sample(random.key(1), i_shape) - j = dist.Categorical(jnp.ones((6,))).sample(random.key(2), j_shape) + i = dist.Categorical(jnp.ones((5,)) / 5).sample(random.key(1), i_shape) + j = dist.Categorical(jnp.ones((6,)) / 6).sample(random.key(2), j_shape) if event_shape: actual = Vindex(x)[..., i, j, :] else: diff --git a/test/pyroapi/test_pyroapi.py b/test/pyroapi/test_pyroapi.py index 1454e496a..f640aa1d3 100644 --- a/test/pyroapi/test_pyroapi.py +++ b/test/pyroapi/test_pyroapi.py @@ -2,7 +2,9 @@ # SPDX-License-Identifier: Apache-2.0 from pyroapi import pyro_backend +from pyroapi.dispatch import distributions as dist, infer, ops, pyro from pyroapi.tests import * # noqa F401 +from pyroapi.tests.test_svi import assert_ok import pytest from numpyro.infer import RenyiELBO, Trace_ELBO, TraceMeanField_ELBO @@ -26,3 +28,29 @@ def backend(): with pyro_backend("numpy"): yield + + +# pyroapi's test_constraints inits the simplex param `q` with an unnormalized +# exp(randn(3)); use a valid simplex so it passes with validation enabled. +@pytest.mark.parametrize("jit", [False, True], ids=["py", "jit"]) +def test_constraints(backend, jit): + data = ops.tensor(0.5) + + def model(): + locs = pyro.param("locs", ops.randn(3), constraint=dist.constraints.real) + scales = pyro.param( + "scales", ops.exp(ops.randn(3)), constraint=dist.constraints.positive + ) + p = ops.tensor([0.5, 0.3, 0.2]) + x = pyro.sample("x", dist.Categorical(p)) + pyro.sample("obs", dist.Normal(locs[x], scales[x]), obs=data) + + def guide(): + q = pyro.param( + "q", ops.tensor([0.4, 0.3, 0.3]), constraint=dist.constraints.simplex + ) + pyro.sample("x", dist.Categorical(q)) + + Elbo = infer.JitTrace_ELBO if jit else infer.Trace_ELBO + elbo = Elbo(ignore_jit_warnings=True) + assert_ok(model, guide, elbo) diff --git a/test/test_distributions.py b/test/test_distributions.py index 256cb1f82..a27876b6f 100644 --- a/test/test_distributions.py +++ b/test/test_distributions.py @@ -24,6 +24,7 @@ from jax.scipy.special import expit, logsumexp from jax.scipy.stats import norm as jax_norm, truncnorm as jax_truncnorm +import numpyro import numpyro.distributions as dist from numpyro.distributions import ( SineBivariateVonMises, @@ -1646,8 +1647,12 @@ def fn(args): continue args_lhs = [p if j != i else p - eps for j, p in enumerate(repara_params)] args_rhs = [p if j != i else p + eps for j, p in enumerate(repara_params)] - fn_lhs = fn(args_lhs) - fn_rhs = fn(args_rhs) + # the finite-difference reference perturbs parameters off their + # constraint manifold (e.g. scale_tril, simplex probs), so disable + # validation here; jax.grad above traces and is unaffected. + with numpyro.validation_enabled(False): + fn_lhs = fn(args_lhs) + fn_rhs = fn(args_rhs) # finite diff approximation expected_grad = (fn_rhs - fn_lhs) / (2.0 * eps) assert jnp.shape(actual_grad[i]) == jnp.shape(repara_params[i]) @@ -2044,8 +2049,8 @@ def test_beta_binomial_log_prob(total_count, shape): @pytest.mark.parametrize("n", [1, 2, 5, 10]) @pytest.mark.parametrize("shape", [(1,), (3, 1), (2, 3, 1)]) def test_beta_negative_binomial_log_prob(n, shape): - concentration0 = np.exp(np.random.normal(size=shape)) - concentration1 = np.exp(np.random.normal(size=shape)) + concentration0 = 1 + np.exp(np.random.normal(size=shape)) + concentration1 = 1 + np.exp(np.random.normal(size=shape)) value = jnp.arange(15) num_samples = 300000 @@ -2183,8 +2188,12 @@ def fn(*args): actual_grad = jax.grad(fn, i)(*params) args_lhs = [p if j != i else p - eps for j, p in enumerate(params)] args_rhs = [p if j != i else p + eps for j, p in enumerate(params)] - fn_lhs = fn(*args_lhs) - fn_rhs = fn(*args_rhs) + # the finite-difference reference perturbs parameters off their + # constraint manifold (e.g. scale_tril, simplex probs), so disable + # validation here; jax.grad above traces and is unaffected. + with numpyro.validation_enabled(False): + fn_lhs = fn(*args_lhs) + fn_rhs = fn(*args_rhs) # finite diff approximation expected_grad = (fn_rhs - fn_lhs) / (2.0 * eps) assert jnp.shape(actual_grad) == jnp.shape(params[i]) @@ -2485,7 +2494,12 @@ def test_distribution_constraints(jax_dist, sp_dist, params, prepend_shape): # a > 0 and b > 0. Then, make b = a + b. valid_params[1] += valid_params[0] - assert jax_dist(*oob_params) + # Out-of-bounds params must still construct when validation is off. Use the + # context manager (not the validate_args kwarg) so that compound + # distributions, which build internal distributions without forwarding + # validate_args, also skip validation. + with numpyro.validation_enabled(False): + assert jax_dist(*oob_params) # Invalid parameter values throw ValueError if not dependent_constraint and (