Skip to content

Commit 2579841

Browse files
authored
FIX Using custom init for KMeans does a single init (scikit-learn#26657)
1 parent 4e88150 commit 2579841

File tree

3 files changed

+48
-5
lines changed

3 files changed

+48
-5
lines changed

doc/whats_new/v1.3.rst

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -231,6 +231,11 @@ Changelog
231231
:user:`Jérémie du Boisberranger <jeremiedbb>`,
232232
:user:`Guillaume Lemaitre <glemaitre>`.
233233

234+
- |Fix| :class:`cluster.KMeans`, :class:`cluster.MiniBatchKMeans` and
235+
:func:`cluster.k_means` now correctly handle the combination of `n_init="auto"`
236+
and `init` being an array-like, running one initialization in that case.
237+
:pr:`26657` by :user:`Binesh Bannerjee <bnsh>`.
238+
234239
- |API| The `sample_weight` parameter in `predict` for
235240
:meth:`cluster.KMeans.predict` and :meth:`cluster.MiniBatchKMeans.predict`
236241
is now deprecated and will be removed in v1.5.

sklearn/cluster/_kmeans.py

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -354,7 +354,8 @@ def k_means(
354354
n_init consecutive runs in terms of inertia.
355355
356356
When `n_init='auto'`, the number of runs depends on the value of init:
357-
10 if using `init='random'`, 1 if using `init='k-means++'`.
357+
10 if using `init='random'` or `init` is a callable;
358+
1 if using `init='k-means++'` or `init` is an array-like.
358359
359360
.. versionadded:: 1.2
360361
Added 'auto' option for `n_init`.
@@ -884,10 +885,14 @@ def _check_params_vs_input(self, X, default_n_init=None):
884885
)
885886
self._n_init = default_n_init
886887
if self._n_init == "auto":
887-
if self.init == "k-means++":
888+
if isinstance(self.init, str) and self.init == "k-means++":
888889
self._n_init = 1
889-
else:
890+
elif isinstance(self.init, str) and self.init == "random":
891+
self._n_init = default_n_init
892+
elif callable(self.init):
890893
self._n_init = default_n_init
894+
else: # array-like
895+
self._n_init = 1
891896

892897
if _is_arraylike_not_scalar(self.init) and self._n_init != 1:
893898
warnings.warn(
@@ -1241,7 +1246,8 @@ class KMeans(_BaseKMeans):
12411246
high-dimensional problems (see :ref:`kmeans_sparse_high_dim`).
12421247
12431248
When `n_init='auto'`, the number of runs depends on the value of init:
1244-
10 if using `init='random'`, 1 if using `init='k-means++'`.
1249+
10 if using `init='random'` or `init` is a callable;
1250+
1 if using `init='k-means++'` or `init` is an array-like.
12451251
12461252
.. versionadded:: 1.2
12471253
Added 'auto' option for `n_init`.
@@ -1777,7 +1783,8 @@ class MiniBatchKMeans(_BaseKMeans):
17771783
:ref:`kmeans_sparse_high_dim`).
17781784
17791785
When `n_init='auto'`, the number of runs depends on the value of init:
1780-
3 if using `init='random'`, 1 if using `init='k-means++'`.
1786+
3 if using `init='random'` or `init` is a callable;
1787+
1 if using `init='k-means++'` or `init` is an array-like.
17811788
17821789
.. versionadded:: 1.2
17831790
Added 'auto' option for `n_init`.

sklearn/cluster/tests/test_k_means.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -348,6 +348,37 @@ def test_minibatch_kmeans_partial_fit_init(init):
348348
_check_fitted_model(km)
349349

350350

351+
@pytest.mark.parametrize(
352+
"init, expected_n_init",
353+
[
354+
("k-means++", 1),
355+
("random", "default"),
356+
(
357+
lambda X, n_clusters, random_state: random_state.uniform(
358+
size=(n_clusters, X.shape[1])
359+
),
360+
"default",
361+
),
362+
("array-like", 1),
363+
],
364+
)
365+
@pytest.mark.parametrize("Estimator", [KMeans, MiniBatchKMeans])
366+
def test_kmeans_init_auto_with_initial_centroids(Estimator, init, expected_n_init):
367+
"""Check that `n_init="auto"` chooses the right number of initializations.
368+
Non-regression test for #26657:
369+
https://github.com/scikit-learn/scikit-learn/pull/26657
370+
"""
371+
n_sample, n_features, n_clusters = 100, 10, 5
372+
X = np.random.randn(n_sample, n_features)
373+
if init == "array-like":
374+
init = np.random.randn(n_clusters, n_features)
375+
if expected_n_init == "default":
376+
expected_n_init = 3 if Estimator is MiniBatchKMeans else 10
377+
378+
kmeans = Estimator(n_clusters=n_clusters, init=init, n_init="auto").fit(X)
379+
assert kmeans._n_init == expected_n_init
380+
381+
351382
@pytest.mark.parametrize("Estimator", [KMeans, MiniBatchKMeans])
352383
def test_fortran_aligned_data(Estimator, global_random_seed):
353384
# Check that KMeans works with fortran-aligned data.

0 commit comments

Comments
 (0)