diff --git a/tensorflow/contrib/distributions/BUILD b/tensorflow/contrib/distributions/BUILD index c9cfc922079472..451a34320e0485 100644 --- a/tensorflow/contrib/distributions/BUILD +++ b/tensorflow/contrib/distributions/BUILD @@ -65,6 +65,17 @@ cuda_py_tests( ], ) +cuda_py_tests( + name = "student_t_test", + size = "small", + srcs = ["python/kernel_tests/student_t_test.py"], + additional_deps = [ + ":distributions_py", + "//tensorflow/python:framework_test_lib", + "//tensorflow/python:platform_test", + ], +) + cuda_py_tests( name = "uniform_test", size = "small", diff --git a/tensorflow/contrib/distributions/__init__.py b/tensorflow/contrib/distributions/__init__.py index 74cedaa251ed96..2c8a0343b28f50 100644 --- a/tensorflow/contrib/distributions/__init__.py +++ b/tensorflow/contrib/distributions/__init__.py @@ -31,6 +31,7 @@ @@Exponential @@Gamma @@Gaussian +@@StudentT @@Uniform ### Multivariate distributions @@ -62,4 +63,5 @@ from tensorflow.contrib.distributions.python.ops.gaussian import * from tensorflow.contrib.distributions.python.ops.gaussian_conjugate_posteriors import * from tensorflow.contrib.distributions.python.ops.mvn import * +from tensorflow.contrib.distributions.python.ops.student_t import * from tensorflow.contrib.distributions.python.ops.uniform import * diff --git a/tensorflow/contrib/distributions/python/kernel_tests/student_t_test.py b/tensorflow/contrib/distributions/python/kernel_tests/student_t_test.py new file mode 100644 index 00000000000000..625c421dc31195 --- /dev/null +++ b/tensorflow/contrib/distributions/python/kernel_tests/student_t_test.py @@ -0,0 +1,321 @@ +# Copyright 2016 Google Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for Student t distribution.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import math + +import numpy as np +from scipy import stats +import tensorflow as tf + + +class StudentTTest(tf.test.TestCase): + + def testStudentPDFAndLogPDF(self): + with tf.Session(): + batch_size = 6 + df = tf.constant([3.0] * batch_size) + mu = tf.constant([7.0] * batch_size) + sigma = tf.constant([8.0] * batch_size) + df_v = 3.0 + mu_v = 7.0 + sigma_v = 8.0 + t = np.array([-2.5, 2.5, 8.0, 0.0, -1.0, 2.0], dtype=np.float32) + student = tf.contrib.distributions.StudentT(df, mu=mu, sigma=sigma) + + log_pdf = student.log_pdf(t) + self.assertEquals(log_pdf.get_shape(), (6,)) + log_pdf_values = log_pdf.eval() + pdf = student.pdf(t) + self.assertEquals(pdf.get_shape(), (6,)) + pdf_values = pdf.eval() + + expected_log_pdf = stats.t.logpdf(t, df_v, loc=mu_v, scale=sigma_v) + expected_pdf = stats.t.pdf(t, df_v, loc=mu_v, scale=sigma_v) + self.assertAllClose(expected_log_pdf, log_pdf_values) + self.assertAllClose(np.log(expected_pdf), log_pdf_values) + self.assertAllClose(expected_pdf, pdf_values) + self.assertAllClose(np.exp(expected_log_pdf), pdf_values) + + def testStudentLogPDFMultidimensional(self): + with tf.Session(): + batch_size = 6 + df = tf.constant([[1.5, 7.2]] * batch_size) + mu = tf.constant([[3.0, -3.0]] * batch_size) + sigma = tf.constant([[math.sqrt(10.0), math.sqrt(15.0)]] * batch_size) + df_v = np.array([1.5, 7.2]) + mu_v = np.array([3.0, -3.0]) + sigma_v = np.array([np.sqrt(10.0), np.sqrt(15.0)]) + t = np.array([[-2.5, 2.5, 4.0, 0.0, -1.0, 2.0]], dtype=np.float32).T + student = tf.contrib.distributions.StudentT(df, mu=mu, sigma=sigma) + log_pdf = student.log_pdf(t) + log_pdf_values = log_pdf.eval() + self.assertEqual(log_pdf.get_shape(), (6, 2)) + pdf = student.pdf(t) + pdf_values = pdf.eval() + self.assertEqual(pdf.get_shape(), (6, 2)) + expected_log_pdf = stats.t.logpdf(t, df_v, loc=mu_v, scale=sigma_v) + expected_pdf = stats.t.pdf(t, df_v, loc=mu_v, scale=sigma_v) + self.assertAllClose(expected_log_pdf, log_pdf_values) + self.assertAllClose(np.log(expected_pdf), log_pdf_values) + self.assertAllClose(expected_pdf, pdf_values) + self.assertAllClose(np.exp(expected_log_pdf), pdf_values) + + def testStudentEntropy(self): + df_v = np.array([[2., 3., 7.]]) # 1x3 + mu_v = np.array([[1., -1, 0]]) # 1x3 + sigma_v = np.array([[1., 2., 3.]]).T # transposed => 3x1 + with tf.Session(): + student = tf.contrib.distributions.StudentT(df=df_v, + mu=mu_v, + sigma=sigma_v) + ent = student.entropy() + ent_values = ent.eval() + + # Help scipy broadcast to 3x3 + ones = np.array([[1, 1, 1]]) + sigma_bc = sigma_v * ones + mu_bc = ones.T * mu_v + df_bc = ones.T * df_v + expected_entropy = stats.t.entropy( + np.reshape(df_bc, [-1]), + loc=np.reshape(mu_bc, [-1]), + scale=np.reshape(sigma_bc, [-1])) + expected_entropy = np.reshape(expected_entropy, df_bc.shape) + self.assertAllClose(expected_entropy, ent_values) + + def testStudentSample(self): + with tf.Session(): + df = tf.constant(4.0) + mu = tf.constant(3.0) + sigma = tf.constant(math.sqrt(10.0)) + df_v = 4.0 + mu_v = 3.0 + sigma_v = np.sqrt(10.0) + n = tf.constant(100000) + student = tf.contrib.distributions.StudentT(df=df, mu=mu, sigma=sigma) + samples = student.sample(n, seed=137) + sample_values = samples.eval() + n = 100000 + self.assertEqual(sample_values.shape, (n,)) + self.assertAllClose(sample_values.mean(), mu_v, atol=1e-2) + self.assertAllClose(sample_values.var(), + sigma_v**2 * df_v / (df_v - 2), + atol=.25) + self._checkKLApprox(df_v, mu_v, sigma_v, sample_values) + + def testStudentSampleMultiDimensional(self): + with tf.Session(): + batch_size = 7 + df = tf.constant([[3.0, 7.0]] * batch_size) + mu = tf.constant([[3.0, -3.0]] * batch_size) + sigma = tf.constant([[math.sqrt(10.0), math.sqrt(15.0)]] * batch_size) + df_v = [3.0, 7.0] + mu_v = [3.0, -3.0] + sigma_v = [np.sqrt(10.0), np.sqrt(15.0)] + n = tf.constant(100000) + student = tf.contrib.distributions.StudentT(df=df, mu=mu, sigma=sigma) + samples = student.sample(n, seed=137) + sample_values = samples.eval() + self.assertEqual(samples.get_shape(), (100000, batch_size, 2)) + self.assertAllClose(sample_values[:, 0, 0].mean(), mu_v[0], atol=.15) + self.assertAllClose(sample_values[:, 0, 0].var(), + sigma_v[0]**2 * df_v[0] / (df_v[0] - 2), + atol=1) + self._checkKLApprox(df_v[0], mu_v[0], sigma_v[0], sample_values[:, 0, 0]) + self.assertAllClose(sample_values[:, 0, 1].mean(), mu_v[1], atol=.01) + self.assertAllClose(sample_values[:, 0, 1].var(), + sigma_v[1]**2 * df_v[1] / (df_v[1] - 2), + atol=.25) + self._checkKLApprox(df_v[0], mu_v[0], sigma_v[0], sample_values[:, 0, 1]) + + def _checkKLApprox(self, df, mu, sigma, samples): + n = samples.size + np.random.seed(137) + sample_scipy = stats.t.rvs(df, loc=mu, scale=sigma, size=n) + covg = 0.99 + r = stats.t.interval(covg, df, loc=mu, scale=sigma) + bins = 100 + hist, _ = np.histogram(samples, bins=bins, range=r) + hist_scipy, _ = np.histogram(sample_scipy, bins=bins, range=r) + self.assertGreater(hist.sum(), n * (covg - .01)) + self.assertGreater(hist_scipy.sum(), n * (covg - .01)) + hist_min1 = hist + 1. # put at least one item in each bucket + hist_norm = hist_min1 / hist_min1.sum() + hist_scipy_min1 = hist_scipy + 1. # put at least one item in each bucket + hist_scipy_norm = hist_scipy_min1 / hist_scipy_min1.sum() + kl_appx = np.sum(np.log(hist_scipy_norm / hist_norm) * hist_scipy_norm) + self.assertLess(kl_appx, 1) + + def testBroadcastingParams(self): + + def _check(student): + self.assertEqual(student.mean.get_shape(), (3,)) + self.assertEqual(student.variance.get_shape(), (3,)) + self.assertEqual(student.entropy().get_shape(), (3,)) + self.assertEqual(student.log_pdf(2.).get_shape(), (3,)) + self.assertEqual(student.pdf(2.).get_shape(), (3,)) + self.assertEqual(student.sample(37).get_shape(), (37, 3,)) + + _check(tf.contrib.distributions.StudentT(df=[2., 3., 4.,], mu=2., sigma=1.)) + _check(tf.contrib.distributions.StudentT(df=7., mu=[2., 3., 4.,], sigma=1.)) + _check(tf.contrib.distributions.StudentT(df=7., mu=3., sigma=[2., 3., 4.,])) + + def testBroadcastingPdfArgs(self): + + def _assert_shape(student, arg, shape): + self.assertEqual(student.log_pdf(arg).get_shape(), shape) + self.assertEqual(student.pdf(arg).get_shape(), shape) + + def _check(student): + _assert_shape(student, 2., (3,)) + xs = np.array([2., 3., 4.], dtype=np.float32) + _assert_shape(student, xs, (3,)) + xs = np.array([xs]) + _assert_shape(student, xs, (1, 3)) + xs = xs.T + _assert_shape(student, xs, (3, 3)) + + _check(tf.contrib.distributions.StudentT(df=[2., 3., 4.,], mu=2., sigma=1.)) + _check(tf.contrib.distributions.StudentT(df=7., mu=[2., 3., 4.,], sigma=1.)) + _check(tf.contrib.distributions.StudentT(df=7., mu=3., sigma=[2., 3., 4.,])) + + def _check2d(student): + _assert_shape(student, 2., (1, 3)) + xs = np.array([2., 3., 4.], dtype=np.float32) + _assert_shape(student, xs, (1, 3)) + xs = np.array([xs]) + _assert_shape(student, xs, (1, 3)) + xs = xs.T + _assert_shape(student, xs, (3, 3)) + + _check2d(tf.contrib.distributions.StudentT( + df=[[2., 3., 4.,]], mu=2., sigma=1.)) + _check2d(tf.contrib.distributions.StudentT( + df=7., mu=[[2., 3., 4.,]], sigma=1.)) + _check2d(tf.contrib.distributions.StudentT( + df=7., mu=3., sigma=[[2., 3., 4.,]])) + + def _check2d_rows(student): + _assert_shape(student, 2., (3, 1)) + xs = np.array([2., 3., 4.], dtype=np.float32) # (3,) + _assert_shape(student, xs, (3, 3)) + xs = np.array([xs]) # (1,3) + _assert_shape(student, xs, (3, 3)) + xs = xs.T # (3,1) + _assert_shape(student, xs, (3, 1)) + + _check2d_rows(tf.contrib.distributions.StudentT( + df=[[2.], [3.], [4.]], mu=2., sigma=1.)) + _check2d_rows(tf.contrib.distributions.StudentT( + df=7., mu=[[2.], [3.], [4.]], sigma=1.)) + _check2d_rows(tf.contrib.distributions.StudentT( + df=7., mu=3., sigma=[[2.], [3.], [4.]])) + + def testMeanVar(self): + with tf.Session(): + student = tf.contrib.distributions.StudentT( + df=[1., 2., 3., 5., 7.], + mu=np.exp(1, dtype=np.float32), + sigma=[5., 4., 3., 2., 1.]) + # Test broadcast of mu across shape of df/sigma + mean = student.mean.eval() + self.assertAllClose([np.exp(1, dtype=np.float32)] * 5, mean) + var = student.variance.eval() + # loc does not effect variance, so we use 0. + self.assertAllClose([stats.t.var(1., loc=0., scale=5.), + stats.t.var(2., loc=0., scale=4.), + stats.t.var(3., loc=0., scale=3.), + stats.t.var(5., loc=0., scale=2.), + stats.t.var(7., loc=0., scale=1.)], var) + + def testPdfOfSample(self): + with tf.Session() as sess: + student = tf.contrib.distributions.StudentT(df=3., mu=np.pi, sigma=1.) + num = 20000 + samples = student.sample(num, seed=137) + pdfs = student.pdf(samples) + mean = student.mean + mean_pdf = student.pdf(student.mean) + sample_vals, pdf_vals, mean_val, mean_pdf_val = sess.run( + [samples, pdfs, student.mean, mean_pdf]) + self.assertEqual(samples.get_shape(), (num,)) + self.assertEqual(pdfs.get_shape(), (num,)) + self.assertEqual(mean.get_shape(), ()) + self.assertNear(np.pi, np.mean(sample_vals), err=0.02) + self.assertNear(np.pi, mean_val, err=1e-6) + self.assertNear(stats.t.pdf(np.pi, 3., loc=np.pi), mean_pdf_val, err=1e-6) + # Verify integral over sample*pdf ~= 1. + self._assertIntegral(sample_vals, pdf_vals) + + def testPdfOfSampleMultiDims(self): + with tf.Session() as sess: + student = tf.contrib.distributions.StudentT(df=[7., 11.], + mu=[[5.], [6.]], + sigma=3.) + num = 50000 + samples = student.sample(num, seed=137) + pdfs = student.pdf(samples) + sample_vals, pdf_vals = sess.run([samples, pdfs]) + self.assertEqual(samples.get_shape(), (num, 2, 2)) + self.assertEqual(pdfs.get_shape(), (num, 2, 2)) + self.assertNear(5., np.mean(sample_vals[:, 0, :]), err=.03) + self.assertNear(6., np.mean(sample_vals[:, 1, :]), err=.03) + self.assertNear(stats.t.var(7., loc=0., scale=3.), # loc d.n. effect var + np.var(sample_vals[:, :, 0]), + err=.25) + self.assertNear(stats.t.var(11., loc=0., scale=3.), # loc d.n. effect var + np.var(sample_vals[:, :, 1]), + err=.25) + self._assertIntegral(sample_vals[:, 0, 0], pdf_vals[:, 0, 0], err=0.02) + self._assertIntegral(sample_vals[:, 0, 1], pdf_vals[:, 0, 1], err=0.02) + self._assertIntegral(sample_vals[:, 1, 0], pdf_vals[:, 1, 0], err=0.02) + self._assertIntegral(sample_vals[:, 1, 1], pdf_vals[:, 1, 1], err=0.02) + + def _assertIntegral(self, sample_vals, pdf_vals, err=1e-3): + s_p = zip(sample_vals, pdf_vals) + prev = (sample_vals.min() - 1000, 0) + total = 0 + for k in sorted(s_p, key=lambda x: x[0]): + pair_pdf = (k[1] + prev[1]) / 2 + total += (k[0] - prev[0]) * pair_pdf + prev = k + self.assertNear(1., total, err=err) + + def testNegativeDofFails(self): + with tf.Session(): + student = tf.contrib.distributions.StudentT(df=[2, -5.], + mu=0., + sigma=1., + name='S') + with self.assertRaisesOpError(r'Condition x > 0 did not hold'): + student.mean.eval() + + def testNegativeScaleFails(self): + with tf.Session(): + student = tf.contrib.distributions.StudentT(df=[5.], + mu=0., + sigma=[[3.], [-2.]], + name='S') + with self.assertRaisesOpError(r'Condition x > 0 did not hold'): + student.mean.eval() + + +if __name__ == '__main__': + tf.test.main() diff --git a/tensorflow/contrib/distributions/python/ops/student_t.py b/tensorflow/contrib/distributions/python/ops/student_t.py new file mode 100644 index 00000000000000..3e5d75b80108fb --- /dev/null +++ b/tensorflow/contrib/distributions/python/ops/student_t.py @@ -0,0 +1,284 @@ +# Copyright 2016 Google Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Student's t distribution class.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import math + +import numpy as np + +from tensorflow.contrib.distributions.python.ops.distribution import ContinuousDistribution # pylint: disable=line-too-long +from tensorflow.contrib.framework.python.framework import tensor_util as contrib_tensor_util # pylint: disable=line-too-long +from tensorflow.python.framework import ops +from tensorflow.python.framework import tensor_shape +from tensorflow.python.framework import tensor_util +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import check_ops +from tensorflow.python.ops import constant_op +from tensorflow.python.ops import math_ops +from tensorflow.python.ops import random_ops +from tensorflow.python.ops import special_math_ops + + +class StudentT(ContinuousDistribution): + """Student's t distribution with degree-of-freedom parameter df. + + #### Mathematical details + + The PDF of this distribution is: + + `f(t) = gamma((df+1)/2)/sqrt(df*pi)/gamma(df/2)*(1+t^2/df)^(-(df+1)/2)` + + #### Examples + + Examples of initialization of one or a batch of distributions. + + ```python + # Define a single scalar Student t distribution. + single_dist = tf.contrib.distributions.StudentT(df=3) + + # Evaluate the pdf at 1, returning a scalar Tensor. + single_dist.pdf(1.) + + # Define a batch of two scalar valued Student t's. + # The first has degrees of freedom 2, mean 1, and scale 11. + # The second 3, 2 and 22. + multi_dist = tf.contrib.distributions.StudentT(df=[2, 3], + mu=[1, 2.], + sigma=[11, 22.]) + + # Evaluate the pdf of the first distribution on 0, and the second on 1.5, + # returning a length two tensor. + multi_dist.pdf([0, 1.5]) + + # Get 3 samples, returning a 3 x 2 tensor. + multi_dist.sample(3) + ``` + + Arguments are broadcast when possible. + + ```python + # Define a batch of two Student's t distributions. + # Both have df 2 and mean 1, but different scales. + dist = tf.contrib.distributions.StudentT(df=2, mu=1, sigma=[11, 22.]) + + # Evaluate the pdf of both distributions on the same point, 3.0, + # returning a length 2 tensor. + dist.pdf(3.0) + ``` + """ + + def __init__(self, df, mu, sigma, name="StudentT"): + """Construct Student's t distributions. + + The distributions have degree of freedom `df`, mean `mu`, and scale `sigma`. + + The parameters `df`, `mu`, and `sigma` must be shaped in a way that supports + broadcasting (e.g. `df + mu + sigma` is a valid operation). + + Args: + df: `float` or `double` tensor, the degrees of freedom of the + distribution(s). `df` must contain only positive values. + mu: `float` or `double` tensor, the means of the distribution(s). + sigma: `float` or `double` tensor, the scaling factor for the + distribution(s). `sigma` must contain only positive values. + Note that `sigma` is not the standard deviation of this distribution. + name: The name to give Ops created by the initializer. + + Raises: + TypeError: if mu and sigma are different dtypes. + """ + super(StudentT, self).__init__() + with ops.op_scope([df, mu, sigma], name) as scope: + with ops.control_dependencies([check_ops.assert_positive(df), + check_ops.assert_positive(sigma)]): + self._df = ops.convert_to_tensor(df, name="df") + self._mu = ops.convert_to_tensor(mu, name="mu") + self._sigma = ops.convert_to_tensor(sigma, name="sigma") + contrib_tensor_util.assert_same_float_dtype( + (self._df, self._mu, self._sigma)) + self._name = scope + self._batch_shape = self._ones().get_shape() + self._event_shape = tensor_shape.TensorShape([]) + + @property + def name(self): + return self._name + + @property + def dtype(self): + return self._df.dtype + + @property + def df(self): + """Degrees of freedom in these Student's t distribution(s).""" + return self._df + + @property + def mu(self): + """Locations of these Student's t distribution(s).""" + return self._mu + + @property + def sigma(self): + """Scaling factors of these Student's t distribution(s).""" + return self._sigma + + @property + def mean(self, name="mean"): + with ops.name_scope(self.name): + return math_ops.mul(self._mu, self._ones(), name=name) + + @property + def variance(self, name="var"): + with ops.name_scope(self.name): + return math_ops.select( + (self._zeros() + self._df > 2), + self._zeros() + math_ops.square(self._sigma) * self._df / + (self._df - 2), + self._zeros() + np.inf, + name=name) + + def batch_shape(self, name="batch_shape"): + with ops.name_scope(self.name): + return array_ops.shape(self._ones(), name=name) + + def get_batch_shape(self): + return self._batch_shape + + def event_shape(self, name="event_shape"): + with ops.name_scope(self.name): + return constant_op.constant(1, name=name) + + def get_event_shape(self): + return self._event_shape + + def log_pdf(self, x, name="log_pdf"): + """Log pdf of observations in `x` under these Student's t-distribution(s). + + Args: + x: tensor of dtype `dtype`, must be broadcastable with `mu` and `df`. + name: The name to give this op. + + Returns: + log_pdf: tensor of dtype `dtype`, the log-PDFs of `x`. + """ + with ops.op_scope([self._df, self._mu, self._sigma, x], self.name): + with ops.name_scope(name): + x = ops.convert_to_tensor(x) + if x.dtype != self.dtype: + raise TypeError("Input x dtype does not match dtype: %s vs. %s" % + (x.dtype, self.dtype)) + df_2 = self._df / 2 + log_beta = (math_ops.lgamma(0.5) + math_ops.lgamma(df_2) - + math_ops.lgamma(0.5 + df_2)) + return (-math_ops.log(self._df) / 2 - log_beta - (self._df + 1) / 2 * + math_ops.log(1 + math_ops.square((x - self._mu) / self._sigma) / + self._df) - math_ops.log(self._sigma)) + + def pdf(self, x, name="pdf"): + """The PDF of observations in `x` under these Student's t distribution(s). + + Args: + x: tensor of dtype `dtype`, must be broadcastable with `df`, `mu`, and + `sigma`. + name: The name to give this op. + + Returns: + pdf: tensor of dtype `dtype`, the pdf values of `x`. + """ + with ops.op_scope([self._df, self._mu, self._sigma, x], self.name): + with ops.name_scope(name): + x = ops.convert_to_tensor(x) + if x.dtype != self.dtype: + raise TypeError("Input x dtype does not match dtype: %s vs. %s" % + (x.dtype, self.dtype)) + reloc_scaled = (x - self._mu) / self._sigma + return (math_ops.exp(math_ops.lgamma((self._df + 1) / 2) - + math_ops.lgamma(self._df / 2)) / + math_ops.sqrt(self._df) / math.sqrt(np.pi) * + math_ops.pow(1 + math_ops.square(reloc_scaled) / self._df, + -(self._df + 1) / 2) / self.sigma) + + def entropy(self, name="entropy"): + """The entropy of Student t distribution(s). + + Args: + name: The name to give this op. + + Returns: + entropy: tensor of dtype `dtype`, the entropy. + """ + with ops.op_scope([self._df, self._sigma], self.name): + with ops.name_scope(name): + u = array_ops.expand_dims(self._df + self._zeros(), -1) + v = array_ops.expand_dims(self._ones(), -1) + beta_arg = array_ops.concat(len(u.get_shape()) - 1, [u, v]) / 2 + return ((self._df + 1) / 2 * (math_ops.digamma((self._df + 1) / 2) - + math_ops.digamma(self._df / 2)) + + math_ops.log(self._df) / 2 + + special_math_ops.lbeta(beta_arg) + + math_ops.log(self._sigma)) + + def sample(self, n, seed=None, name="sample"): + """Sample `n` observations from the Student t Distributions. + + Args: + n: `Scalar`, type int32, the number of observations to sample. + seed: Python integer, the random seed. + name: The name to give this op. + + Returns: + samples: a `Tensor` of shape `(n,) + self.batch_shape + self.event_shape` + with values of type `self.dtype`. + """ + with ops.op_scope([self._df, self._mu, self._sigma, n], self.name): + with ops.name_scope(name): + n = ops.convert_to_tensor(n, name="n") + n_val = tensor_util.constant_value(n) + + # We use 2 uniform random floats to generate polar random variates. + # http://dl.acm.org/citation.cfm?id=179631 + # Theorem 2. Let G, H be iid variates, uniformly distributed on [0,1]. + # Let theta = 2*pi*H, let R = sqrt(df*(G^(-2/df) - 1)) for df > 0. + # Let X = R*cos(theta), and let Y = R*sin(theta). + # Then X ~ t_df and Y ~ t_df. + # The variates X and Y are not independent. + shape = array_ops.concat(0, [array_ops.pack([2, n]), + self.batch_shape()]) + uniform = random_ops.random_uniform(shape=shape, + dtype=self.dtype, + seed=seed) + samples_g, samples_h = array_ops.unpack(uniform, num=2) + theta = (2 * np.pi) * samples_h + r = math_ops.sqrt(self._df * + (math_ops.pow(samples_g, -2 / self._df) - 1)) + samples = r * math_ops.cos(theta) + + # Provide some hints to shape inference + inferred_shape = tensor_shape.vector(n_val).concatenate( + self.get_batch_shape()) + samples.set_shape(inferred_shape) + + return samples * self._sigma + self._mu + + def _ones(self): + return array_ops.ones_like(self._df + self._mu + self._sigma) + + def _zeros(self): + return array_ops.zeros_like(self._df + self._mu + self._sigma)