Skip to content

Commit

Permalink
FIX numerical stability in GMM with eigh sampling
Browse files Browse the repository at this point in the history
  • Loading branch information
larsmans committed May 15, 2014
1 parent 96e9275 commit c6b7e3b
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 2 deletions.
6 changes: 4 additions & 2 deletions sklearn/mixture/gmm.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,8 +98,10 @@ def sample_gaussian(mean, covar, covariance_type='diag', n_samples=1,
rand = np.dot(np.diag(np.sqrt(covar)), rand)
else:
s, U = linalg.eigh(covar)
sqrt_covar = U * np.sqrt(s)
rand = np.dot(sqrt_covar, rand)
np.abs(s, out=s) # get rid of tiny negatives
np.sqrt(s, out=s)
U *= s
rand = np.dot(U, rand)

return (rand.T + mean).T

Expand Down
8 changes: 8 additions & 0 deletions sklearn/mixture/tests/test_gmm.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,14 @@ def test_sample_gaussian():
assert_true(np.allclose(samples.mean(axis), mu, atol=1.3))
assert_true(np.allclose(np.cov(samples), cv, atol=2.5))

# Numerical stability check: in SciPy 0.12.0 at least, eigh may return
# tiny negative values in its second return value.
from sklearn.mixture import sample_gaussian
x = sample_gaussian([0, 0], [[4, 3], [1, .1]],
covariance_type='full', random_state=42)
print(x)
assert_true(np.isfinite(x).all())


def _naive_lmvnpdf_diag(X, mu, cv):
# slow and naive implementation of lmvnpdf
Expand Down

0 comments on commit c6b7e3b

Please sign in to comment.