Skip to content

Commit 28a0d56

Browse files
AishwaryaRKTomDLT
authored andcommitted
Improve error message for MeanShift.estimate_bandwidth when X is sparse scikit-learn#8627 (scikit-learn#8771)
* raise error on sparse matrix in estimate_bandwidth of mean_shift * Remove added newline
1 parent 398ffed commit 28a0d56

File tree

2 files changed

+11
-0
lines changed

2 files changed

+11
-0
lines changed

sklearn/cluster/mean_shift_.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,8 @@ def estimate_bandwidth(X, quantile=0.3, n_samples=None, random_state=0,
6262
bandwidth : float
6363
The bandwidth parameter.
6464
"""
65+
X = check_array(X)
66+
6567
random_state = check_random_state(random_state)
6668
if n_samples is not None:
6769
idx = random_state.permutation(X.shape[0])[:n_samples]

sklearn/cluster/tests/test_mean_shift.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,8 @@
66
import numpy as np
77
import warnings
88

9+
from scipy import sparse
10+
911
from sklearn.utils.testing import assert_equal
1012
from sklearn.utils.testing import assert_false
1113
from sklearn.utils.testing import assert_true
@@ -47,6 +49,13 @@ def test_mean_shift():
4749
assert_equal(n_clusters_, n_clusters)
4850

4951

52+
def test_estimate_bandwidth_with_sparse_matrix():
53+
# Test estimate_bandwidth with sparse matrix
54+
X = sparse.lil_matrix((1000, 1000))
55+
msg = "A sparse matrix was passed, but dense data is required."
56+
assert_raise_message(TypeError, msg, estimate_bandwidth, X, 200)
57+
58+
5059
def test_parallel():
5160
ms1 = MeanShift(n_jobs=2)
5261
ms1.fit(X)

0 commit comments

Comments
 (0)