forked from tensorflow/tensorflow
-
Notifications
You must be signed in to change notification settings - Fork 1
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Gamma, Chi2 and Exponential Distributions for Tensorflow
Change: 122546445
- Loading branch information
1 parent
43ff0e9
commit da10ae8
Showing
8 changed files
with
647 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
85 changes: 85 additions & 0 deletions
85
tensorflow/contrib/distributions/python/kernel_tests/chi2_test.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,85 @@ | ||
# Copyright 2016 Google Inc. All Rights Reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
# ============================================================================== | ||
"""Tests for initializers.""" | ||
|
||
from __future__ import absolute_import | ||
from __future__ import division | ||
from __future__ import print_function | ||
|
||
import numpy as np | ||
from scipy import stats | ||
import tensorflow as tf | ||
|
||
|
||
class Chi2Test(tf.test.TestCase): | ||
|
||
def testChi2LogPDF(self): | ||
with tf.Session(): | ||
batch_size = 6 | ||
df = tf.constant([2.0] * batch_size, dtype=np.float64) | ||
df_v = 2.0 | ||
x = np.array([2.5, 2.5, 4.0, 0.1, 1.0, 2.0], dtype=np.float64) | ||
chi2 = tf.contrib.distributions.Chi2(df=df) | ||
expected_log_pdf = stats.chi2.logpdf(x, df_v) | ||
|
||
log_pdf = chi2.log_pdf(x) | ||
self.assertEqual(log_pdf.get_shape(), (6,)) | ||
self.assertAllClose(log_pdf.eval(), expected_log_pdf) | ||
|
||
pdf = chi2.pdf(x) | ||
self.assertEqual(pdf.get_shape(), (6,)) | ||
self.assertAllClose(pdf.eval(), np.exp(expected_log_pdf)) | ||
|
||
def testChi2CDF(self): | ||
with tf.Session(): | ||
batch_size = 6 | ||
df = tf.constant([2.0] * batch_size, dtype=np.float64) | ||
df_v = 2.0 | ||
x = np.array([2.5, 2.5, 4.0, 0.1, 1.0, 2.0], dtype=np.float64) | ||
|
||
chi2 = tf.contrib.distributions.Chi2(df=df) | ||
expected_cdf = stats.chi2.cdf(x, df_v) | ||
|
||
cdf = chi2.cdf(x) | ||
self.assertEqual(cdf.get_shape(), (6,)) | ||
self.assertAllClose(cdf.eval(), expected_cdf) | ||
|
||
def testChi2Mean(self): | ||
with tf.Session(): | ||
df_v = np.array([1., 3, 5], dtype=np.float64) | ||
expected_mean = stats.chi2.mean(df_v) | ||
chi2 = tf.contrib.distributions.Chi2(df=df_v) | ||
self.assertEqual(chi2.mean.get_shape(), (3,)) | ||
self.assertAllClose(chi2.mean.eval(), expected_mean) | ||
|
||
def testChi2Variance(self): | ||
with tf.Session(): | ||
df_v = np.array([1., 3, 5], np.float64) | ||
expected_variances = stats.chi2.var(df_v) | ||
chi2 = tf.contrib.distributions.Chi2(df=df_v) | ||
self.assertEqual(chi2.variance.get_shape(), (3,)) | ||
self.assertAllClose(chi2.variance.eval(), expected_variances) | ||
|
||
def testChi2Entropy(self): | ||
with tf.Session(): | ||
df_v = np.array([1., 3, 5], dtype=np.float64) | ||
expected_entropy = stats.chi2.entropy(df_v) | ||
chi2 = tf.contrib.distributions.Chi2(df=df_v) | ||
self.assertEqual(chi2.entropy().get_shape(), (3,)) | ||
self.assertAllClose(chi2.entropy().eval(), expected_entropy) | ||
|
||
|
||
if __name__ == '__main__': | ||
tf.test.main() |
85 changes: 85 additions & 0 deletions
85
tensorflow/contrib/distributions/python/kernel_tests/exponential_test.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,85 @@ | ||
# Copyright 2016 Google Inc. All Rights Reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
# ============================================================================== | ||
"""Tests for initializers.""" | ||
|
||
from __future__ import absolute_import | ||
from __future__ import division | ||
from __future__ import print_function | ||
|
||
import numpy as np | ||
from scipy import stats | ||
import tensorflow as tf | ||
|
||
|
||
class ExponentialTest(tf.test.TestCase): | ||
|
||
def testExponentialLogPDF(self): | ||
with tf.Session(): | ||
batch_size = 6 | ||
lam = tf.constant([2.0] * batch_size) | ||
lam_v = 2.0 | ||
x = np.array([2.5, 2.5, 4.0, 0.1, 1.0, 2.0], dtype=np.float32) | ||
exponential = tf.contrib.distributions.Exponential(lam=lam) | ||
expected_log_pdf = stats.expon.logpdf(x, scale=1 / lam_v) | ||
|
||
log_pdf = exponential.log_pdf(x) | ||
self.assertEqual(log_pdf.get_shape(), (6,)) | ||
self.assertAllClose(log_pdf.eval(), expected_log_pdf) | ||
|
||
pdf = exponential.pdf(x) | ||
self.assertEqual(pdf.get_shape(), (6,)) | ||
self.assertAllClose(pdf.eval(), np.exp(expected_log_pdf)) | ||
|
||
def testExponentialCDF(self): | ||
with tf.Session(): | ||
batch_size = 6 | ||
lam = tf.constant([2.0] * batch_size) | ||
lam_v = 2.0 | ||
x = np.array([2.5, 2.5, 4.0, 0.1, 1.0, 2.0], dtype=np.float32) | ||
|
||
exponential = tf.contrib.distributions.Exponential(lam=lam) | ||
expected_cdf = stats.expon.cdf(x, scale=1 / lam_v) | ||
|
||
cdf = exponential.cdf(x) | ||
self.assertEqual(cdf.get_shape(), (6,)) | ||
self.assertAllClose(cdf.eval(), expected_cdf) | ||
|
||
def testExponentialMean(self): | ||
with tf.Session(): | ||
lam_v = np.array([1.0, 4.0, 2.5]) | ||
expected_mean = stats.expon.mean(scale=1 / lam_v) | ||
exponential = tf.contrib.distributions.Exponential(lam=lam_v) | ||
self.assertEqual(exponential.mean.get_shape(), (3,)) | ||
self.assertAllClose(exponential.mean.eval(), expected_mean) | ||
|
||
def testExponentialVariance(self): | ||
with tf.Session(): | ||
lam_v = np.array([1.0, 4.0, 2.5]) | ||
expected_variance = stats.expon.var(scale=1 / lam_v) | ||
exponential = tf.contrib.distributions.Exponential(lam=lam_v) | ||
self.assertEqual(exponential.variance.get_shape(), (3,)) | ||
self.assertAllClose(exponential.variance.eval(), expected_variance) | ||
|
||
def testExponentialEntropy(self): | ||
with tf.Session(): | ||
lam_v = np.array([1.0, 4.0, 2.5]) | ||
expected_entropy = stats.expon.entropy(scale=1 / lam_v) | ||
exponential = tf.contrib.distributions.Exponential(lam=lam_v) | ||
self.assertEqual(exponential.entropy().get_shape(), (3,)) | ||
self.assertAllClose(exponential.entropy().eval(), expected_entropy) | ||
|
||
|
||
if __name__ == '__main__': | ||
tf.test.main() |
142 changes: 142 additions & 0 deletions
142
tensorflow/contrib/distributions/python/kernel_tests/gamma_test.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,142 @@ | ||
# Copyright 2016 Google Inc. All Rights Reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
# ============================================================================== | ||
"""Tests for initializers.""" | ||
|
||
from __future__ import absolute_import | ||
from __future__ import division | ||
from __future__ import print_function | ||
|
||
import numpy as np | ||
from scipy import stats | ||
import tensorflow as tf | ||
|
||
|
||
class GammaTest(tf.test.TestCase): | ||
|
||
def testGammaShape(self): | ||
with tf.Session(): | ||
alpha = tf.constant([3.0] * 5) | ||
beta = tf.constant(11.0) | ||
gamma = tf.contrib.distributions.Gamma(alpha=alpha, beta=beta) | ||
|
||
self.assertEqual(gamma.batch_shape().eval(), (5,)) | ||
self.assertEqual(gamma.get_batch_shape(), tf.TensorShape([5])) | ||
self.assertEqual(gamma.event_shape().eval(), 1) | ||
self.assertEqual(gamma.get_event_shape(), tf.TensorShape([])) | ||
|
||
def testGammaLogPDF(self): | ||
with tf.Session(): | ||
batch_size = 6 | ||
alpha = tf.constant([2.0] * batch_size) | ||
beta = tf.constant([3.0] * batch_size) | ||
alpha_v = 2.0 | ||
beta_v = 3.0 | ||
x = np.array([2.5, 2.5, 4.0, 0.1, 1.0, 2.0], dtype=np.float32) | ||
gamma = tf.contrib.distributions.Gamma(alpha=alpha, beta=beta) | ||
expected_log_pdf = stats.gamma.logpdf(x, alpha_v, scale=1 / beta_v) | ||
log_pdf = gamma.log_pdf(x) | ||
self.assertEqual(log_pdf.get_shape(), (6,)) | ||
self.assertAllClose(log_pdf.eval(), expected_log_pdf) | ||
|
||
pdf = gamma.pdf(x) | ||
self.assertEqual(pdf.get_shape(), (6,)) | ||
self.assertAllClose(pdf.eval(), np.exp(expected_log_pdf)) | ||
|
||
def testGammaLogPDFMultidimensional(self): | ||
with tf.Session(): | ||
batch_size = 6 | ||
alpha = tf.constant([[2.0, 4.0]] * batch_size) | ||
beta = tf.constant([[3.0, 4.0]] * batch_size) | ||
alpha_v = np.array([2.0, 4.0]) | ||
beta_v = np.array([3.0, 4.0]) | ||
x = np.array([[2.5, 2.5, 4.0, 0.1, 1.0, 2.0]], dtype=np.float32).T | ||
gamma = tf.contrib.distributions.Gamma(alpha=alpha, beta=beta) | ||
expected_log_pdf = stats.gamma.logpdf(x, alpha_v, scale=1 / beta_v) | ||
log_pdf = gamma.log_pdf(x) | ||
log_pdf_values = log_pdf.eval() | ||
self.assertEqual(log_pdf.get_shape(), (6, 2)) | ||
self.assertAllClose(log_pdf_values, expected_log_pdf) | ||
|
||
pdf = gamma.pdf(x) | ||
pdf_values = pdf.eval() | ||
self.assertEqual(pdf.get_shape(), (6, 2)) | ||
self.assertAllClose(pdf_values, np.exp(expected_log_pdf)) | ||
|
||
def testGammaLogPDFMultidimensionalBroadcasting(self): | ||
with tf.Session(): | ||
batch_size = 6 | ||
alpha = tf.constant([[2.0, 4.0]] * batch_size) | ||
beta = tf.constant(3.0) | ||
alpha_v = np.array([2.0, 4.0]) | ||
beta_v = 3.0 | ||
x = np.array([[2.5, 2.5, 4.0, 0.1, 1.0, 2.0]], dtype=np.float32).T | ||
gamma = tf.contrib.distributions.Gamma(alpha=alpha, beta=beta) | ||
expected_log_pdf = stats.gamma.logpdf(x, alpha_v, scale=1 / beta_v) | ||
log_pdf = gamma.log_pdf(x) | ||
log_pdf_values = log_pdf.eval() | ||
self.assertEqual(log_pdf.get_shape(), (6, 2)) | ||
self.assertAllClose(log_pdf_values, expected_log_pdf) | ||
|
||
pdf = gamma.pdf(x) | ||
pdf_values = pdf.eval() | ||
self.assertEqual(pdf.get_shape(), (6, 2)) | ||
self.assertAllClose(pdf_values, np.exp(expected_log_pdf)) | ||
|
||
def testGammaCDF(self): | ||
with tf.Session(): | ||
batch_size = 6 | ||
alpha = tf.constant([2.0] * batch_size) | ||
beta = tf.constant([3.0] * batch_size) | ||
alpha_v = 2.0 | ||
beta_v = 3.0 | ||
x = np.array([2.5, 2.5, 4.0, 0.1, 1.0, 2.0], dtype=np.float32) | ||
|
||
gamma = tf.contrib.distributions.Gamma(alpha=alpha, beta=beta) | ||
expected_cdf = stats.gamma.cdf(x, alpha_v, scale=1 / beta_v) | ||
|
||
cdf = gamma.cdf(x) | ||
self.assertEqual(cdf.get_shape(), (6,)) | ||
self.assertAllClose(cdf.eval(), expected_cdf) | ||
|
||
def testGammaMean(self): | ||
with tf.Session(): | ||
alpha_v = np.array([1.0, 3.0, 2.5]) | ||
beta_v = np.array([1.0, 4.0, 5.0]) | ||
gamma = tf.contrib.distributions.Gamma(alpha=alpha_v, beta=beta_v) | ||
expected_means = stats.gamma.mean(alpha_v, scale=1 / beta_v) | ||
self.assertEqual(gamma.mean.get_shape(), (3,)) | ||
self.assertAllClose(gamma.mean.eval(), expected_means) | ||
|
||
def testGammaVariance(self): | ||
with tf.Session(): | ||
alpha_v = np.array([1.0, 3.0, 2.5]) | ||
beta_v = np.array([1.0, 4.0, 5.0]) | ||
gamma = tf.contrib.distributions.Gamma(alpha=alpha_v, beta=beta_v) | ||
expected_variances = stats.gamma.var(alpha_v, scale=1 / beta_v) | ||
self.assertEqual(gamma.variance.get_shape(), (3,)) | ||
self.assertAllClose(gamma.variance.eval(), expected_variances) | ||
|
||
def testGammaEntropy(self): | ||
with tf.Session(): | ||
alpha_v = np.array([1.0, 3.0, 2.5]) | ||
beta_v = np.array([1.0, 4.0, 5.0]) | ||
expected_entropy = stats.gamma.entropy(alpha_v, scale=1 / beta_v) | ||
gamma = tf.contrib.distributions.Gamma(alpha=alpha_v, beta=beta_v) | ||
self.assertEqual(gamma.entropy().get_shape(), (3,)) | ||
self.assertAllClose(gamma.entropy().eval(), expected_entropy) | ||
|
||
|
||
if __name__ == '__main__': | ||
tf.test.main() |
Oops, something went wrong.