Skip to content

Commit

Permalink
Gamma, Chi2 and Exponential Distributions for Tensorflow
Browse files Browse the repository at this point in the history
Change: 122546445
  • Loading branch information
A. Unique TensorFlower authored and tensorflower-gardener committed May 17, 2016
1 parent 43ff0e9 commit da10ae8
Show file tree
Hide file tree
Showing 8 changed files with 647 additions and 1 deletion.
28 changes: 27 additions & 1 deletion tensorflow/contrib/distributions/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,33 @@ cuda_py_tests(
],
)

cuda_py_tests(
name = "gamma_test",
srcs = ["python/kernel_tests/gamma_test.py"],
additional_deps = [
"//tensorflow/python:framework_test_lib",
"//tensorflow/python:platform_test",
],
)

cuda_py_tests(
name = "chi2_test",
srcs = ["python/kernel_tests/chi2_test.py"],
additional_deps = [
"//tensorflow/python:framework_test_lib",
"//tensorflow/python:platform_test",
],
)

cuda_py_tests(
name = "exponential_test",
srcs = ["python/kernel_tests/exponential_test.py"],
additional_deps = [
"//tensorflow/python:framework_test_lib",
"//tensorflow/python:platform_test",
],
)

cuda_py_tests(
name = "gaussian_test",
size = "small",
Expand Down Expand Up @@ -65,7 +92,6 @@ cuda_py_tests(
srcs = ["python/kernel_tests/gaussian_conjugate_posteriors_test.py"],
additional_deps = [
":distributions_py",
"//tensorflow/python:framework_test_lib",
"//tensorflow/python:platform_test",
],
)
Expand Down
7 changes: 7 additions & 0 deletions tensorflow/contrib/distributions/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,9 @@
### Univariate (scalar) distributions
@@Chi2
@@Exponential
@@Gamma
@@Gaussian
@@Uniform
Expand All @@ -50,8 +53,12 @@
from __future__ import print_function

# pylint: disable=unused-import,wildcard-import,line-too-long

from tensorflow.contrib.distributions.python.ops.chi2 import *
from tensorflow.contrib.distributions.python.ops.dirichlet_multinomial import *
from tensorflow.contrib.distributions.python.ops.distribution import *
from tensorflow.contrib.distributions.python.ops.exponential import *
from tensorflow.contrib.distributions.python.ops.gamma import *
from tensorflow.contrib.distributions.python.ops.gaussian import *
from tensorflow.contrib.distributions.python.ops.gaussian_conjugate_posteriors import *
from tensorflow.contrib.distributions.python.ops.mvn import *
Expand Down
85 changes: 85 additions & 0 deletions tensorflow/contrib/distributions/python/kernel_tests/chi2_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
# Copyright 2016 Google Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for initializers."""

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import numpy as np
from scipy import stats
import tensorflow as tf


class Chi2Test(tf.test.TestCase):

def testChi2LogPDF(self):
with tf.Session():
batch_size = 6
df = tf.constant([2.0] * batch_size, dtype=np.float64)
df_v = 2.0
x = np.array([2.5, 2.5, 4.0, 0.1, 1.0, 2.0], dtype=np.float64)
chi2 = tf.contrib.distributions.Chi2(df=df)
expected_log_pdf = stats.chi2.logpdf(x, df_v)

log_pdf = chi2.log_pdf(x)
self.assertEqual(log_pdf.get_shape(), (6,))
self.assertAllClose(log_pdf.eval(), expected_log_pdf)

pdf = chi2.pdf(x)
self.assertEqual(pdf.get_shape(), (6,))
self.assertAllClose(pdf.eval(), np.exp(expected_log_pdf))

def testChi2CDF(self):
with tf.Session():
batch_size = 6
df = tf.constant([2.0] * batch_size, dtype=np.float64)
df_v = 2.0
x = np.array([2.5, 2.5, 4.0, 0.1, 1.0, 2.0], dtype=np.float64)

chi2 = tf.contrib.distributions.Chi2(df=df)
expected_cdf = stats.chi2.cdf(x, df_v)

cdf = chi2.cdf(x)
self.assertEqual(cdf.get_shape(), (6,))
self.assertAllClose(cdf.eval(), expected_cdf)

def testChi2Mean(self):
with tf.Session():
df_v = np.array([1., 3, 5], dtype=np.float64)
expected_mean = stats.chi2.mean(df_v)
chi2 = tf.contrib.distributions.Chi2(df=df_v)
self.assertEqual(chi2.mean.get_shape(), (3,))
self.assertAllClose(chi2.mean.eval(), expected_mean)

def testChi2Variance(self):
with tf.Session():
df_v = np.array([1., 3, 5], np.float64)
expected_variances = stats.chi2.var(df_v)
chi2 = tf.contrib.distributions.Chi2(df=df_v)
self.assertEqual(chi2.variance.get_shape(), (3,))
self.assertAllClose(chi2.variance.eval(), expected_variances)

def testChi2Entropy(self):
with tf.Session():
df_v = np.array([1., 3, 5], dtype=np.float64)
expected_entropy = stats.chi2.entropy(df_v)
chi2 = tf.contrib.distributions.Chi2(df=df_v)
self.assertEqual(chi2.entropy().get_shape(), (3,))
self.assertAllClose(chi2.entropy().eval(), expected_entropy)


if __name__ == '__main__':
tf.test.main()
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
# Copyright 2016 Google Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for initializers."""

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import numpy as np
from scipy import stats
import tensorflow as tf


class ExponentialTest(tf.test.TestCase):

def testExponentialLogPDF(self):
with tf.Session():
batch_size = 6
lam = tf.constant([2.0] * batch_size)
lam_v = 2.0
x = np.array([2.5, 2.5, 4.0, 0.1, 1.0, 2.0], dtype=np.float32)
exponential = tf.contrib.distributions.Exponential(lam=lam)
expected_log_pdf = stats.expon.logpdf(x, scale=1 / lam_v)

log_pdf = exponential.log_pdf(x)
self.assertEqual(log_pdf.get_shape(), (6,))
self.assertAllClose(log_pdf.eval(), expected_log_pdf)

pdf = exponential.pdf(x)
self.assertEqual(pdf.get_shape(), (6,))
self.assertAllClose(pdf.eval(), np.exp(expected_log_pdf))

def testExponentialCDF(self):
with tf.Session():
batch_size = 6
lam = tf.constant([2.0] * batch_size)
lam_v = 2.0
x = np.array([2.5, 2.5, 4.0, 0.1, 1.0, 2.0], dtype=np.float32)

exponential = tf.contrib.distributions.Exponential(lam=lam)
expected_cdf = stats.expon.cdf(x, scale=1 / lam_v)

cdf = exponential.cdf(x)
self.assertEqual(cdf.get_shape(), (6,))
self.assertAllClose(cdf.eval(), expected_cdf)

def testExponentialMean(self):
with tf.Session():
lam_v = np.array([1.0, 4.0, 2.5])
expected_mean = stats.expon.mean(scale=1 / lam_v)
exponential = tf.contrib.distributions.Exponential(lam=lam_v)
self.assertEqual(exponential.mean.get_shape(), (3,))
self.assertAllClose(exponential.mean.eval(), expected_mean)

def testExponentialVariance(self):
with tf.Session():
lam_v = np.array([1.0, 4.0, 2.5])
expected_variance = stats.expon.var(scale=1 / lam_v)
exponential = tf.contrib.distributions.Exponential(lam=lam_v)
self.assertEqual(exponential.variance.get_shape(), (3,))
self.assertAllClose(exponential.variance.eval(), expected_variance)

def testExponentialEntropy(self):
with tf.Session():
lam_v = np.array([1.0, 4.0, 2.5])
expected_entropy = stats.expon.entropy(scale=1 / lam_v)
exponential = tf.contrib.distributions.Exponential(lam=lam_v)
self.assertEqual(exponential.entropy().get_shape(), (3,))
self.assertAllClose(exponential.entropy().eval(), expected_entropy)


if __name__ == '__main__':
tf.test.main()
142 changes: 142 additions & 0 deletions tensorflow/contrib/distributions/python/kernel_tests/gamma_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,142 @@
# Copyright 2016 Google Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for initializers."""

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import numpy as np
from scipy import stats
import tensorflow as tf


class GammaTest(tf.test.TestCase):

def testGammaShape(self):
with tf.Session():
alpha = tf.constant([3.0] * 5)
beta = tf.constant(11.0)
gamma = tf.contrib.distributions.Gamma(alpha=alpha, beta=beta)

self.assertEqual(gamma.batch_shape().eval(), (5,))
self.assertEqual(gamma.get_batch_shape(), tf.TensorShape([5]))
self.assertEqual(gamma.event_shape().eval(), 1)
self.assertEqual(gamma.get_event_shape(), tf.TensorShape([]))

def testGammaLogPDF(self):
with tf.Session():
batch_size = 6
alpha = tf.constant([2.0] * batch_size)
beta = tf.constant([3.0] * batch_size)
alpha_v = 2.0
beta_v = 3.0
x = np.array([2.5, 2.5, 4.0, 0.1, 1.0, 2.0], dtype=np.float32)
gamma = tf.contrib.distributions.Gamma(alpha=alpha, beta=beta)
expected_log_pdf = stats.gamma.logpdf(x, alpha_v, scale=1 / beta_v)
log_pdf = gamma.log_pdf(x)
self.assertEqual(log_pdf.get_shape(), (6,))
self.assertAllClose(log_pdf.eval(), expected_log_pdf)

pdf = gamma.pdf(x)
self.assertEqual(pdf.get_shape(), (6,))
self.assertAllClose(pdf.eval(), np.exp(expected_log_pdf))

def testGammaLogPDFMultidimensional(self):
with tf.Session():
batch_size = 6
alpha = tf.constant([[2.0, 4.0]] * batch_size)
beta = tf.constant([[3.0, 4.0]] * batch_size)
alpha_v = np.array([2.0, 4.0])
beta_v = np.array([3.0, 4.0])
x = np.array([[2.5, 2.5, 4.0, 0.1, 1.0, 2.0]], dtype=np.float32).T
gamma = tf.contrib.distributions.Gamma(alpha=alpha, beta=beta)
expected_log_pdf = stats.gamma.logpdf(x, alpha_v, scale=1 / beta_v)
log_pdf = gamma.log_pdf(x)
log_pdf_values = log_pdf.eval()
self.assertEqual(log_pdf.get_shape(), (6, 2))
self.assertAllClose(log_pdf_values, expected_log_pdf)

pdf = gamma.pdf(x)
pdf_values = pdf.eval()
self.assertEqual(pdf.get_shape(), (6, 2))
self.assertAllClose(pdf_values, np.exp(expected_log_pdf))

def testGammaLogPDFMultidimensionalBroadcasting(self):
with tf.Session():
batch_size = 6
alpha = tf.constant([[2.0, 4.0]] * batch_size)
beta = tf.constant(3.0)
alpha_v = np.array([2.0, 4.0])
beta_v = 3.0
x = np.array([[2.5, 2.5, 4.0, 0.1, 1.0, 2.0]], dtype=np.float32).T
gamma = tf.contrib.distributions.Gamma(alpha=alpha, beta=beta)
expected_log_pdf = stats.gamma.logpdf(x, alpha_v, scale=1 / beta_v)
log_pdf = gamma.log_pdf(x)
log_pdf_values = log_pdf.eval()
self.assertEqual(log_pdf.get_shape(), (6, 2))
self.assertAllClose(log_pdf_values, expected_log_pdf)

pdf = gamma.pdf(x)
pdf_values = pdf.eval()
self.assertEqual(pdf.get_shape(), (6, 2))
self.assertAllClose(pdf_values, np.exp(expected_log_pdf))

def testGammaCDF(self):
with tf.Session():
batch_size = 6
alpha = tf.constant([2.0] * batch_size)
beta = tf.constant([3.0] * batch_size)
alpha_v = 2.0
beta_v = 3.0
x = np.array([2.5, 2.5, 4.0, 0.1, 1.0, 2.0], dtype=np.float32)

gamma = tf.contrib.distributions.Gamma(alpha=alpha, beta=beta)
expected_cdf = stats.gamma.cdf(x, alpha_v, scale=1 / beta_v)

cdf = gamma.cdf(x)
self.assertEqual(cdf.get_shape(), (6,))
self.assertAllClose(cdf.eval(), expected_cdf)

def testGammaMean(self):
with tf.Session():
alpha_v = np.array([1.0, 3.0, 2.5])
beta_v = np.array([1.0, 4.0, 5.0])
gamma = tf.contrib.distributions.Gamma(alpha=alpha_v, beta=beta_v)
expected_means = stats.gamma.mean(alpha_v, scale=1 / beta_v)
self.assertEqual(gamma.mean.get_shape(), (3,))
self.assertAllClose(gamma.mean.eval(), expected_means)

def testGammaVariance(self):
with tf.Session():
alpha_v = np.array([1.0, 3.0, 2.5])
beta_v = np.array([1.0, 4.0, 5.0])
gamma = tf.contrib.distributions.Gamma(alpha=alpha_v, beta=beta_v)
expected_variances = stats.gamma.var(alpha_v, scale=1 / beta_v)
self.assertEqual(gamma.variance.get_shape(), (3,))
self.assertAllClose(gamma.variance.eval(), expected_variances)

def testGammaEntropy(self):
with tf.Session():
alpha_v = np.array([1.0, 3.0, 2.5])
beta_v = np.array([1.0, 4.0, 5.0])
expected_entropy = stats.gamma.entropy(alpha_v, scale=1 / beta_v)
gamma = tf.contrib.distributions.Gamma(alpha=alpha_v, beta=beta_v)
self.assertEqual(gamma.entropy().get_shape(), (3,))
self.assertAllClose(gamma.entropy().eval(), expected_entropy)


if __name__ == '__main__':
tf.test.main()
Loading

0 comments on commit da10ae8

Please sign in to comment.