Skip to content

Commit

Permalink
Supporting logits as parameters in Bernoulli and Categorical (#4448)
Browse files Browse the repository at this point in the history
* Supporting logits as parameters in Bernoulli and Categorical

* address comments

* fix lint

* modify binary_cross_entropy_with_logits

* address comments

* add descriptor for lazy attributes

* address comments
  • Loading branch information
neerajprad authored and soumith committed Jan 5, 2018
1 parent 0afcc8e commit 408c84d
Show file tree
Hide file tree
Showing 6 changed files with 248 additions and 28 deletions.
115 changes: 115 additions & 0 deletions test/test_distributions.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,8 @@
Dirichlet, Exponential, Gamma, Laplace,
Normal, OneHotCategorical, Pareto, Uniform)
from torch.distributions.constraints import Constraint, is_dependent
from torch.distributions.utils import _get_clamping_buffer


TEST_NUMPY = True
try:
Expand Down Expand Up @@ -240,8 +242,14 @@ def ref_log_prob(idx, val, log_prob):
self.assertEqual(log_prob, math.log(prob if val else 1 - prob))

self._check_log_prob(Bernoulli(p), ref_log_prob)
self._check_log_prob(Bernoulli(logits=p.log() - (-p).log1p()), ref_log_prob)
self.assertRaises(NotImplementedError, Bernoulli(r).rsample)

# check entropy computation
self.assertEqual(Bernoulli(p).entropy().data, torch.Tensor([0.6108, 0.5004, 0.6730]), prec=1e-4)
self.assertEqual(Bernoulli(torch.Tensor([0.0])).entropy(), torch.Tensor([0.0]))
self.assertEqual(Bernoulli(s).entropy(), torch.Tensor([0.6108]), prec=1e-4)

def test_bernoulli_enumerate_support(self):
examples = [
([0.1], [[0], [1]]),
Expand Down Expand Up @@ -286,6 +294,11 @@ def ref_log_prob(idx, val, log_prob):
self.assertEqual(log_prob, math.log(sample_prob))

self._check_log_prob(Categorical(p), ref_log_prob)
self._check_log_prob(Categorical(logits=p.log()), ref_log_prob)

# check entropy computation
self.assertEqual(Categorical(p).entropy().data, torch.Tensor([1.0114, 1.0297]), prec=1e-4)
self.assertEqual(Categorical(s).entropy().data, torch.Tensor([0.0, 0.0]))

def test_categorical_enumerate_support(self):
examples = [
Expand Down Expand Up @@ -1087,5 +1100,107 @@ def test_support_contains(self):
self.assertTrue(constraint.check(value).all(), msg=message)


class TestNumericalStability(TestCase):
def _test_pdf_score(self,
dist_class,
x,
expected_value,
probs=None,
logits=None,
expected_gradient=None,
prec=1e-5):
if probs is not None:
p = Variable(probs, requires_grad=True)
dist = dist_class(p)
else:
p = Variable(logits, requires_grad=True)
dist = dist_class(logits=p)
log_pdf = dist.log_prob(Variable(x))
log_pdf.sum().backward()
self.assertEqual(log_pdf.data,
expected_value,
prec=prec,
message='Failed for tensor type: {}. Expected = {}, Actual = {}'
.format(type(x), expected_value, log_pdf.data))
if expected_gradient is not None:
self.assertEqual(p.grad.data,
expected_gradient,
prec=prec,
message='Failed for tensor type: {}. Expected = {}, Actual = {}'
.format(type(x), expected_gradient, p.grad.data))

def test_bernoulli_gradient(self):
for tensor_type in [torch.FloatTensor, torch.DoubleTensor]:
self._test_pdf_score(dist_class=Bernoulli,
probs=tensor_type([0]),
x=tensor_type([0]),
expected_value=tensor_type([0]),
expected_gradient=tensor_type([0]))

self._test_pdf_score(dist_class=Bernoulli,
probs=tensor_type([0]),
x=tensor_type([1]),
expected_value=tensor_type([_get_clamping_buffer(tensor_type([]))]).log(),
expected_gradient=tensor_type([0]))

self._test_pdf_score(dist_class=Bernoulli,
probs=tensor_type([1e-4]),
x=tensor_type([1]),
expected_value=tensor_type([math.log(1e-4)]),
expected_gradient=tensor_type([10000]))

# Lower precision due to:
# >>> 1 / (1 - torch.FloatTensor([0.9999]))
# 9998.3408
# [torch.FloatTensor of size 1]
self._test_pdf_score(dist_class=Bernoulli,
probs=tensor_type([1 - 1e-4]),
x=tensor_type([0]),
expected_value=tensor_type([math.log(1e-4)]),
expected_gradient=tensor_type([-10000]),
prec=2)

self._test_pdf_score(dist_class=Bernoulli,
logits=tensor_type([math.log(9999)]),
x=tensor_type([0]),
expected_value=tensor_type([math.log(1e-4)]),
expected_gradient=tensor_type([-1]),
prec=1e-3)

def test_bernoulli_with_logits_underflow(self):
for tensor_type, lim in ([(torch.FloatTensor, -1e38),
(torch.DoubleTensor, -1e308)]):
self._test_pdf_score(dist_class=Bernoulli,
logits=tensor_type([lim]),
x=tensor_type([0]),
expected_value=tensor_type([0]),
expected_gradient=tensor_type([0]))

def test_bernoulli_with_logits_overflow(self):
for tensor_type, lim in ([(torch.FloatTensor, 1e38),
(torch.DoubleTensor, 1e308)]):
self._test_pdf_score(dist_class=Bernoulli,
logits=tensor_type([lim]),
x=tensor_type([1]),
expected_value=tensor_type([0]),
expected_gradient=tensor_type([0]))

def test_categorical_log_prob(self):
for tensor_type in ([torch.FloatTensor, torch.DoubleTensor]):
p = Variable(tensor_type([0, 1]), requires_grad=True)
categorical = OneHotCategorical(p)
log_pdf = categorical.log_prob(Variable(tensor_type([0, 1])))
self.assertEqual(log_pdf.data[0], 0)

def test_categorical_log_prob_with_logits(self):
for tensor_type in ([torch.FloatTensor, torch.DoubleTensor]):
p = Variable(tensor_type([-float('inf'), 0]), requires_grad=True)
categorical = OneHotCategorical(logits=p)
log_pdf_prob_1 = categorical.log_prob(Variable(tensor_type([0, 1])))
self.assertEqual(log_pdf_prob_1.data[0], 0)
log_pdf_prob_0 = categorical.log_prob(Variable(tensor_type([1, 0])))
self.assertEqual(log_pdf_prob_0.data[0], -float('inf'), allow_inf=True)


if __name__ == '__main__':
run_tests()
38 changes: 23 additions & 15 deletions torch/distributions/bernoulli.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,8 @@
from torch.autograd import Variable
from torch.distributions import constraints
from torch.distributions.distribution import Distribution
from torch.distributions.utils import broadcast_all
from torch.distributions.utils import broadcast_all, probs_to_logits, logits_to_probs, lazy_property
from torch.nn.functional import binary_cross_entropy_with_logits


class Bernoulli(Distribution):
Expand All @@ -28,32 +29,39 @@ class Bernoulli(Distribution):
support = constraints.boolean
has_enumerate_support = True

def __init__(self, probs):
self.probs, = broadcast_all(probs)
if isinstance(probs, Number):
def __init__(self, probs=None, logits=None):
if (probs is None) == (logits is None):
raise ValueError("Either `probs` or `logits` must be specified, but not both.")
if probs is not None:
self.probs, = broadcast_all(probs)
else:
self.logits, = broadcast_all(logits)
probs_or_logits = probs if probs is not None else logits
if isinstance(probs_or_logits, Number):
batch_shape = torch.Size()
else:
batch_shape = self.probs.size()
batch_shape = probs_or_logits.size()
super(Bernoulli, self).__init__(batch_shape)

@lazy_property
def logits(self):
return probs_to_logits(self.probs, is_binary=True)

@lazy_property
def probs(self):
return logits_to_probs(self.logits, is_binary=True)

def sample(self, sample_shape=torch.Size()):
shape = self._extended_shape(sample_shape)
return torch.bernoulli(self.probs.expand(shape))

def log_prob(self, value):
self._validate_log_prob_arg(value)
param_shape = value.size()
probs = self.probs.expand(param_shape)
# compute the log probabilities for 0 and 1
log_pmf = (torch.stack([1 - probs, probs], dim=-1)).log()
# evaluate using the values
return log_pmf.gather(-1, value.unsqueeze(-1).long()).squeeze(-1)
logits, value = broadcast_all(self.logits, value)
return -binary_cross_entropy_with_logits(logits, value, reduce=False)

def entropy(self):
p = torch.stack([self.probs, 1.0 - self.probs])
p_log_p = torch.log(p) * p
p_log_p[p == 0] = 0
return -p_log_p.sum(0)
return binary_cross_entropy_with_logits(self.logits, self.probs, reduce=False)

def enumerate_support(self):
values = torch.arange(2).long()
Expand Down
26 changes: 19 additions & 7 deletions torch/distributions/categorical.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from torch.autograd import Variable
from torch.distributions import constraints
from torch.distributions.distribution import Distribution
from torch.distributions.utils import probs_to_logits, logits_to_probs, log_sum_exp, lazy_property


class Categorical(Distribution):
Expand Down Expand Up @@ -33,15 +34,28 @@ class Categorical(Distribution):
params = {'probs': constraints.simplex}
has_enumerate_support = True

def __init__(self, probs):
self.probs = probs
batch_shape = self.probs.size()[:-1]
def __init__(self, probs=None, logits=None):
if (probs is None) == (logits is None):
raise ValueError("Either `probs` or `logits` must be specified, but not both.")
if probs is not None:
self.probs = probs / probs.sum(-1, keepdim=True)
else:
self.logits = logits - log_sum_exp(logits)
batch_shape = self.probs.size()[:-1] if probs is not None else self.logits.size()[:-1]
super(Categorical, self).__init__(batch_shape)

@constraints.dependent_property
def support(self):
return constraints.integer_interval(0, self.probs.size()[-1] - 1)

@lazy_property
def logits(self):
return probs_to_logits(self.probs)

@lazy_property
def probs(self):
return logits_to_probs(self.logits)

def sample(self, sample_shape=torch.Size()):
num_events = self.probs.size()[-1]
sample_shape = self._extended_shape(sample_shape)
Expand All @@ -54,13 +68,11 @@ def sample(self, sample_shape=torch.Size()):
def log_prob(self, value):
self._validate_log_prob_arg(value)
param_shape = value.size() + self.probs.size()[-1:]
log_pmf = (self.probs / self.probs.sum(-1, keepdim=True)).log()
log_pmf = log_pmf.expand(param_shape)
log_pmf = self.logits.expand(param_shape)
return log_pmf.gather(-1, value.unsqueeze(-1).long()).squeeze(-1)

def entropy(self):
p_log_p = torch.log(self.probs) * self.probs
p_log_p[self.probs == 0] = 0
p_log_p = self.logits * self.probs
return -p_log_p.sum(-1)

def enumerate_support(self):
Expand Down
8 changes: 4 additions & 4 deletions torch/distributions/one_hot_categorical.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,10 +30,10 @@ class OneHotCategorical(Distribution):
support = constraints.simplex
has_enumerate_support = True

def __init__(self, probs):
self._categorical = Categorical(probs)
batch_shape = probs.size()[:-1]
event_shape = probs.size()[-1:]
def __init__(self, probs=None, logits=None):
self._categorical = Categorical(probs, logits)
batch_shape = self._categorical.probs.size()[:-1]
event_shape = self._categorical.probs.size()[-1:]
super(OneHotCategorical, self).__init__(batch_shape, event_shape)

def sample(self, sample_shape=torch.Size()):
Expand Down
79 changes: 79 additions & 0 deletions torch/distributions/utils.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
from functools import update_wrapper
from numbers import Number

import torch
from torch.autograd import Variable
import torch.nn.functional as F


def expand_n(v, n):
Expand Down Expand Up @@ -66,3 +68,80 @@ def broadcast_all(*values):
for idx in scalar_idxs:
values[idx] = torch.Tensor([values[idx]])
return values


def _get_clamping_buffer(tensor):
clamp_eps = 1e-6
if isinstance(tensor, Variable):
tensor = tensor.data
if isinstance(tensor, (torch.DoubleTensor, torch.cuda.DoubleTensor)):
clamp_eps = 1e-15
return clamp_eps


def softmax(tensor):
"""
Wrapper around softmax to make it work with both Tensors and Variables.
TODO: Remove once https://github.com/pytorch/pytorch/issues/2633 is resolved.
"""
if not isinstance(tensor, Variable):
return F.softmax(Variable(tensor), -1).data
return F.softmax(tensor, -1)


def log_sum_exp(tensor, keepdim=True):
"""
Numerically stable implementation for the `LogSumExp` operation. The
summing is done along the last dimension.
Args:
tensor (torch.Tensor or torch.autograd.Variable)
keepdim (Boolean): Whether to retain the last dimension on summing.
"""
max_val = tensor.max(dim=-1, keepdim=True)[0]
return max_val + (tensor - max_val).exp().sum(dim=-1, keepdim=keepdim).log()


def logits_to_probs(logits, is_binary=False):
"""
Converts a tensor of logits into probabilities. Note that for the
binary case, each value denotes log odds, whereas for the
multi-dimensional case, the values along the last dimension denote
the log probabilities (possibly unnormalized) of the events.
"""
if is_binary:
return F.sigmoid(logits)
return softmax(logits)


def probs_to_logits(probs, is_binary=False):
"""
Converts a tensor of probabilities into logits. For the binary case,
this denotes the probability of occurrence of the event indexed by `1`.
For the multi-dimensional case, the values along the last dimension
denote the probabilities of occurrence of each of the events.
"""
eps = _get_clamping_buffer(probs)
ps_clamped = probs.clamp(min=eps, max=1 - eps)
if is_binary:
return torch.log(ps_clamped) - torch.log1p(-ps_clamped)
return torch.log(ps_clamped)


class lazy_property(object):
"""
Used as a decorator for lazy loading of class attributes. This uses a
non-data descriptor that calls the wrapped method to compute the property on
first call; thereafter replacing the wrapped method into an instance
attribute.
"""
def __init__(self, wrapped):
self.wrapped = wrapped
update_wrapper(self, wrapped)

def __get__(self, instance, obj_type=None):
if instance is None:
return self
value = self.wrapped(instance)
setattr(instance, self.wrapped.__name__, value)
return value
Loading

0 comments on commit 408c84d

Please sign in to comment.