Skip to content

Commit

Permalink
Implement OneHotCategorical distribution (#4357)
Browse files Browse the repository at this point in the history
  • Loading branch information
fritzo authored and apaszke committed Dec 28, 2017
1 parent 3a16978 commit 5c33400
Show file tree
Hide file tree
Showing 6 changed files with 209 additions and 53 deletions.
12 changes: 12 additions & 0 deletions docs/source/distributions.rst
Original file line number Diff line number Diff line change
Expand Up @@ -48,3 +48,15 @@ Probability distributions - torch.distributions

.. autoclass:: Normal
:members:

:hidden:`OneHotCategorical`
~~~~~~~~~~~~~~~~~~~~~~~~~~~

.. autoclass:: OneHotCategorical
:members:

:hidden:`Uniform`
~~~~~~~~~~~~~~~~~~~~~~~

.. autoclass:: Uniform
:members:
8 changes: 7 additions & 1 deletion test/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,12 @@ def to_gpu(obj, type_map={}):
return deepcopy(obj)


def set_rng_seed(seed):
torch.manual_seed(seed)
if TEST_NUMPY:
numpy.random.seed(seed)


@contextlib.contextmanager
def freeze_rng_state():
rng_state = torch.get_rng_state()
Expand Down Expand Up @@ -129,7 +135,7 @@ class TestCase(unittest.TestCase):
maxDiff = None

def setUp(self):
torch.manual_seed(SEED)
set_rng_seed(SEED)

def assertTensorsSlowEqual(self, x, y, prec=None, message=''):
max_err = 0
Expand Down
141 changes: 96 additions & 45 deletions test/test_distributions.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,38 @@
"""
Note [Randomized statistical tests]
-----------------------------------
This note describes how to maintain tests in this file as random sources
change. This file contains two types of randomized tests:
1. The easier type of randomized test are tests that should always pass but are
initialized with random data. If these fail something is wrong, but it's
fine to use a fixed seed by inheriting from common.TestCase.
2. The trickier tests are statistical tests. These tests explicitly call
set_rng_seed(n) and are marked "see Note [Randomized statistical tests]".
These statistical tests have a known positive failure rate
(we set failure_rate=1e-3 by default). We need to balance strength of these
tests with annoyance of false alarms. One way that works is to specifically
set seeds in each of the randomized tests. When a random generator
occasionally changes (as in #4312 vectorizing the Box-Muller sampler), some
of these statistical tests may (rarely) fail. If one fails in this case,
it's fine to increment the seed of the failing test (but you shouldn't need
to increment it more than once; otherwise something is probably actually
wrong).
"""

import math
import unittest
from collections import namedtuple
from itertools import product

import torch
from common import TestCase, run_tests
from common import TestCase, run_tests, set_rng_seed
from torch.autograd import Variable, gradcheck
from torch.distributions import (Bernoulli, Beta, Categorical, Dirichlet,
Exponential, Gamma, Laplace, Normal)
from torch.distributions.uniform import Uniform
Exponential, Gamma, Laplace, Normal,
OneHotCategorical, Uniform)

TEST_NUMPY = True
try:
Expand Down Expand Up @@ -41,6 +65,10 @@
{'probs': Variable(torch.Tensor([[0.1, 0.2, 0.3], [0.5, 0.3, 0.2]]), requires_grad=True)},
{'probs': Variable(torch.Tensor([[1.0, 0.0], [0.0, 1.0]]), requires_grad=True)},
]),
Example(OneHotCategorical, [
{'probs': Variable(torch.Tensor([[0.1, 0.2, 0.3], [0.5, 0.3, 0.2]]), requires_grad=True)},
{'probs': Variable(torch.Tensor([[1.0, 0.0], [0.0, 1.0]]), requires_grad=True)},
]),
Example(Gamma, [
{
'alpha': Variable(torch.exp(torch.randn(2, 3)), requires_grad=True),
Expand Down Expand Up @@ -105,17 +133,15 @@


class TestDistributions(TestCase):
def _set_rng_seed(self, seed=0):
torch.manual_seed(seed)
if TEST_NUMPY:
np.random.seed(seed)

def _gradcheck_log_prob(self, dist_ctor, ctor_params):
# performs gradient checks on log_prob
distribution = dist_ctor(*ctor_params)
s = distribution.sample()

self.assertEqual(s.size(), distribution.log_prob(s).size())
expected_shape = distribution.batch_shape + distribution.event_shape
if not expected_shape:
expected_shape = torch.Size((1,)) # Work around lack of scalars.
self.assertEqual(s.size(), expected_shape)

def apply_fn(*params):
return dist_ctor(*params).log_prob(s)
Expand Down Expand Up @@ -185,10 +211,7 @@ def ref_log_prob(idx, val, log_prob):
self.assertEqual(log_prob, math.log(prob if val else 1 - prob))

self._check_log_prob(Bernoulli(p), ref_log_prob)

def call_rsample():
return Bernoulli(r).rsample()
self.assertRaises(NotImplementedError, call_rsample)
self.assertRaises(NotImplementedError, Bernoulli(r).rsample)

def test_bernoulli_enumerate_support(self):
examples = [
Expand All @@ -212,10 +235,7 @@ def test_categorical_1d(self):
self.assertEqual(Categorical(p).sample((2, 2)).size(), (2, 2))
self.assertEqual(Categorical(p).sample_n(1).size(), (1,))
self._gradcheck_log_prob(Categorical, (p,))

def call_rsample():
return Categorical(p).rsample()
self.assertRaises(NotImplementedError, call_rsample)
self.assertRaises(NotImplementedError, Categorical(p).rsample)

def test_categorical_2d(self):
probabilities = [[0.1, 0.2, 0.3], [0.5, 0.3, 0.2]]
Expand All @@ -228,7 +248,7 @@ def test_categorical_2d(self):
self._gradcheck_log_prob(Categorical, (p,))

# sample check for extreme value of probs
self._set_rng_seed(0)
set_rng_seed(0)
self.assertEqual(Categorical(s).sample(sample_shape=(2,)).data,
torch.Tensor([[0, 1], [0, 1]]))

Expand All @@ -245,6 +265,35 @@ def test_categorical_enumerate_support(self):
]
self._check_enumerate_support(Categorical, examples)

def test_one_hot_categorical_1d(self):
p = Variable(torch.Tensor([0.1, 0.2, 0.3]), requires_grad=True)
self.assertEqual(OneHotCategorical(p).sample().size(), (3,))
self.assertEqual(OneHotCategorical(p).sample((2, 2)).size(), (2, 2, 3))
self.assertEqual(OneHotCategorical(p).sample_n(1).size(), (1, 3))
self._gradcheck_log_prob(OneHotCategorical, (p,))
self.assertRaises(NotImplementedError, OneHotCategorical(p).rsample)

def test_one_hot_categorical_2d(self):
probabilities = [[0.1, 0.2, 0.3], [0.5, 0.3, 0.2]]
probabilities_1 = [[1.0, 0.0], [0.0, 1.0]]
p = Variable(torch.Tensor(probabilities), requires_grad=True)
s = Variable(torch.Tensor(probabilities_1), requires_grad=True)
self.assertEqual(OneHotCategorical(p).sample().size(), (2, 3))
self.assertEqual(OneHotCategorical(p).sample(sample_shape=(3, 4)).size(), (3, 4, 2, 3))
self.assertEqual(OneHotCategorical(p).sample_n(6).size(), (6, 2, 3))
self._gradcheck_log_prob(OneHotCategorical, (p,))

dist = OneHotCategorical(p)
x = dist.sample()
self.assertEqual(dist.log_prob(x), Categorical(p).log_prob(x.max(-1)[1]))

def test_one_hot_categorical_enumerate_support(self):
examples = [
([0.1, 0.2, 0.7], [[1, 0, 0], [0, 1, 0], [0, 0, 1]]),
([[0.1, 0.9], [0.3, 0.7]], [[[1, 0], [1, 0]], [[0, 1], [0, 1]]]),
]
self._check_enumerate_support(OneHotCategorical, examples)

def test_uniform(self):
low = Variable(torch.zeros(5, 5), requires_grad=True)
high = Variable(torch.ones(5, 5) * 3, requires_grad=True)
Expand All @@ -263,7 +312,7 @@ def test_uniform(self):
self.assertEqual(uniform.log_prob(above_high).data[0], -float('inf'), allow_inf=True)
self.assertEqual(uniform.log_prob(below_low).data[0], -float('inf'), allow_inf=True)

self._set_rng_seed(1)
set_rng_seed(1)
self._gradcheck_log_prob(Uniform, (low, high))
self._gradcheck_log_prob(Uniform, (low, 1.0))
self._gradcheck_log_prob(Uniform, (0.0, high))
Expand Down Expand Up @@ -293,7 +342,7 @@ def test_normal(self):
self.assertEqual(Normal(-0.7, 50.0).sample_n(1).size(), (1,))

# sample check for extreme value of mean, std
self._set_rng_seed(1)
set_rng_seed(1)
self.assertEqual(Normal(mean_delta, std_delta).sample(sample_shape=(1, 2)),
torch.Tensor([[[1.0, 0.0], [1.0, 0.0]]]),
prec=1e-4)
Expand Down Expand Up @@ -322,10 +371,9 @@ def ref_log_prob(idx, x, log_prob):

self._check_log_prob(Normal(mean, std), ref_log_prob)

# This is a randomized test.
@unittest.skipIf(not TEST_NUMPY, "Numpy not found")
def test_normal_sample(self):
self._set_rng_seed()
set_rng_seed(0) # see Note [Randomized statistical tests]
for mean, std in product([-1.0, 0.0, 1.0], [0.1, 1.0, 10.0]):
self._check_sampler_sampler(Normal(mean, std),
scipy.stats.norm(loc=mean, scale=std),
Expand Down Expand Up @@ -358,10 +406,9 @@ def ref_log_prob(idx, x, log_prob):

self._check_log_prob(Exponential(rate), ref_log_prob)

# This is a randomized test.
@unittest.skipIf(not TEST_NUMPY, "Numpy not found")
def test_exponential_sample(self):
self._set_rng_seed(1)
set_rng_seed(1) # see Note [Randomized statistical tests]
for rate in [1e-5, 1.0, 10.]:
self._check_sampler_sampler(Exponential(rate),
scipy.stats.expon(scale=1. / rate),
Expand All @@ -382,7 +429,7 @@ def test_laplace(self):
self.assertEqual(Laplace(-0.7, 50.0).sample_n(1).size(), (1,))

# sample check for extreme value of mean, std
self._set_rng_seed()
set_rng_seed(0)
self.assertEqual(Laplace(loc_delta, scale_delta).sample(sample_shape=(1, 2)),
torch.Tensor([[[1.0, 0.0], [1.0, 0.0]]]),
prec=1e-4)
Expand Down Expand Up @@ -410,16 +457,14 @@ def ref_log_prob(idx, x, log_prob):

self._check_log_prob(Laplace(loc, scale), ref_log_prob)

# This is a randomized test.
@unittest.skipIf(not TEST_NUMPY, "Numpy not found")
def test_laplace_sample(self):
self._set_rng_seed(1)
set_rng_seed(1) # see Note [Randomized statistical tests]
for loc, scale in product([-1.0, 0.0, 1.0], [0.1, 1.0, 10.0]):
self._check_sampler_sampler(Laplace(loc, scale),
scipy.stats.laplace(loc=loc, scale=scale),
'Laplace(loc={}, scale={})'.format(loc, scale))

# This is a randomized test.
@unittest.skipIf(not TEST_NUMPY, "Numpy not found")
def test_gamma_shape(self):
alpha = Variable(torch.exp(torch.randn(2, 3)), requires_grad=True)
Expand All @@ -441,19 +486,17 @@ def ref_log_prob(idx, x, log_prob):

self._check_log_prob(Gamma(alpha, beta), ref_log_prob)

# This is a randomized test.
@unittest.skipIf(not TEST_NUMPY, "Numpy not found")
def test_gamma_sample(self):
self._set_rng_seed()
set_rng_seed(0) # see Note [Randomized statistical tests]
for alpha, beta in product([0.1, 1.0, 5.0], [0.1, 1.0, 10.0]):
self._check_sampler_sampler(Gamma(alpha, beta),
scipy.stats.gamma(alpha, scale=1.0 / beta),
'Gamma(alpha={}, beta={})'.format(alpha, beta))

# This is a randomized test.
@unittest.skipIf(not TEST_NUMPY, "Numpy not found")
def test_gamma_sample_grad(self):
self._set_rng_seed(1)
set_rng_seed(1) # see Note [Randomized statistical tests]
num_samples = 100
for alpha in [1e-3, 1e-2, 1e-1, 1e0, 1e1, 1e2, 1e3, 1e4]:
alphas = Variable(torch.Tensor([alpha] * num_samples), requires_grad=True)
Expand Down Expand Up @@ -498,10 +541,9 @@ def test_dirichlet_log_prob(self):
expected_log_prob = scipy.stats.dirichlet.logpdf(x[i].numpy(), alpha.numpy())
self.assertAlmostEqual(actual_log_prob[i], expected_log_prob, places=3)

# This is a randomized test.
@unittest.skipIf(not TEST_NUMPY, "Numpy not found")
def test_dirichlet_sample(self):
self._set_rng_seed()
set_rng_seed(0) # see Note [Randomized statistical tests]
alpha = torch.exp(torch.randn(3))
self._check_sampler_sampler(Dirichlet(alpha),
scipy.stats.dirichlet(alpha.numpy()),
Expand Down Expand Up @@ -531,10 +573,9 @@ def test_beta_log_prob(self):
expected_log_prob = scipy.stats.beta.logpdf(x, alpha, beta)[0]
self.assertAlmostEqual(actual_log_prob, expected_log_prob, places=3, allow_inf=True)

# This is a randomized test.
@unittest.skipIf(not TEST_NUMPY, "Numpy not found")
def test_beta_sample(self):
self._set_rng_seed(1)
set_rng_seed(1) # see Note [Randomized statistical tests]
for alpha, beta in product([0.1, 1.0, 10.0], [0.1, 1.0, 10.0]):
self._check_sampler_sampler(Beta(alpha, beta),
scipy.stats.beta(alpha, beta),
Expand All @@ -544,10 +585,9 @@ def test_beta_sample(self):
x = Beta(Tensor([1e-6]), Tensor([1e-6])).sample()[0]
self.assertTrue(np.isfinite(x) and x > 0, 'Invalid Beta.sample(): {}'.format(x))

# This is a randomized test.
@unittest.skipIf(not TEST_NUMPY, "Numpy not found")
def test_beta_sample_grad(self):
self._set_rng_seed()
set_rng_seed(0) # see Note [Randomized statistical tests]
num_samples = 20
for alpha, beta in product([1e-2, 1e0, 1e2], [1e-2, 1e0, 1e2]):
alphas = Variable(torch.Tensor([alpha] * num_samples), requires_grad=True)
Expand Down Expand Up @@ -649,6 +689,7 @@ def test_invalid_parameter_broadcasting(self):

class TestDistributionShapes(TestCase):
def setUp(self):
super(TestCase, self).setUp()
self.scalar_sample = 1
self.tensor_sample_1 = torch.ones(3, 2)
self.tensor_sample_2 = torch.ones(3, 2, 3)
Expand Down Expand Up @@ -705,13 +746,23 @@ def test_beta_shape_tensor_params(self):
self.assertRaises(ValueError, dist.log_prob, self.tensor_sample_2)

def test_categorical_shape(self):
categorical = Categorical(torch.Tensor([[0.6, 0.3], [0.6, 0.3], [0.6, 0.3]]))
self.assertEqual(categorical._batch_shape, torch.Size((3,)))
self.assertEqual(categorical._event_shape, torch.Size(()))
self.assertEqual(categorical.sample().size(), torch.Size((3,)))
self.assertEqual(categorical.sample((3, 2)).size(), torch.Size((3, 2, 3,)))
self.assertRaises(ValueError, categorical.log_prob, self.tensor_sample_1)
self.assertEqual(categorical.log_prob(self.tensor_sample_2).size(), torch.Size((3, 2, 3)))
dist = Categorical(torch.Tensor([[0.6, 0.3], [0.6, 0.3], [0.6, 0.3]]))
self.assertEqual(dist._batch_shape, torch.Size((3,)))
self.assertEqual(dist._event_shape, torch.Size(()))
self.assertEqual(dist.sample().size(), torch.Size((3,)))
self.assertEqual(dist.sample((3, 2)).size(), torch.Size((3, 2, 3,)))
self.assertRaises(ValueError, dist.log_prob, self.tensor_sample_1)
self.assertEqual(dist.log_prob(self.tensor_sample_2).size(), torch.Size((3, 2, 3)))

def test_one_hot_categorical_shape(self):
dist = OneHotCategorical(torch.Tensor([[0.6, 0.3], [0.6, 0.3], [0.6, 0.3]]))
self.assertEqual(dist._batch_shape, torch.Size((3,)))
self.assertEqual(dist._event_shape, torch.Size((2,)))
self.assertEqual(dist.sample().size(), torch.Size((3, 2)))
self.assertEqual(dist.sample((3, 2)).size(), torch.Size((3, 2, 3, 2)))
self.assertEqual(dist.log_prob(self.tensor_sample_1).size(), torch.Size((3,)))
self.assertRaises(ValueError, dist.log_prob, self.tensor_sample_2)
self.assertEqual(dist.log_prob(dist.enumerate_support()).size(), torch.Size((2, 3)))

def test_dirichlet_shape(self):
dist = Dirichlet(torch.Tensor([[0.6, 0.3], [1.6, 1.3], [2.6, 2.3]]))
Expand Down
23 changes: 18 additions & 5 deletions torch/distributions/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,10 +35,23 @@
from .categorical import Categorical
from .dirichlet import Dirichlet
from .distribution import Distribution
from .gamma import Gamma
from .normal import Normal
from .exponential import Exponential
from .gamma import Gamma
from .laplace import Laplace


__all__ = ['Distribution', 'Bernoulli', 'Beta', 'Categorical', 'Dirichlet', 'Gamma', 'Normal', 'Exponential', 'Laplace']
from .normal import Normal
from .one_hot_categorical import OneHotCategorical
from .uniform import Uniform

__all__ = [
'Bernoulli',
'Beta',
'Categorical',
'Dirichlet',
'Distribution',
'Exponential',
'Gamma',
'Laplace',
'Normal',
'OneHotCategorical',
'Uniform',
]
Loading

0 comments on commit 5c33400

Please sign in to comment.