Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions numpyro/contrib/hsgp/approximation.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,10 @@ def _non_centered_approximation(phi: Array, spd: Array, m: int) -> Array:


def _centered_approximation(phi: Array, spd: Array, m: int) -> Array:
# ``spd`` is the square root of the spectral density and can underflow to
# exactly 0 for high-frequency basis functions; floor it to a tiny positive
# value so the (degenerate) coefficient keeps a valid positive scale.
spd = jnp.clip(spd, jnp.finfo(spd.dtype).tiny)
with numpyro.plate("basis", m):
beta = numpyro.sample("beta", dist.Normal(loc=0.0, scale=spd))

Expand Down
12 changes: 10 additions & 2 deletions numpyro/distributions/continuous.py
Original file line number Diff line number Diff line change
Expand Up @@ -654,7 +654,11 @@ def log_prob(self, value: ArrayLike) -> ArrayLike:
mu = xtm1 + dt * f
sigma = jnp.sqrt(dt) * g

sde_log_prob = Normal(mu, sigma).to_event(self.event_dim).log_prob(xt)
# Normal is location-invariant, so evaluate the residual under a
# zero-mean Normal. This keeps the loc valid even for out-of-support
# ``value`` (where ``mu`` would be NaN); the public log_prob's
# @validate_sample still warns about out-of-support values.
sde_log_prob = Normal(0.0, sigma).to_event(self.event_dim).log_prob(xt - mu)

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice solution to address NaN mu.

init_log_prob = self.init_dist.log_prob(value0)

return sde_log_prob + init_log_prob
Expand Down Expand Up @@ -1096,7 +1100,11 @@ def sample(self, key: jax.Array, sample_shape: tuple[int, ...] = ()) -> ArrayLik
def log_prob(self, value: ArrayLike) -> ArrayLike:
init_prob = Normal(0.0, self.scale).log_prob(value[..., 0])
scale = jnp.expand_dims(self.scale, -1)
step_probs = Normal(value[..., :-1], scale).log_prob(value[..., 1:])
# Normal is location-invariant, so evaluate the increments under a
# zero-mean Normal. This keeps the loc valid even for out-of-support
# ``value``; the public log_prob's @validate_sample still warns about
# out-of-support values.
step_probs = Normal(0.0, scale).log_prob(value[..., 1:] - value[..., :-1])
return init_prob + jnp.sum(step_probs, axis=-1)

@property
Expand Down
2 changes: 1 addition & 1 deletion numpyro/distributions/distribution.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,7 +155,7 @@ class Distribution(metaclass=DistributionMeta):

arg_constraints: dict[str, Any] = {}
_support: Optional[constraints.Constraint] = None
_validate_args: bool = False
_validate_args: bool = _VALIDATION_ENABLED
_arg_names: ClassVar[Optional[tuple[str, ...]]] = None
pytree_data_fields: tuple[str, ...] = ()
pytree_aux_fields: tuple[str, ...] = ("_batch_shape", "_event_shape")
Expand Down
7 changes: 6 additions & 1 deletion numpyro/infer/autoguide.py
Original file line number Diff line number Diff line change
Expand Up @@ -2243,7 +2243,12 @@ def get_posterior(self, params):
Returns a multivariate Normal posterior distribution.
"""
transform = self.get_transform(params)
return dist.MultivariateNormal(transform.loc, scale_tril=transform.scale_tril)
# When the Hessian at the MAP point is singular, ``get_transform`` warns
# and falls back to a zeroed ``scale_tril`` (a degenerate, constant
# posterior). That fallback is intentional, so skip argument validation.
return dist.MultivariateNormal(
transform.loc, scale_tril=transform.scale_tril, validate_args=False
)

def sample_posterior(self, rng_key, params, *args, sample_shape=(), **kwargs):
latent_sample = self.get_posterior(params).sample(rng_key, sample_shape)
Expand Down
6 changes: 5 additions & 1 deletion test/contrib/test_enum_elbo.py
Original file line number Diff line number Diff line change
Expand Up @@ -746,7 +746,7 @@ def test_elbo_enumerate_plate_7(scale):
[[[0.4, 0.6], [0.3, 0.7]], [[0.3, 0.7], [0.2, 0.8]]]
)
params["model_probs_e"] = jnp.array([[0.75, 0.25], [0.55, 0.45]])
params["guide_probs_a"] = jnp.array([0.35, 0.64])
params["guide_probs_a"] = jnp.array([0.35, 0.65])
params["guide_probs_c"] = jnp.array([[0.0, 1.0], [1.0, 0.0]]) # deterministic

@handlers.scale(scale=scale)
Expand All @@ -772,6 +772,7 @@ def auto_model(data, params):
params["model_probs_e"],
constraint=constraints.simplex,
)

a = pyro.sample("a", dist.Categorical(probs_a))
b = pyro.sample(
"b", dist.Categorical(probs_b[a]), infer={"enumerate": "parallel"}
Expand All @@ -793,6 +794,7 @@ def auto_guide(data, params):
probs_c = pyro.param(
"guide_probs_c", params["guide_probs_c"], constraint=constraints.simplex
)

a = pyro.sample("a", dist.Categorical(probs_a), infer={"enumerate": "parallel"})
with pyro.plate("data", 2):
pyro.sample("c", dist.Categorical(probs_c[a]))
Expand Down Expand Up @@ -820,6 +822,7 @@ def hand_model(data, params):
params["model_probs_e"],
constraint=constraints.simplex,
)

a = pyro.sample("a", dist.Categorical(probs_a))
b = pyro.sample(
"b", dist.Categorical(probs_b[a]), infer={"enumerate": "parallel"}
Expand All @@ -841,6 +844,7 @@ def hand_guide(data, params):
probs_c = pyro.param(
"guide_probs_c", params["guide_probs_c"], constraint=constraints.simplex
)

a = pyro.sample("a", dist.Categorical(probs_a), infer={"enumerate": "parallel"})
for i in range(2):
pyro.sample(f"c_{i}", dist.Categorical(probs_c[a]))
Expand Down
7 changes: 6 additions & 1 deletion test/infer/test_hmc_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -468,8 +468,13 @@ def test_gaussian_subposterior(method, diagonal):

@pytest.mark.parametrize("method", [consensus, parametric_draws])
def test_subposterior_structure(method):
# use non-degenerate draws so the per-subposterior covariance is invertible
subposteriors = [
{"x": jnp.ones((100, 3)), "y": jnp.zeros((100,))} for i in range(10)
{
"x": random.normal(random.key(i), (100, 3)),
"y": random.normal(random.key(10 + i), (100,)),
}
for i in range(10)
]
draws = method(subposteriors, num_draws=9)
assert draws["x"].shape == (9, 3)
Expand Down
3 changes: 1 addition & 2 deletions test/infer/test_infer_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,9 +81,8 @@ def model(X, y=None):


def categorical_probs():
probs0 = 0.5
nbatch0, nbatch1 = 2, 1
probs = jnp.ones((nbatch0, nbatch1, 3)) * probs0
probs = jnp.ones((nbatch0, nbatch1, 3)) / 3

def model(probs):
probs = numpyro.deterministic("probs", probs)
Expand Down
2 changes: 1 addition & 1 deletion test/infer/test_inspect.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ class NonreparameterizedNormal(dist.Normal):
def test_get_dependencies():
def model(data):
a = numpyro.sample("a", dist.Normal(0, 1))
b = numpyro.sample("b", NonreparameterizedNormal(a, 0))
b = numpyro.sample("b", NonreparameterizedNormal(a, 1))
c = numpyro.sample("c", dist.Normal(b, 1))
d = numpyro.sample("d", dist.Normal(a, jnp.exp(c)))

Expand Down
4 changes: 3 additions & 1 deletion test/infer/test_mcmc.py
Original file line number Diff line number Diff line change
Expand Up @@ -738,7 +738,9 @@ def test_functional_map(algo, map_fn):
pytest.skip("pmap test requires device_count greater than 1.")

true_mean, true_std = 1.0, 2.0
num_warmup, num_samples = 1000, 8000
# enough samples so the Monte Carlo error on the mean stays within tolerance
# for every map_fn (pmap and vmap accumulate floats differently).
num_warmup, num_samples = 2000, 15_000

def potential_fn(z):
return 0.5 * jnp.sum(((z - true_mean) / true_std) ** 2)
Expand Down
4 changes: 2 additions & 2 deletions test/ops/test_indexing.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,8 +107,8 @@ def test_shape(expression, expected_shape):
@pytest.mark.parametrize("x_shape", [(), (2,), (3, 1), (4, 1, 1), (4, 3, 2)], ids=str)
def test_value(x_shape, i_shape, j_shape, event_shape):
x = jnp.array(np.random.rand(*(x_shape + (5, 6) + event_shape)))
i = dist.Categorical(jnp.ones((5,))).sample(random.key(1), i_shape)
j = dist.Categorical(jnp.ones((6,))).sample(random.key(2), j_shape)
i = dist.Categorical(jnp.ones((5,)) / 5).sample(random.key(1), i_shape)
j = dist.Categorical(jnp.ones((6,)) / 6).sample(random.key(2), j_shape)
if event_shape:
actual = Vindex(x)[..., i, j, :]
else:
Expand Down
28 changes: 28 additions & 0 deletions test/pyroapi/test_pyroapi.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,9 @@
# SPDX-License-Identifier: Apache-2.0

from pyroapi import pyro_backend
from pyroapi.dispatch import distributions as dist, infer, ops, pyro
from pyroapi.tests import * # noqa F401
from pyroapi.tests.test_svi import assert_ok
import pytest

from numpyro.infer import RenyiELBO, Trace_ELBO, TraceMeanField_ELBO
Expand All @@ -26,3 +28,29 @@
def backend():
with pyro_backend("numpy"):
yield


# pyroapi's test_constraints inits the simplex param `q` with an unnormalized
# exp(randn(3)); use a valid simplex so it passes with validation enabled.

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

exp(randn(3)) for q lives in the vendored upstream pyroapi package (pyroapi/tests/test_svi.py), see https://github.com/pyro-ppl/pyro-api/blob/master/pyroapi/tests/test_svi.py#L168

To provide a valid simplex (as suggested) we therefore redefine test_constraints in our own test/pyroapi/test_pyroapi.py, identical to upstream except q is initialized to a valid simplex [0.4, 0.3, 0.3] instead of the unnormalized exp(randn(3)). Since the test name is imported via from pyroapi.tests import *, redefining it shadows the upstream one, it does not add a test.

@pytest.mark.parametrize("jit", [False, True], ids=["py", "jit"])
def test_constraints(backend, jit):
data = ops.tensor(0.5)

def model():
locs = pyro.param("locs", ops.randn(3), constraint=dist.constraints.real)
scales = pyro.param(
"scales", ops.exp(ops.randn(3)), constraint=dist.constraints.positive
)
p = ops.tensor([0.5, 0.3, 0.2])
x = pyro.sample("x", dist.Categorical(p))
pyro.sample("obs", dist.Normal(locs[x], scales[x]), obs=data)

def guide():
q = pyro.param(
"q", ops.tensor([0.4, 0.3, 0.3]), constraint=dist.constraints.simplex
)
pyro.sample("x", dist.Categorical(q))

Elbo = infer.JitTrace_ELBO if jit else infer.Trace_ELBO
elbo = Elbo(ignore_jit_warnings=True)
assert_ok(model, guide, elbo)
28 changes: 21 additions & 7 deletions test/test_distributions.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
from jax.scipy.special import expit, logsumexp
from jax.scipy.stats import norm as jax_norm, truncnorm as jax_truncnorm

import numpyro
import numpyro.distributions as dist
from numpyro.distributions import (
SineBivariateVonMises,
Expand Down Expand Up @@ -1646,8 +1647,12 @@ def fn(args):
continue
args_lhs = [p if j != i else p - eps for j, p in enumerate(repara_params)]
args_rhs = [p if j != i else p + eps for j, p in enumerate(repara_params)]
fn_lhs = fn(args_lhs)
fn_rhs = fn(args_rhs)
# the finite-difference reference perturbs parameters off their
# constraint manifold (e.g. scale_tril, simplex probs), so disable
# validation here; jax.grad above traces and is unaffected.
with numpyro.validation_enabled(False):
fn_lhs = fn(args_lhs)
fn_rhs = fn(args_rhs)
# finite diff approximation
expected_grad = (fn_rhs - fn_lhs) / (2.0 * eps)
assert jnp.shape(actual_grad[i]) == jnp.shape(repara_params[i])
Expand Down Expand Up @@ -2044,8 +2049,8 @@ def test_beta_binomial_log_prob(total_count, shape):
@pytest.mark.parametrize("n", [1, 2, 5, 10])
@pytest.mark.parametrize("shape", [(1,), (3, 1), (2, 3, 1)])
def test_beta_negative_binomial_log_prob(n, shape):
concentration0 = np.exp(np.random.normal(size=shape))
concentration1 = np.exp(np.random.normal(size=shape))
concentration0 = 1 + np.exp(np.random.normal(size=shape))
concentration1 = 1 + np.exp(np.random.normal(size=shape))
value = jnp.arange(15)

num_samples = 300000
Expand Down Expand Up @@ -2183,8 +2188,12 @@ def fn(*args):
actual_grad = jax.grad(fn, i)(*params)
args_lhs = [p if j != i else p - eps for j, p in enumerate(params)]
args_rhs = [p if j != i else p + eps for j, p in enumerate(params)]
fn_lhs = fn(*args_lhs)
fn_rhs = fn(*args_rhs)
# the finite-difference reference perturbs parameters off their
# constraint manifold (e.g. scale_tril, simplex probs), so disable
# validation here; jax.grad above traces and is unaffected.
with numpyro.validation_enabled(False):
fn_lhs = fn(*args_lhs)
fn_rhs = fn(*args_rhs)
# finite diff approximation
expected_grad = (fn_rhs - fn_lhs) / (2.0 * eps)
assert jnp.shape(actual_grad) == jnp.shape(params[i])
Expand Down Expand Up @@ -2485,7 +2494,12 @@ def test_distribution_constraints(jax_dist, sp_dist, params, prepend_shape):
# a > 0 and b > 0. Then, make b = a + b.
valid_params[1] += valid_params[0]

assert jax_dist(*oob_params)
# Out-of-bounds params must still construct when validation is off. Use the
# context manager (not the validate_args kwarg) so that compound
# distributions, which build internal distributions without forwarding
# validate_args, also skip validation.
with numpyro.validation_enabled(False):
assert jax_dist(*oob_params)

# Invalid parameter values throw ValueError
if not dependent_constraint and (
Expand Down
Loading