Skip to content

Commit

Permalink
EHN Change default value of n_init in KMeans (scikit-learn#23038)
Browse files Browse the repository at this point in the history
Co-authored-by: Muriel <[email protected]>
Co-authored-by: Thomas J. Fan <[email protected]>
Co-authored-by: Jérémie du Boisberranger <[email protected]>
  • Loading branch information
4 people authored May 25, 2022
1 parent d6d1405 commit 0d923da
Show file tree
Hide file tree
Showing 10 changed files with 119 additions and 25 deletions.
7 changes: 7 additions & 0 deletions doc/whats_new/v1.2.rst
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,13 @@ Changelog
- |Enhancement| :class:`cluster.Birch` now preserves dtype for `numpy.float32`
inputs. :pr:`22968` by `Meekail Zain <micky774>`.

- |Enhancement| :class:`cluster.KMeans` and :class:`cluster.MiniBatchKMeans`
now accept a new `'auto'` option for `n_init` which changes the number of
random initializations to one when using `init='k-means++'` for efficiency.
This begins deprecation for the default values of `n_init` in the two classes
and both will have their defaults changed to `n_init='auto'` in 1.4.
:pr:`23038` by :user:`Meekail Zain <micky774>`.

:mod:`sklearn.datasets`
.......................

Expand Down
83 changes: 65 additions & 18 deletions sklearn/cluster/_kmeans.py
Original file line number Diff line number Diff line change
Expand Up @@ -272,7 +272,7 @@ def k_means(
*,
sample_weight=None,
init="k-means++",
n_init=10,
n_init="warn",
max_iter=300,
verbose=False,
tol=1e-4,
Expand Down Expand Up @@ -314,10 +314,19 @@ def k_means(
- If a callable is passed, it should take arguments `X`, `n_clusters` and a
random state and return an initialization.
n_init : int, default=10
n_init : 'auto' or int, default=10
Number of time the k-means algorithm will be run with different
centroid seeds. The final results will be the best output of
`n_init` consecutive runs in terms of inertia.
n_init consecutive runs in terms of inertia.
When `n_init='auto'`, the number of runs will be 10 if using
`init='random'`, and 1 if using `init='kmeans++'`.
.. versionadded:: 1.2
Added 'auto' option for `n_init`.
.. versionchanged:: 1.4
Default value for `n_init` will change from 10 to `'auto'` in version 1.4.
max_iter : int, default=300
Maximum number of iterations of the k-means algorithm to run.
Expand Down Expand Up @@ -803,7 +812,10 @@ class _BaseKMeans(
_parameter_constraints = {
"n_clusters": [Interval(Integral, 1, None, closed="left")],
"init": [StrOptions({"k-means++", "random"}), callable, "array-like"],
"n_init": [Interval(Integral, 1, None, closed="left")],
"n_init": [
StrOptions({"auto", "warn"}),
Interval(Integral, 1, None, closed="left"),
],
"max_iter": [Interval(Integral, 1, None, closed="left")],
"tol": [Interval(Real, 0, None, closed="left")],
"verbose": [Interval(Integral, 0, None, closed="left"), bool],
Expand All @@ -829,7 +841,7 @@ def __init__(
self.verbose = verbose
self.random_state = random_state

def _check_params_vs_input(self, X):
def _check_params_vs_input(self, X, default_n_init=None):
# n_clusters
if X.shape[0] < self.n_clusters:
raise ValueError(
Expand All @@ -839,8 +851,23 @@ def _check_params_vs_input(self, X):
# tol
self._tol = _tolerance(X, self.tol)

# init
# n-init
# TODO(1.4): Remove
self._n_init = self.n_init
if self._n_init == "warn":
warnings.warn(
"The default value of `n_init` will change from "
f"{default_n_init} to 'auto' in 1.4. Set the value of `n_init`"
" explicitly to suppress the warning",
FutureWarning,
)
self._n_init = default_n_init
if self._n_init == "auto":
if self.init == "k-means++":
self._n_init = 1
else:
self._n_init = default_n_init

if _is_arraylike_not_scalar(self.init) and self._n_init != 1:
warnings.warn(
"Explicit initial center position passed: performing only"
Expand Down Expand Up @@ -1150,11 +1177,20 @@ class KMeans(_BaseKMeans):
If a callable is passed, it should take arguments X, n_clusters and a
random state and return an initialization.
n_init : int, default=10
n_init : 'auto' or int, default=10
Number of time the k-means algorithm will be run with different
centroid seeds. The final results will be the best output of
n_init consecutive runs in terms of inertia.
When `n_init='auto'`, the number of runs will be 10 if using
`init='random'`, and 1 if using `init='kmeans++'`.
.. versionadded:: 1.2
Added 'auto' option for `n_init`.
.. versionchanged:: 1.4
Default value for `n_init` will change from 10 to `'auto'` in version 1.4.
max_iter : int, default=300
Maximum number of iterations of the k-means algorithm for a
single run.
Expand Down Expand Up @@ -1263,7 +1299,7 @@ class KMeans(_BaseKMeans):
>>> import numpy as np
>>> X = np.array([[1, 2], [1, 4], [1, 0],
... [10, 2], [10, 4], [10, 0]])
>>> kmeans = KMeans(n_clusters=2, random_state=0).fit(X)
>>> kmeans = KMeans(n_clusters=2, random_state=0, n_init="auto").fit(X)
>>> kmeans.labels_
array([1, 1, 1, 0, 0, 0], dtype=int32)
>>> kmeans.predict([[0, 0], [12, 3]])
Expand All @@ -1286,7 +1322,7 @@ def __init__(
n_clusters=8,
*,
init="k-means++",
n_init=10,
n_init="warn",
max_iter=300,
tol=1e-4,
verbose=0,
Expand All @@ -1308,7 +1344,7 @@ def __init__(
self.algorithm = algorithm

def _check_params_vs_input(self, X):
super()._check_params_vs_input(X)
super()._check_params_vs_input(X, default_n_init=10)

self._algorithm = self.algorithm
if self._algorithm in ("auto", "full"):
Expand Down Expand Up @@ -1667,11 +1703,20 @@ class MiniBatchKMeans(_BaseKMeans):
If `None`, the heuristic is `init_size = 3 * batch_size` if
`3 * batch_size < n_clusters`, else `init_size = 3 * n_clusters`.
n_init : int, default=3
n_init : 'auto' or int, default=3
Number of random initializations that are tried.
In contrast to KMeans, the algorithm is only run once, using the
best of the ``n_init`` initializations as measured by inertia.
When `n_init='auto'`, the number of runs will be 3 if using
`init='random'`, and 1 if using `init='kmeans++'`.
.. versionadded:: 1.2
Added 'auto' option for `n_init`.
.. versionchanged:: 1.4
Default value for `n_init` will change from 3 to `'auto'` in version 1.4.
reassignment_ratio : float, default=0.01
Control the fraction of the maximum number of counts for a center to
be reassigned. A higher value means that low count centers are more
Expand Down Expand Up @@ -1737,7 +1782,8 @@ class MiniBatchKMeans(_BaseKMeans):
>>> # manually fit on batches
>>> kmeans = MiniBatchKMeans(n_clusters=2,
... random_state=0,
... batch_size=6)
... batch_size=6,
... n_init="auto")
>>> kmeans = kmeans.partial_fit(X[0:6,:])
>>> kmeans = kmeans.partial_fit(X[6:12,:])
>>> kmeans.cluster_centers_
Expand All @@ -1749,12 +1795,13 @@ class MiniBatchKMeans(_BaseKMeans):
>>> kmeans = MiniBatchKMeans(n_clusters=2,
... random_state=0,
... batch_size=6,
... max_iter=10).fit(X)
... max_iter=10,
... n_init="auto").fit(X)
>>> kmeans.cluster_centers_
array([[1.19..., 1.22...],
[4.03..., 2.46...]])
array([[3.97727273, 2.43181818],
[1.125 , 1.6 ]])
>>> kmeans.predict([[0, 0], [4, 4]])
array([0, 1], dtype=int32)
array([1, 0], dtype=int32)
"""

_parameter_constraints = {
Expand All @@ -1779,7 +1826,7 @@ def __init__(
tol=0.0,
max_no_improvement=10,
init_size=None,
n_init=3,
n_init="warn",
reassignment_ratio=0.01,
):

Expand All @@ -1800,7 +1847,7 @@ def __init__(
self.reassignment_ratio = reassignment_ratio

def _check_params_vs_input(self, X):
super()._check_params_vs_input(X)
super()._check_params_vs_input(X, default_n_init=3)

self._batch_size = min(self.batch_size, X.shape[0])

Expand Down
36 changes: 36 additions & 0 deletions sklearn/cluster/tests/test_k_means.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""Testing for K-means"""
import re
import sys
import warnings

import numpy as np
from scipy import sparse as sp
Expand Down Expand Up @@ -31,6 +32,12 @@
from sklearn.datasets import make_blobs
from io import StringIO

# TODO(1.4): Remove
msg = (
r"The default value of `n_init` will change from \d* to 'auto' in 1.4. Set the"
r" value of `n_init` explicitly to suppress the warning:FutureWarning"
)
pytestmark = pytest.mark.filterwarnings("ignore:" + msg)

# non centered, sparse centers to check the
centers = np.array(
Expand Down Expand Up @@ -1029,6 +1036,35 @@ def test_inertia(dtype):
assert_allclose(inertia_sparse, expected, rtol=1e-6)


# TODO(1.4): Remove
@pytest.mark.parametrize("Klass, default_n_init", [(KMeans, 10), (MiniBatchKMeans, 3)])
def test_change_n_init_future_warning(Klass, default_n_init):
est = Klass(n_init=1)
with warnings.catch_warnings():
warnings.simplefilter("error", FutureWarning)
est.fit(X)

default_n_init = 10 if Klass.__name__ == "KMeans" else 3
msg = (
f"The default value of `n_init` will change from {default_n_init} to 'auto'"
" in 1.4"
)
est = Klass()
with pytest.warns(FutureWarning, match=msg):
est.fit(X)


@pytest.mark.parametrize("Klass, default_n_init", [(KMeans, 10), (MiniBatchKMeans, 3)])
def test_n_init_auto(Klass, default_n_init):
est = Klass(n_init="auto", init="k-means++")
est.fit(X)
assert est._n_init == 1

est = Klass(n_init="auto", init="random")
est.fit(X)
assert est._n_init == 10 if Klass.__name__ == "KMeans" else 3


@pytest.mark.parametrize("Estimator", [KMeans, MiniBatchKMeans])
def test_sample_weight_unchanged(Estimator):
# Check that sample_weight is not modified in place by KMeans (#17204)
Expand Down
2 changes: 1 addition & 1 deletion sklearn/inspection/tests/test_partial_dependence.py
Original file line number Diff line number Diff line change
Expand Up @@ -427,7 +427,7 @@ def fit(self, X, y):
"estimator, params, err_msg",
[
(
KMeans(),
KMeans(random_state=0, n_init="auto"),
{"features": [0]},
"'estimator' must be a fitted regressor or classifier",
),
Expand Down
2 changes: 1 addition & 1 deletion sklearn/manifold/tests/test_spectral_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -329,7 +329,7 @@ def test_pipeline_spectral_clustering(seed=36):
random_state=random_state,
)
for se in [se_rbf, se_knn]:
km = KMeans(n_clusters=n_clusters, random_state=random_state)
km = KMeans(n_clusters=n_clusters, random_state=random_state, n_init="auto")
km.fit(se.fit_transform(S))
assert_array_almost_equal(
normalized_mutual_info_score(km.labels_, true_labels), 1.0, 2
Expand Down
2 changes: 1 addition & 1 deletion sklearn/metrics/tests/test_score_objects.py
Original file line number Diff line number Diff line change
Expand Up @@ -577,7 +577,7 @@ def test_supervised_cluster_scorers():
# Test clustering scorers against gold standard labeling.
X, y = make_blobs(random_state=0, centers=2)
X_train, X_test, y_train, y_test = train_test_split(X, y, random_state=0)
km = KMeans(n_clusters=3)
km = KMeans(n_clusters=3, n_init="auto")
km.fit(X_train)
for name in CLUSTER_SCORERS:
score1 = get_scorer(name)(km, X_test, y_test)
Expand Down
2 changes: 1 addition & 1 deletion sklearn/model_selection/tests/test_validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -930,7 +930,7 @@ def test_cross_val_predict():
preds = cross_val_predict(est, Xsp, y)
assert_array_almost_equal(len(preds), len(y))

preds = cross_val_predict(KMeans(), X)
preds = cross_val_predict(KMeans(n_init="auto"), X)
assert len(preds) == len(y)

class BadCV:
Expand Down
2 changes: 1 addition & 1 deletion sklearn/tests/test_discriminant_analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,7 +143,7 @@ def test_lda_predict():

# test bad covariance estimator
clf = LinearDiscriminantAnalysis(
solver="lsqr", covariance_estimator=KMeans(n_clusters=2)
solver="lsqr", covariance_estimator=KMeans(n_clusters=2, n_init="auto")
)
with pytest.raises(
ValueError, match="KMeans does not have a covariance_ attribute"
Expand Down
4 changes: 4 additions & 0 deletions sklearn/tests/test_docstring_parameters.py
Original file line number Diff line number Diff line change
Expand Up @@ -262,6 +262,10 @@ def test_fit_docstring_attributes(name, Estimator):
if Estimator.__name__ == "MiniBatchDictionaryLearning":
est.set_params(batch_size=5)

# TODO(1.4): TO BE REMOVED for 1.4 (avoid FutureWarning)
if Estimator.__name__ in ("KMeans", "MiniBatchKMeans"):
est.set_params(n_init="auto")

# In case we want to deprecate some attributes in the future
skipped_attributes = {}

Expand Down
4 changes: 2 additions & 2 deletions sklearn/tests/test_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -422,11 +422,11 @@ def test_fit_predict_on_pipeline():
# test that the fit_predict on pipeline yields same results as applying
# transform and clustering steps separately
scaler = StandardScaler()
km = KMeans(random_state=0)
km = KMeans(random_state=0, n_init="auto")
# As pipeline doesn't clone estimators on construction,
# it must have its own estimators
scaler_for_pipeline = StandardScaler()
km_for_pipeline = KMeans(random_state=0)
km_for_pipeline = KMeans(random_state=0, n_init="auto")

# first compute the transform and clustering step separately
scaled = scaler.fit_transform(iris.data)
Expand Down

0 comments on commit 0d923da

Please sign in to comment.