-
Notifications
You must be signed in to change notification settings - Fork 288
Fix enable_validation #2201
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Fix enable_validation #2201
Changes from all commits
7d1cb47
c3f06a8
c66e2fb
1c35a7c
955d50f
f853fc8
d0d3e4b
2eb84d4
cde1511
e6550ea
572dfcc
c94bfb0
b7cb852
8a75729
e03f20e
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -2,7 +2,9 @@ | |
| # SPDX-License-Identifier: Apache-2.0 | ||
|
|
||
| from pyroapi import pyro_backend | ||
| from pyroapi.dispatch import distributions as dist, infer, ops, pyro | ||
| from pyroapi.tests import * # noqa F401 | ||
| from pyroapi.tests.test_svi import assert_ok | ||
| import pytest | ||
|
|
||
| from numpyro.infer import RenyiELBO, Trace_ELBO, TraceMeanField_ELBO | ||
|
|
@@ -26,3 +28,29 @@ | |
| def backend(): | ||
| with pyro_backend("numpy"): | ||
| yield | ||
|
|
||
|
|
||
| # pyroapi's test_constraints inits the simplex param `q` with an unnormalized | ||
| # exp(randn(3)); use a valid simplex so it passes with validation enabled. | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
To provide a valid simplex (as suggested) we therefore redefine |
||
| @pytest.mark.parametrize("jit", [False, True], ids=["py", "jit"]) | ||
| def test_constraints(backend, jit): | ||
| data = ops.tensor(0.5) | ||
|
|
||
| def model(): | ||
| locs = pyro.param("locs", ops.randn(3), constraint=dist.constraints.real) | ||
| scales = pyro.param( | ||
| "scales", ops.exp(ops.randn(3)), constraint=dist.constraints.positive | ||
| ) | ||
| p = ops.tensor([0.5, 0.3, 0.2]) | ||
| x = pyro.sample("x", dist.Categorical(p)) | ||
| pyro.sample("obs", dist.Normal(locs[x], scales[x]), obs=data) | ||
|
|
||
| def guide(): | ||
| q = pyro.param( | ||
| "q", ops.tensor([0.4, 0.3, 0.3]), constraint=dist.constraints.simplex | ||
| ) | ||
| pyro.sample("x", dist.Categorical(q)) | ||
|
|
||
| Elbo = infer.JitTrace_ELBO if jit else infer.Trace_ELBO | ||
| elbo = Elbo(ignore_jit_warnings=True) | ||
| assert_ok(model, guide, elbo) | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice solution to address NaN mu.