Skip to content

Commit 6c18131

Browse files
TST Extend tests for scipy.sparse.*array in sklearn/neighbors/tests/test_nearest_centroid.py (scikit-learn#27132)
Co-authored-by: Omar Salman <omar.salman@arbisoft.com>
1 parent 474cd98 commit 6c18131

File tree

1 file changed

+9
-5
lines changed

1 file changed

+9
-5
lines changed

sklearn/neighbors/tests/test_nearest_centroid.py

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -4,17 +4,15 @@
44
import numpy as np
55
import pytest
66
from numpy.testing import assert_array_equal
7-
from scipy import sparse as sp
87

98
from sklearn import datasets
109
from sklearn.neighbors import NearestCentroid
10+
from sklearn.utils.fixes import CSR_CONTAINERS
1111

1212
# toy sample
1313
X = [[-2, -1], [-1, -1], [-1, -2], [1, 1], [1, 2], [2, 1]]
14-
X_csr = sp.csr_matrix(X) # Sparse matrix
1514
y = [-1, -1, -1, 1, 1, 1]
1615
T = [[-1, -1], [2, 2], [3, 2]]
17-
T_csr = sp.csr_matrix(T)
1816
true_result = [-1, 1, 1]
1917

2018
# also load the iris dataset
@@ -26,8 +24,12 @@
2624
iris.target = iris.target[perm]
2725

2826

29-
def test_classification_toy():
27+
@pytest.mark.parametrize("csr_container", CSR_CONTAINERS)
28+
def test_classification_toy(csr_container):
3029
# Check classification on a toy dataset, including sparse versions.
30+
X_csr = csr_container(X)
31+
T_csr = csr_container(T)
32+
3133
clf = NearestCentroid()
3234
clf.fit(X, y)
3335
assert_array_equal(clf.predict(T), true_result)
@@ -135,8 +137,10 @@ def test_predict_translated_data():
135137
assert_array_equal(y_init, y_translate)
136138

137139

138-
def test_manhattan_metric():
140+
@pytest.mark.parametrize("csr_container", CSR_CONTAINERS)
141+
def test_manhattan_metric(csr_container):
139142
# Test the manhattan metric.
143+
X_csr = csr_container(X)
140144

141145
clf = NearestCentroid(metric="manhattan")
142146
clf.fit(X, y)

0 commit comments

Comments
 (0)