Skip to content

Commit

Permalink
Merge pull request scipy#9037 from vyasr/kmeans
Browse files Browse the repository at this point in the history
ENH: add a new init method for k-means
  • Loading branch information
tylerjereddy authored Nov 6, 2018
2 parents f0bb86c + f56452b commit f79b239
Show file tree
Hide file tree
Showing 2 changed files with 87 additions and 20 deletions.
26 changes: 16 additions & 10 deletions scipy/cluster/tests/test_vq.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,13 @@

import numpy as np
from numpy.testing import (assert_array_equal, assert_array_almost_equal,
assert_allclose, assert_equal, assert_)
assert_allclose, assert_equal, assert_)
from scipy._lib._numpy_compat import suppress_warnings
import pytest
from pytest import raises as assert_raises

from scipy.cluster.vq import (kmeans, kmeans2, py_vq, vq, whiten,
ClusterError, _krandinit)
ClusterError, _krandinit)
from scipy.cluster import _vq


Expand Down Expand Up @@ -58,8 +58,8 @@

# Global data
X = np.array([[3.0, 3], [4, 3], [4, 2],
[9, 2], [5, 1], [6, 2], [9, 4],
[5, 2], [5, 4], [7, 4], [6, 5]])
[9, 2], [5, 1], [6, 2], [9, 4],
[5, 2], [5, 4], [7, 4], [6, 5]])

CODET1 = np.array([[3.0000, 3.0000],
[6.2000, 4.0000],
Expand Down Expand Up @@ -201,7 +201,7 @@ def test_kmeans_lost_cluster(self):
data = TESTDATA_2D
initk = np.array([[-1.8127404, -0.67128041],
[2.04621601, 0.07401111],
[-2.31149087,-0.05160469]])
[-2.31149087, -0.05160469]])

kmeans(data, initk)
with suppress_warnings() as sup:
Expand Down Expand Up @@ -250,10 +250,17 @@ def test_kmeans2_init(self):
kmeans2(data, 3, minit='points')
kmeans2(data[:, :1], 3, minit='points') # special case (1-D)

kmeans2(data, 3, minit='random')
kmeans2(data[:, :1], 3, minit='random') # special case (1-D)
kmeans2(data, 3, minit='++')
kmeans2(data[:, :1], 3, minit='++') # special case (1-D)

@pytest.mark.skipif(sys.platform == 'win32', reason='Fails with MemoryError in Wine.')
# minit='random' can give warnings, filter those
with suppress_warnings() as sup:
sup.filter(message="One of the clusters is empty. Re-run")
kmeans2(data, 3, minit='random')
kmeans2(data[:, :1], 3, minit='random') # special case (1-D)

@pytest.mark.skipif(sys.platform == 'win32',
reason='Fails with MemoryError in Wine.')
def test_krandinit(self):
data = TESTDATA_2D
datas = [data.reshape((200, 2)), data.reshape((20, 20))[:10]]
Expand All @@ -277,8 +284,7 @@ def test_kmeans_0k(self):

def test_kmeans_large_thres(self):
# Regression test for gh-1774
x = np.array([1,2,3,4,10], dtype=float)
x = np.array([1, 2, 3, 4, 10], dtype=float)
res = kmeans(x, 1, thresh=1e16)
assert_allclose(res[0], np.array([4.]))
assert_allclose(res[1], 2.3999999999999999)

81 changes: 71 additions & 10 deletions scipy/cluster/vq.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,10 +28,10 @@
any other centroids. If v belongs to i, we say centroid i is the
dominating centroid of v. The k-means algorithm tries to
minimize distortion, which is defined as the sum of the squared distances
between each observation vector and its dominating centroid.
between each observation vector and its dominating centroid.
The minimization is achieved by iteratively reclassifying
the observations into clusters and recalculating the centroids until
a configuration is reached in which the centroids are stable. One can
a configuration is reached in which the centroids are stable. One can
also define a maximum number of iterations.
Since vector quantization is a natural application for k-means,
Expand Down Expand Up @@ -323,10 +323,10 @@ def kmeans(obs, k_or_guess, iter=20, thresh=1e-5, check_finite=True):
The k-means algorithm adjusts the classification of the observations
into clusters and updates the cluster centroids until the position of
the centroids is stable over successive iterations. In this
implementation of the algorithm, the stability of the centroids is
determined by comparing the absolute value of the change in the average
Euclidean distance between the observations and their corresponding
the centroids is stable over successive iterations. In this
implementation of the algorithm, the stability of the centroids is
determined by comparing the absolute value of the change in the average
Euclidean distance between the observations and their corresponding
centroids against a threshold. This yields
a code book mapping centroids to codes and vice versa.
Expand Down Expand Up @@ -373,9 +373,9 @@ def kmeans(obs, k_or_guess, iter=20, thresh=1e-5, check_finite=True):
not necessarily the globally minimal distortion.
distortion : float
The mean (non-squared) Euclidean distance between the observations
The mean (non-squared) Euclidean distance between the observations
passed and the centroids generated. Note the difference to the standard
definition of distortion in the context of the K-means algorithm, which
definition of distortion in the context of the K-means algorithm, which
is the sum of the squared distances.
See Also
Expand Down Expand Up @@ -474,6 +474,11 @@ def _kpoints(data, k):
k : int
Number of samples to generate.
Returns
-------
x : ndarray
A 'k' by 'N' containing the initial centroids
"""
idx = np.random.choice(data.shape[0], size=k, replace=False)
return data[idx]
Expand All @@ -494,6 +499,11 @@ def _krandinit(data, k):
k : int
Number of samples to generate.
Returns
-------
x : ndarray
A 'k' by 'N' containing the initial centroids
"""
mu = data.mean(axis=0)

Expand All @@ -519,7 +529,50 @@ def _krandinit(data, k):
return x


_valid_init_meth = {'random': _krandinit, 'points': _kpoints}
def _kpp(data, k):
""" Picks k points in data based on the kmeans++ method
Parameters
----------
data : ndarray
Expect a rank 1 or 2 array. Rank 1 are assumed to describe one
dimensional data, rank 2 multidimensional data, in which case one
row is one observation.
k : int
Number of samples to generate.
Returns
-------
init : ndarray
A 'k' by 'N' containing the initial centroids
References
----------
.. [1] D. Arthur and S. Vassilvitskii, "k-means++: the advantages of
careful seeding", Proceedings of the Eighteenth Annual ACM-SIAM Symposium
on Discrete Algorithms, 2007.
"""

dims = data.shape[1] if len(data.shape) > 1 else 1
init = np.ndarray((k, dims))

for i in range(k):
if i == 0:
init[i, :] = data[np.random.randint(dims)]

else:
D2 = np.array([min(
[np.inner(init[j]-x, init[j]-x) for j in range(i)]
) for x in data])
probs = D2/D2.sum()
cumprobs = probs.cumsum()
r = np.random.rand()
init[i, :] = data[np.searchsorted(cumprobs, r)]

return init


_valid_init_meth = {'random': _krandinit, 'points': _kpoints, '++': _kpp}


def _missing_warn():
Expand Down Expand Up @@ -564,14 +617,17 @@ def kmeans2(data, k, iter=10, thresh=1e-5, minit='random',
(not used yet)
minit : str, optional
Method for initialization. Available methods are 'random',
'points', and 'matrix':
'points', '++' and 'matrix':
'random': generate k centroids from a Gaussian with mean and
variance estimated from the data.
'points': choose k observations (rows) at random from data for
the initial centroids.
'++': choose k observations accordingly to the kmeans++ method
(careful seeding)
'matrix': interpret the k parameter as a k by M (or length k
array for one-dimensional data) array of initial centroids.
missing : str, optional
Expand All @@ -596,6 +652,11 @@ def kmeans2(data, k, iter=10, thresh=1e-5, minit='random',
label[i] is the code or index of the centroid the
i'th observation is closest to.
References
----------
.. [1] D. Arthur and S. Vassilvitskii, "k-means++: the advantages of
careful seeding", Proceedings of the Eighteenth Annual ACM-SIAM Symposium
on Discrete Algorithms, 2007.
"""
if int(iter) < 1:
raise ValueError("Invalid iter (%s), "
Expand Down

0 comments on commit f79b239

Please sign in to comment.