Skip to content

Commit

Permalink
DOC: cluster: Add 'See Also' and 'Examples' for kmeans2.
Browse files Browse the repository at this point in the history
  • Loading branch information
WarrenWeckesser committed Sep 1, 2019
1 parent 6cc0225 commit c31120a
Showing 1 changed file with 47 additions and 0 deletions.
47 changes: 47 additions & 0 deletions scipy/cluster/vq.py
Original file line number Diff line number Diff line change
Expand Up @@ -651,11 +651,58 @@ def kmeans2(data, k, iter=10, thresh=1e-5, minit='random',
label[i] is the code or index of the centroid the
i'th observation is closest to.
See Also
--------
kmeans
References
----------
.. [1] D. Arthur and S. Vassilvitskii, "k-means++: the advantages of
careful seeding", Proceedings of the Eighteenth Annual ACM-SIAM Symposium
on Discrete Algorithms, 2007.
Examples
--------
>>> from scipy.cluster.vq import kmeans2
>>> import matplotlib.pyplot as plt
Create z, an array with shape (100, 2) containing a mixture of samples
from three multivariate normal distributions.
>>> np.random.seed(12345678)
>>> a = np.random.multivariate_normal([0, 6], [[2, 1], [1, 1.5]], size=45)
>>> b = np.random.multivariate_normal([2, 0], [[1, -1], [-1, 3]], size=30)
>>> c = np.random.multivariate_normal([6, 4], [[5, 0], [0, 1.2]], size=25)
>>> z = np.concatenate((a, b, c))
>>> np.random.shuffle(z)
Compute three clusters.
>>> centroid, label = kmeans2(z, 3, minit='points')
>>> centroid
array([[-0.35770296, 5.31342524],
[ 2.32210289, -0.50551972],
[ 6.17653859, 4.16719247]])
How many points are in each cluster?
>>> counts = np.bincount(label)
>>> counts
array([52, 27, 21])
Plot the clusters.
>>> w0 = z[label == 0]
>>> w1 = z[label == 1]
>>> w2 = z[label == 2]
>>> plt.plot(w0[:, 0], w0[:, 1], 'o', alpha=0.5, label='cluster 0')
>>> plt.plot(w1[:, 0], w1[:, 1], 'd', alpha=0.5, label='cluster 1')
>>> plt.plot(w2[:, 0], w2[:, 1], 's', alpha=0.5, label='cluster 2')
>>> plt.plot(centroid[:, 0], centroid[:, 1], 'k*', label='centroids')
>>> plt.axis('equal')
>>> plt.legend(shadow=True)
>>> plt.show()
"""
if int(iter) < 1:
raise ValueError("Invalid iter (%s), "
Expand Down

0 comments on commit c31120a

Please sign in to comment.