Skip to content

Commit 335c2d2

Browse files
authored
TST Extend tests for scipy.sparse.*array in sklearn/cluster/tests/test_k_means.py (scikit-learn#27179)
1 parent 9f6592f commit 335c2d2

File tree

3 files changed

+87
-62
lines changed

3 files changed

+87
-62
lines changed

doc/whats_new/v1.4.rst

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -88,6 +88,7 @@ and classes are impacted:
8888

8989
- :func:`cluster.compute_optics_graph` in :pr:`27250` by
9090
:user:`Yao Xiao <Charlie-XIAO>`;
91+
- :func:`cluster.kmeans_plusplus` in :pr:`27179` by :user:`Nurseit Kamchyev <Bncer>`;
9192
- :func:`decomposition.non_negative_factorization` in :pr:`27100` by
9293
:user:`Isaac Virshup <ivirshup>`;
9394
- :func:`manifold.trustworthiness` in :pr:`27250` by :user:`Yao Xiao <Charlie-XIAO>`;
@@ -104,6 +105,8 @@ and classes are impacted:
104105
**Classes:**
105106

106107
- :class:`cluster.HDBSCAN` in :pr:`27250` by :user:`Yao Xiao <Charlie-XIAO>`;
108+
- :class:`cluster.KMeans` in :pr:`27179` by :user:`Nurseit Kamchyev <Bncer>`;
109+
- :class:`cluster.MiniBatchKMeans` in :pr:`27179` by :user:`Nurseit Kamchyev <Bncer>`;
107110
- :class:`cluster.OPTICS` in :pr:`27250` by :user:`Yao Xiao <Charlie-XIAO>`;
108111
- :class:`decomposition.NMF` in :pr:`27100` by :user:`Isaac Virshup <ivirshup>`;
109112
- :class:`decomposition.MiniBatchNMF` in :pr:`27100` by

sklearn/cluster/_kmeans.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -229,7 +229,7 @@ def _kmeans_plusplus(
229229
center_id = random_state.choice(n_samples, p=sample_weight / sample_weight.sum())
230230
indices = np.full(n_clusters, -1, dtype=int)
231231
if sp.issparse(X):
232-
centers[0] = X[center_id].toarray()
232+
centers[0] = X[[center_id]].toarray()
233233
else:
234234
centers[0] = X[center_id]
235235
indices[0] = center_id
@@ -268,7 +268,7 @@ def _kmeans_plusplus(
268268

269269
# Permanently add best center candidate found in local tries
270270
if sp.issparse(X):
271-
centers[c] = X[best_candidate].toarray()
271+
centers[c] = X[[best_candidate]].toarray()
272272
else:
273273
centers[c] = X[best_candidate]
274274
indices[c] = best_candidate

sklearn/cluster/tests/test_k_means.py

Lines changed: 82 additions & 60 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@
3131
create_memmap_backed_data,
3232
)
3333
from sklearn.utils.extmath import row_norms
34-
from sklearn.utils.fixes import threadpool_limits
34+
from sklearn.utils.fixes import CSR_CONTAINERS, threadpool_limits
3535

3636
# TODO(1.4): Remove
3737
msg = (
@@ -53,12 +53,16 @@
5353
X, true_labels = make_blobs(
5454
n_samples=n_samples, centers=centers, cluster_std=1.0, random_state=42
5555
)
56-
X_csr = sp.csr_matrix(X)
56+
X_as_any_csr = [container(X) for container in CSR_CONTAINERS]
57+
data_containers = [np.array] + CSR_CONTAINERS
58+
data_containers_ids = (
59+
["dense", "sparse_matrix", "sparse_array"]
60+
if len(X_as_any_csr) == 2
61+
else ["dense", "sparse_matrix"]
62+
)
5763

5864

59-
@pytest.mark.parametrize(
60-
"array_constr", [np.array, sp.csr_matrix], ids=["dense", "sparse"]
61-
)
65+
@pytest.mark.parametrize("array_constr", data_containers, ids=data_containers_ids)
6266
@pytest.mark.parametrize("algo", ["lloyd", "elkan"])
6367
@pytest.mark.parametrize("dtype", [np.float32, np.float64])
6468
def test_kmeans_results(array_constr, algo, dtype):
@@ -82,9 +86,7 @@ def test_kmeans_results(array_constr, algo, dtype):
8286
assert kmeans.n_iter_ == expected_n_iter
8387

8488

85-
@pytest.mark.parametrize(
86-
"array_constr", [np.array, sp.csr_matrix], ids=["dense", "sparse"]
87-
)
89+
@pytest.mark.parametrize("array_constr", data_containers, ids=data_containers_ids)
8890
@pytest.mark.parametrize("algo", ["lloyd", "elkan"])
8991
def test_kmeans_relocated_clusters(array_constr, algo):
9092
# check that empty clusters are relocated as expected
@@ -115,9 +117,7 @@ def test_kmeans_relocated_clusters(array_constr, algo):
115117
assert_allclose(kmeans.cluster_centers_, expected_centers)
116118

117119

118-
@pytest.mark.parametrize(
119-
"array_constr", [np.array, sp.csr_matrix], ids=["dense", "sparse"]
120-
)
120+
@pytest.mark.parametrize("array_constr", data_containers, ids=data_containers_ids)
121121
def test_relocate_empty_clusters(array_constr):
122122
# test for the _relocate_empty_clusters_(dense/sparse) helpers
123123

@@ -160,9 +160,7 @@ def test_relocate_empty_clusters(array_constr):
160160

161161

162162
@pytest.mark.parametrize("distribution", ["normal", "blobs"])
163-
@pytest.mark.parametrize(
164-
"array_constr", [np.array, sp.csr_matrix], ids=["dense", "sparse"]
165-
)
163+
@pytest.mark.parametrize("array_constr", data_containers, ids=data_containers_ids)
166164
@pytest.mark.parametrize("tol", [1e-2, 1e-8, 1e-100, 0])
167165
def test_kmeans_elkan_results(distribution, array_constr, tol, global_random_seed):
168166
# Check that results are identical between lloyd and elkan algorithms
@@ -238,7 +236,8 @@ def test_predict_sample_weight_deprecation_warning(Estimator):
238236
kmeans.predict(X, sample_weight=sample_weight)
239237

240238

241-
def test_minibatch_update_consistency(global_random_seed):
239+
@pytest.mark.parametrize("X_csr", X_as_any_csr)
240+
def test_minibatch_update_consistency(X_csr, global_random_seed):
242241
# Check that dense and sparse minibatch update give the same results
243242
rng = np.random.RandomState(global_random_seed)
244243

@@ -315,19 +314,23 @@ def _check_fitted_model(km):
315314
assert km.inertia_ > 0.0
316315

317316

318-
@pytest.mark.parametrize("data", [X, X_csr], ids=["dense", "sparse"])
317+
@pytest.mark.parametrize(
318+
"input_data",
319+
[X] + X_as_any_csr,
320+
ids=data_containers_ids,
321+
)
319322
@pytest.mark.parametrize(
320323
"init",
321324
["random", "k-means++", centers, lambda X, k, random_state: centers],
322325
ids=["random", "k-means++", "ndarray", "callable"],
323326
)
324327
@pytest.mark.parametrize("Estimator", [KMeans, MiniBatchKMeans])
325-
def test_all_init(Estimator, data, init):
328+
def test_all_init(Estimator, input_data, init):
326329
# Check KMeans and MiniBatchKMeans with all possible init.
327330
n_init = 10 if isinstance(init, str) else 1
328331
km = Estimator(
329332
init=init, n_clusters=n_clusters, random_state=42, n_init=n_init
330-
).fit(data)
333+
).fit(input_data)
331334
_check_fitted_model(km)
332335

333336

@@ -485,8 +488,12 @@ def test_minibatch_sensible_reassign(global_random_seed):
485488
assert km.cluster_centers_.any(axis=1).sum() > 10
486489

487490

488-
@pytest.mark.parametrize("data", [X, X_csr], ids=["dense", "sparse"])
489-
def test_minibatch_reassign(data, global_random_seed):
491+
@pytest.mark.parametrize(
492+
"input_data",
493+
[X] + X_as_any_csr,
494+
ids=data_containers_ids,
495+
)
496+
def test_minibatch_reassign(input_data, global_random_seed):
490497
# Check the reassignment part of the minibatch step with very high or very
491498
# low reassignment ratio.
492499
perfect_centers = np.empty((n_clusters, n_features))
@@ -499,10 +506,10 @@ def test_minibatch_reassign(data, global_random_seed):
499506
# Give a perfect initialization, but a large reassignment_ratio, as a
500507
# result many centers should be reassigned and the model should no longer
501508
# be good
502-
score_before = -_labels_inertia(data, sample_weight, perfect_centers, 1)[1]
509+
score_before = -_labels_inertia(input_data, sample_weight, perfect_centers, 1)[1]
503510

504511
_mini_batch_step(
505-
data,
512+
input_data,
506513
sample_weight,
507514
perfect_centers,
508515
centers_new,
@@ -512,14 +519,14 @@ def test_minibatch_reassign(data, global_random_seed):
512519
reassignment_ratio=1,
513520
)
514521

515-
score_after = -_labels_inertia(data, sample_weight, centers_new, 1)[1]
522+
score_after = -_labels_inertia(input_data, sample_weight, centers_new, 1)[1]
516523

517524
assert score_before > score_after
518525

519526
# Give a perfect initialization, with a small reassignment_ratio,
520527
# no center should be reassigned.
521528
_mini_batch_step(
522-
data,
529+
input_data,
523530
sample_weight,
524531
perfect_centers,
525532
centers_new,
@@ -641,9 +648,7 @@ def test_score_max_iter(Estimator, global_random_seed):
641648
assert s2 > s1
642649

643650

644-
@pytest.mark.parametrize(
645-
"array_constr", [np.array, sp.csr_matrix], ids=["dense", "sparse"]
646-
)
651+
@pytest.mark.parametrize("array_constr", data_containers, ids=data_containers_ids)
647652
@pytest.mark.parametrize(
648653
"Estimator, algorithm",
649654
[(KMeans, "lloyd"), (KMeans, "elkan"), (MiniBatchKMeans, None)],
@@ -684,8 +689,9 @@ def test_kmeans_predict(
684689
assert_array_equal(pred, np.arange(10))
685690

686691

692+
@pytest.mark.parametrize("X_csr", X_as_any_csr)
687693
@pytest.mark.parametrize("Estimator", [KMeans, MiniBatchKMeans])
688-
def test_dense_sparse(Estimator, global_random_seed):
694+
def test_dense_sparse(Estimator, X_csr, global_random_seed):
689695
# Check that the results are the same for dense and sparse input.
690696
sample_weight = np.random.RandomState(global_random_seed).random_sample(
691697
(n_samples,)
@@ -703,11 +709,12 @@ def test_dense_sparse(Estimator, global_random_seed):
703709
assert_allclose(km_dense.cluster_centers_, km_sparse.cluster_centers_)
704710

705711

712+
@pytest.mark.parametrize("X_csr", X_as_any_csr)
706713
@pytest.mark.parametrize(
707714
"init", ["random", "k-means++", centers], ids=["random", "k-means++", "ndarray"]
708715
)
709716
@pytest.mark.parametrize("Estimator", [KMeans, MiniBatchKMeans])
710-
def test_predict_dense_sparse(Estimator, init):
717+
def test_predict_dense_sparse(Estimator, init, X_csr):
711718
# check that models trained on sparse input also works for dense input at
712719
# predict time and vice versa.
713720
n_init = 10 if isinstance(init, str) else 1
@@ -720,9 +727,7 @@ def test_predict_dense_sparse(Estimator, init):
720727
assert_array_equal(km.predict(X_csr), km.labels_)
721728

722729

723-
@pytest.mark.parametrize(
724-
"array_constr", [np.array, sp.csr_matrix], ids=["dense", "sparse"]
725-
)
730+
@pytest.mark.parametrize("array_constr", data_containers, ids=data_containers_ids)
726731
@pytest.mark.parametrize("dtype", [np.int32, np.int64])
727732
@pytest.mark.parametrize("init", ["k-means++", "ndarray"])
728733
@pytest.mark.parametrize("Estimator", [KMeans, MiniBatchKMeans])
@@ -810,9 +815,13 @@ def test_k_means_function(global_random_seed):
810815
assert inertia > 0.0
811816

812817

813-
@pytest.mark.parametrize("data", [X, X_csr], ids=["dense", "sparse"])
818+
@pytest.mark.parametrize(
819+
"input_data",
820+
[X] + X_as_any_csr,
821+
ids=data_containers_ids,
822+
)
814823
@pytest.mark.parametrize("Estimator", [KMeans, MiniBatchKMeans])
815-
def test_float_precision(Estimator, data, global_random_seed):
824+
def test_float_precision(Estimator, input_data, global_random_seed):
816825
# Check that the results are the same for single and double precision.
817826
km = Estimator(n_init=1, random_state=global_random_seed)
818827

@@ -822,7 +831,7 @@ def test_float_precision(Estimator, data, global_random_seed):
822831
labels = {}
823832

824833
for dtype in [np.float64, np.float32]:
825-
X = data.astype(dtype, copy=False)
834+
X = input_data.astype(dtype, copy=False)
826835
km.fit(X)
827836

828837
inertia[dtype] = km.inertia_
@@ -863,12 +872,18 @@ def test_centers_not_mutated(Estimator, dtype):
863872
assert not np.may_share_memory(km.cluster_centers_, centers_new_type)
864873

865874

866-
@pytest.mark.parametrize("data", [X, X_csr], ids=["dense", "sparse"])
867-
def test_kmeans_init_fitted_centers(data):
875+
@pytest.mark.parametrize(
876+
"input_data",
877+
[X] + X_as_any_csr,
878+
ids=data_containers_ids,
879+
)
880+
def test_kmeans_init_fitted_centers(input_data):
868881
# Check that starting fitting from a local optimum shouldn't change the
869882
# solution
870-
km1 = KMeans(n_clusters=n_clusters).fit(data)
871-
km2 = KMeans(n_clusters=n_clusters, init=km1.cluster_centers_, n_init=1).fit(data)
883+
km1 = KMeans(n_clusters=n_clusters).fit(input_data)
884+
km2 = KMeans(n_clusters=n_clusters, init=km1.cluster_centers_, n_init=1).fit(
885+
input_data
886+
)
872887

873888
assert_allclose(km1.cluster_centers_, km2.cluster_centers_)
874889

@@ -920,31 +935,39 @@ def test_weighted_vs_repeated(global_random_seed):
920935
)
921936

922937

923-
@pytest.mark.parametrize("data", [X, X_csr], ids=["dense", "sparse"])
938+
@pytest.mark.parametrize(
939+
"input_data",
940+
[X] + X_as_any_csr,
941+
ids=data_containers_ids,
942+
)
924943
@pytest.mark.parametrize("Estimator", [KMeans, MiniBatchKMeans])
925-
def test_unit_weights_vs_no_weights(Estimator, data, global_random_seed):
944+
def test_unit_weights_vs_no_weights(Estimator, input_data, global_random_seed):
926945
# Check that not passing sample weights should be equivalent to passing
927946
# sample weights all equal to one.
928947
sample_weight = np.ones(n_samples)
929948

930949
km = Estimator(n_clusters=n_clusters, random_state=global_random_seed, n_init=1)
931-
km_none = clone(km).fit(data, sample_weight=None)
932-
km_ones = clone(km).fit(data, sample_weight=sample_weight)
950+
km_none = clone(km).fit(input_data, sample_weight=None)
951+
km_ones = clone(km).fit(input_data, sample_weight=sample_weight)
933952

934953
assert_array_equal(km_none.labels_, km_ones.labels_)
935954
assert_allclose(km_none.cluster_centers_, km_ones.cluster_centers_)
936955

937956

938-
@pytest.mark.parametrize("data", [X, X_csr], ids=["dense", "sparse"])
957+
@pytest.mark.parametrize(
958+
"input_data",
959+
[X] + X_as_any_csr,
960+
ids=data_containers_ids,
961+
)
939962
@pytest.mark.parametrize("Estimator", [KMeans, MiniBatchKMeans])
940-
def test_scaled_weights(Estimator, data, global_random_seed):
963+
def test_scaled_weights(Estimator, input_data, global_random_seed):
941964
# Check that scaling all sample weights by a common factor
942965
# shouldn't change the result
943966
sample_weight = np.random.RandomState(global_random_seed).uniform(size=n_samples)
944967

945968
km = Estimator(n_clusters=n_clusters, random_state=global_random_seed, n_init=1)
946-
km_orig = clone(km).fit(data, sample_weight=sample_weight)
947-
km_scaled = clone(km).fit(data, sample_weight=0.5 * sample_weight)
969+
km_orig = clone(km).fit(input_data, sample_weight=sample_weight)
970+
km_scaled = clone(km).fit(input_data, sample_weight=0.5 * sample_weight)
948971

949972
assert_array_equal(km_orig.labels_, km_scaled.labels_)
950973
assert_allclose(km_orig.cluster_centers_, km_scaled.cluster_centers_)
@@ -957,9 +980,7 @@ def test_kmeans_elkan_iter_attribute():
957980
assert km.n_iter_ == 1
958981

959982

960-
@pytest.mark.parametrize(
961-
"array_constr", [np.array, sp.csr_matrix], ids=["dense", "sparse"]
962-
)
983+
@pytest.mark.parametrize("array_constr", data_containers, ids=data_containers_ids)
963984
def test_kmeans_empty_cluster_relocated(array_constr):
964985
# check that empty clusters are correctly relocated when using sample
965986
# weights (#13486)
@@ -1005,9 +1026,7 @@ def test_warning_elkan_1_cluster():
10051026
KMeans(n_clusters=1, algorithm="elkan").fit(X)
10061027

10071028

1008-
@pytest.mark.parametrize(
1009-
"array_constr", [np.array, sp.csr_matrix], ids=["dense", "sparse"]
1010-
)
1029+
@pytest.mark.parametrize("array_constr", data_containers, ids=data_containers_ids)
10111030
@pytest.mark.parametrize("algo", ["lloyd", "elkan"])
10121031
def test_k_means_1_iteration(array_constr, algo, global_random_seed):
10131032
# check the results after a single iteration (E-step M-step E-step) by
@@ -1196,11 +1215,14 @@ def test_kmeans_plusplus_wrong_params(param, match):
11961215
kmeans_plusplus(X, n_clusters, **param)
11971216

11981217

1199-
@pytest.mark.parametrize("data", [X, X_csr])
1218+
@pytest.mark.parametrize(
1219+
"input_data",
1220+
[X] + X_as_any_csr,
1221+
)
12001222
@pytest.mark.parametrize("dtype", [np.float64, np.float32])
1201-
def test_kmeans_plusplus_output(data, dtype, global_random_seed):
1223+
def test_kmeans_plusplus_output(input_data, dtype, global_random_seed):
12021224
# Check for the correct number of seeds and all positive values
1203-
data = data.astype(dtype)
1225+
data = input_data.astype(dtype)
12041226
centers, indices = kmeans_plusplus(
12051227
data, n_clusters, random_state=global_random_seed
12061228
)
@@ -1289,15 +1311,15 @@ def test_feature_names_out(Klass, method):
12891311
assert_array_equal([f"{class_name}{i}" for i in range(n_clusters)], names_out)
12901312

12911313

1292-
@pytest.mark.parametrize("is_sparse", [True, False])
1293-
def test_predict_does_not_change_cluster_centers(is_sparse):
1314+
@pytest.mark.parametrize("csr_container", CSR_CONTAINERS + [None])
1315+
def test_predict_does_not_change_cluster_centers(csr_container):
12941316
"""Check that predict does not change cluster centers.
12951317
12961318
Non-regression test for gh-24253.
12971319
"""
12981320
X, _ = make_blobs(n_samples=200, n_features=10, centers=10, random_state=0)
1299-
if is_sparse:
1300-
X = sp.csr_matrix(X)
1321+
if csr_container is not None:
1322+
X = csr_container(X)
13011323

13021324
kmeans = KMeans()
13031325
y_pred1 = kmeans.fit_predict(X)

0 commit comments

Comments
 (0)