Skip to content

Commit

Permalink
Implementation of the Fisher-Snedecor Distribution (#4706)
Browse files Browse the repository at this point in the history
  • Loading branch information
vishwakftw authored and apaszke committed Jan 20, 2018
1 parent 8593c6f commit f033dd6
Show file tree
Hide file tree
Showing 4 changed files with 139 additions and 9 deletions.
6 changes: 6 additions & 0 deletions docs/source/distributions.rst
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,12 @@ Probability distributions - torch.distributions
.. autoclass:: Exponential
:members:

:hidden:`FisherSnedecor`
~~~~~~~~~~~~~~~~~~~~~~~~~

.. autoclass:: FisherSnedecor
:members:

:hidden:`Gamma`
~~~~~~~~~~~~~~~~~~~~~~~

Expand Down
81 changes: 72 additions & 9 deletions test/test_distributions.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,9 +30,10 @@
import torch
from common import TestCase, run_tests, set_rng_seed
from torch.autograd import Variable, grad, gradcheck
from torch.distributions import Distribution
from torch.distributions import (Bernoulli, Beta, Binomial, Categorical, Cauchy, Chi2,
Dirichlet, Exponential, Gamma, Geometric, Gumbel, Laplace,
Normal, OneHotCategorical, Multinomial, Pareto,
Dirichlet, Exponential, FisherSnedecor, Gamma, Geometric,
Gumbel, Laplace, Normal, OneHotCategorical, Multinomial, Pareto,
StudentT, Uniform, kl_divergence)
from torch.distributions.dirichlet import _Dirichlet_backward
from torch.distributions.constraints import Constraint, is_dependent
Expand Down Expand Up @@ -114,6 +115,20 @@ def pairwise(Dist, *params):
{'rate': Variable(torch.randn(5, 5).abs(), requires_grad=True)},
{'rate': Variable(torch.randn(1).abs(), requires_grad=True)},
]),
Example(FisherSnedecor, [
{
'df1': Variable(torch.randn(5, 5).abs(), requires_grad=True),
'df2': Variable(torch.randn(5, 5).abs(), requires_grad=True),
},
{
'df1': Variable(torch.randn(1).abs(), requires_grad=True),
'df2': Variable(torch.randn(1).abs(), requires_grad=True),
},
{
'df1': Variable(torch.Tensor([1.0])),
'df2': 1.0,
}
]),
Example(Gamma, [
{
'concentration': Variable(torch.exp(torch.randn(2, 3)), requires_grad=True),
Expand Down Expand Up @@ -277,6 +292,13 @@ def test_enumerate_support_type(self):
except NotImplementedError:
pass

def test_has_examples(self):
distributions_with_examples = set(e.Dist for e in EXAMPLES)
for Dist in globals().values():
if isinstance(Dist, type) and issubclass(Dist, Distribution) and Dist is not Distribution:
self.assertIn(Dist, distributions_with_examples,
"Please add {} to the EXAMPLES list in test_distributions.py".format(Dist.__name__))

def test_bernoulli(self):
p = Variable(torch.Tensor([0.7, 0.2, 0.4]), requires_grad=True)
r = Variable(torch.Tensor([0.3]), requires_grad=True)
Expand Down Expand Up @@ -730,7 +752,7 @@ def test_gamma_sample(self):
'Gamma(concentration={}, rate={})'.format(alpha, beta))

@unittest.skipIf(not TEST_NUMPY, "NumPy not found")
def test_pareto_shape(self):
def test_pareto(self):
scale = Variable(torch.randn(2, 3).abs(), requires_grad=True)
alpha = Variable(torch.randn(2, 3).abs(), requires_grad=True)
scale_1d = torch.randn(1).abs()
Expand Down Expand Up @@ -759,7 +781,7 @@ def test_pareto_sample(self):
'Pareto(scale={}, alpha={})'.format(scale, alpha))

@unittest.skipIf(not TEST_NUMPY, "NumPy not found")
def test_gumbel_shape(self):
def test_gumbel(self):
loc = Variable(torch.randn(2, 3), requires_grad=True)
scale = Variable(torch.randn(2, 3).abs(), requires_grad=True)
loc_1d = torch.randn(1)
Expand Down Expand Up @@ -787,6 +809,35 @@ def test_gumbel_sample(self):
scipy.stats.gumbel_r(loc=loc, scale=scale),
'Gumbel(loc={}, scale={})'.format(loc, scale))

@unittest.skipIf(not TEST_NUMPY, "NumPy not found")
def test_fishersnedecor(self):
df1 = Variable(torch.randn(2, 3).abs(), requires_grad=True)
df2 = Variable(torch.randn(2, 3).abs(), requires_grad=True)
df1_1d = torch.randn(1).abs()
df2_1d = torch.randn(1).abs()
self.assertEqual(FisherSnedecor(df1, df2).sample().size(), (2, 3))
self.assertEqual(FisherSnedecor(df1, df2).sample_n(5).size(), (5, 2, 3))
self.assertEqual(FisherSnedecor(df1_1d, df2_1d).sample().size(), (1,))
self.assertEqual(FisherSnedecor(df1_1d, df2_1d).sample_n(1).size(), (1, 1))
self.assertEqual(FisherSnedecor(1.0, 1.0).sample().size(), (1,))
self.assertEqual(FisherSnedecor(1.0, 1.0).sample_n(1).size(), (1,))

def ref_log_prob(idx, x, log_prob):
f1 = df1.data.view(-1)[idx]
f2 = df2.data.view(-1)[idx]
expected = scipy.stats.f.logpdf(x, f1, f2)
self.assertAlmostEqual(log_prob, expected, places=3)

self._check_log_prob(FisherSnedecor(df1, df2), ref_log_prob)

@unittest.skipIf(not TEST_NUMPY, "NumPy not found")
def test_fishersnedecor_sample(self):
set_rng_seed(1) # see note [Randomized statistical tests]
for df1, df2 in product([0.1, 0.5, 1.0, 5.0, 10.0], [0.1, 0.5, 1.0, 5.0, 10.0]):
self._check_sampler_sampler(FisherSnedecor(df1, df2),
scipy.stats.f(df1, df2),
'FisherSnedecor(loc={}, scale={})'.format(df1, df2))

@unittest.skipIf(not TEST_NUMPY, "NumPy not found")
def test_chi2_shape(self):
df = Variable(torch.exp(torch.randn(2, 3)), requires_grad=True)
Expand Down Expand Up @@ -931,6 +982,18 @@ def test_valid_parameter_broadcasting(self):
(1, 2)),
(Normal(loc=torch.Tensor([0]), scale=torch.Tensor([[1]])),
(1, 1)),
(FisherSnedecor(df1=torch.Tensor([1, 1]), df2=1),
(2,)),
(FisherSnedecor(df1=1, df2=torch.Tensor([1, 1])),
(2,)),
(FisherSnedecor(df1=torch.Tensor([1, 1]), df2=torch.Tensor([1])),
(2,)),
(FisherSnedecor(df1=torch.Tensor([1, 1]), df2=torch.Tensor([[1], [1]])),
(2, 2)),
(FisherSnedecor(df1=torch.Tensor([1, 1]), df2=torch.Tensor([[1]])),
(1, 2)),
(FisherSnedecor(df1=torch.Tensor([1]), df2=torch.Tensor([[1]])),
(1, 1)),
(Gamma(concentration=torch.Tensor([1, 1]), rate=1),
(2,)),
(Gamma(concentration=1, rate=torch.Tensor([1, 1])),
Expand Down Expand Up @@ -1010,6 +1073,10 @@ def test_invalid_parameter_broadcasting(self):
'loc': torch.Tensor([[[0, 0, 0], [0, 0, 0]]]),
'scale': torch.Tensor([1, 1])
}),
(FisherSnedecor, {
'df1': torch.Tensor([1, 1]),
'df2': torch.Tensor([1, 1, 1]),
}),
(Gumbel, {
'loc': torch.Tensor([[0, 0]]),
'scale': torch.Tensor([1, 1, 1, 1])
Expand All @@ -1030,10 +1097,6 @@ def test_invalid_parameter_broadcasting(self):
'scale': torch.Tensor([1, 1]),
'alpha': torch.Tensor([1, 1, 1])
}),
(Pareto, {
'scale': torch.Tensor([1, 1]),
'alpha': torch.Tensor([1, 1, 1])
}),
(StudentT, {
'df': torch.Tensor([1, 1]),
'scale': torch.Tensor([1, 1, 1])
Expand Down Expand Up @@ -1742,7 +1805,7 @@ def test_entropy_monte_carlo(self):
actual = dist.entropy()
except NotImplementedError:
continue
x = dist.sample(sample_shape=(20000,))
x = dist.sample(sample_shape=(50000,))
expected = -dist.log_prob(x).mean(0)
if isinstance(actual, Variable):
actual = actual.data
Expand Down
2 changes: 2 additions & 0 deletions torch/distributions/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
from .dirichlet import Dirichlet
from .distribution import Distribution
from .exponential import Exponential
from .fishersnedecor import FisherSnedecor
from .gamma import Gamma
from .geometric import Geometric
from .gumbel import Gumbel
Expand All @@ -61,6 +62,7 @@
'Dirichlet',
'Distribution',
'Exponential',
'FisherSnedecor',
'Gamma',
'Geometric',
'Gumbel',
Expand Down
59 changes: 59 additions & 0 deletions torch/distributions/fishersnedecor.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
from numbers import Number
import torch
import math
from torch.distributions import constraints
from torch.distributions.distribution import Distribution
from torch.distributions.gamma import Gamma
from torch.distributions.utils import broadcast_all, _finfo


class FisherSnedecor(Distribution):
r"""
Creates a Fisher-Snedecor distribution parameterized by `df1` and `df2`.
Example::
>>> m = FisherSnedecor(torch.Tensor([1.0]), torch.Tensor([2.0]))
>>> m.sample() # Fisher-Snedecor-distributed with df1=1 and df2=2
0.2453
[torch.FloatTensor of size 1]
Args:
df1 (float or Tensor or Variable): degrees of freedom parameter 1
df2 (float or Tensor or Variable): degrees of freedom parameter 2
"""
params = {'df1': constraints.positive, 'df2': constraints.positive}
support = constraints.positive
has_rsample = True

def __init__(self, df1, df2):
self.df1, self.df2 = broadcast_all(df1, df2)
self._gamma1 = Gamma(self.df1 * 0.5, self.df1)
self._gamma2 = Gamma(self.df2 * 0.5, self.df2)

if isinstance(df1, Number) and isinstance(df2, Number):
batch_shape = torch.Size()
else:
batch_shape = self.df1.size()
super(FisherSnedecor, self).__init__(batch_shape)

def rsample(self, sample_shape=torch.Size(())):
shape = self._extended_shape(sample_shape)
# X1 ~ Gamma(df1 / 2, 1 / df1), X2 ~ Gamma(df2 / 2, 1 / df2)
# Y = df2 * df1 * X1 / (df1 * df2 * X2) = X1 / X2 ~ F(df1, df2)
X1 = self._gamma1.rsample(sample_shape).view(shape)
X2 = self._gamma2.rsample(sample_shape).view(shape)
X2.clamp_(min=_finfo(X2).tiny)
Y = X1 / X2
Y.clamp_(min=_finfo(X2).tiny)
return Y

def log_prob(self, value):
self._validate_log_prob_arg(value)
ct1 = self.df1 * 0.5
ct2 = self.df2 * 0.5
ct3 = self.df1 / self.df2
t1 = (ct1 + ct2).lgamma() - ct1.lgamma() - ct2.lgamma()
t2 = ct1 * ct3.log() + (ct1 - 1) * torch.log(value)
t3 = (ct1 + ct2) * torch.log1p(ct3 * value)
return t1 + t2 - t3

0 comments on commit f033dd6

Please sign in to comment.