Add Cumulative Distribution Function, Inverse CDF methods to Distributions#122
Add Cumulative Distribution Function, Inverse CDF methods to Distributions#122vishwakftw wants to merge 9 commits into
Conversation
1. Cauchy 2. Exponential 3. Laplace (Only CDF) 4. Pareto
fritzo
left a comment
There was a problem hiding this comment.
Looks good! I only have minor comments about testing.
| set_rng_seed(0) # see Note [Randomized statistical tests] | ||
| for pytorch_dist, scipy_dist in self.distribution_pairs: | ||
| samples = pytorch_dist.sample((5,)) | ||
| try: |
There was a problem hiding this comment.
It's safer to enclose as little as needed in a try-except. Could you refactor to
try:
cdf = pytorch_dist.cdf(samples)
except NotImplementedError:
continue
self.assertEqual(cdf, scipy_dist.cdf(samples), message=pytorch_dist)There was a problem hiding this comment.
Ah, yes. I saw the discussion in TruncatedNormal. I will modify it accordingly.
| set_rng_seed(0) # see Note [Randomized statistical tests] | ||
| for pytorch_dist, scipy_dist in self.distribution_pairs: | ||
| samples = Variable(torch.rand((5,) + pytorch_dist.batch_shape)) | ||
| try: |
There was a problem hiding this comment.
ditto, enclose as little as possible in try-except
| self._validate_log_prob_arg(value) | ||
| return -torch.log(2 * self.scale) - torch.abs(value - self.loc) / self.scale | ||
|
|
||
| def cdf(self, value): |
There was a problem hiding this comment.
Laplace's .cdf is a piecewise function. I was doubtful about adding an inverse, and later realized that the inverse could be piecewise as well. Will update this too..
| self.assertEqual(pytorch_dist.variance, scipy_dist.var(), allow_inf=True, message=pytorch_dist) | ||
| self.assertEqual(pytorch_dist.stddev, scipy_dist.var() ** 0.5, message=pytorch_dist) | ||
|
|
||
| def test_cdf(self): |
There was a problem hiding this comment.
It would be nice to have an additional test that did not rely on scipy, e.g.
class TestDistributions(TestCase):
def test_cdf_icdf(self):
for Dist, params in EXAMPLES:
for i, param in enumerate(params):
dist = Dist(**param)
samples = dist.sample(sample_shape=(20,))
try:
cdf = dist.cdf(samples)
actual = dist.icdf(cdf)
except NotImplementedError:
continue
self.assertEqual(actual, samples, message='{} example {}/{}, icdf(cdf(x)) != x')or you could get even fancier by using grad() like
x = dist.sample(sample_shape=(20,))
expected_pdf = dist.log_prob(x).exp()
actual_pdf = grad(dist.cdf(x).sum(), [x])[0]
self.assertEqual(actual_pdf, expected_pdf)|
|
Minor: 1. Convert Pareto and Gumbel to TransformedDistribution 2. Add .cdf and .icdf for Uniform 3. Temporarily remove .cdf from Laplace
|
Three tests fail:
|
fritzo
left a comment
There was a problem hiding this comment.
Looks great! Just one minor comment about eps and tiny, then it's ready to send upstream.
| z = (value - self.loc) / self.scale | ||
| return -(self.scale.log() + z + torch.exp(-z)) | ||
| base_dist = Uniform(torch.zeros_like(self.loc), 1) | ||
| transforms = [ExpTransform().inv, AffineTransform(loc=0, scale=-1), |
| self._validate_log_prob_arg(value) | ||
| z = (value - self.loc) / self.scale | ||
| return -(self.scale.log() + z + torch.exp(-z)) | ||
| base_dist = Uniform(torch.zeros_like(self.loc), 1) |
There was a problem hiding this comment.
Maybe we should avoid infinity like
finfo = _finfo(self.loc)
base_dist = Uniform(self.loc.new([finfo.tiny]).expand_as(self.loc), 1 - finfo.eps) | Computes the inverse cumulative distribution function using transform(s) and computing | ||
| the score of the base distribution | ||
| """ | ||
| self.base_dist._validate_log_prob_arg(value) |
There was a problem hiding this comment.
I believe the base_dist.icdf() should call _validate_log_prob_arg(value) internally on the following line. Do you think it's worth having the extra check here? I'd be happy either way.
|
@vishwakftw Let me know if you want any help with the failing tests. I might have time today or tomorrow to help debug. |
|
@fritzo I have fixed the shaping failures with the Gumbel distribution. There is one issue however. Some how the |
|
I also tried implementing def __init__(self, loc, scale):
self.loc, self.scale = broadcast_all(loc, scale)
finfo = _finfo(self.loc)
if isinstance(loc, Number) and isinstance(scale, Number):
base_dist = Uniform(finfo.eps - 1, 1)
else:
base_dist = Uniform(self.loc.new([finfo.eps]).expand_as(self.loc) - 1, 1)
transforms = [AbsTransform(), AffineTransform(loc=1, scale=-1), ExpTransform().inv,
AffineTransform(loc=self.loc, scale=self.scale)]
super(Laplace, self).__init__(base_dist, transforms)I believe the sampling requires a SignTransform = AbsTransform / identity_transform |
|
Great. I am sending this upstream now!! |
Work in parallel with PR #121.
cc: @fritzo @alicanb