Skip to content

Commit 8eef53f

Browse files
authored
TST Extend tests for scipy.sparse.*array in sklearn/feature_selection/tests/test_rfe.py (scikit-learn#27177)
1 parent 11cf0ee commit 8eef53f

File tree

1 file changed

+9
-7
lines changed

1 file changed

+9
-7
lines changed

sklearn/feature_selection/tests/test_rfe.py

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@
77
import numpy as np
88
import pytest
99
from numpy.testing import assert_allclose, assert_array_almost_equal, assert_array_equal
10-
from scipy import sparse
1110

1211
from sklearn.base import BaseEstimator, ClassifierMixin
1312
from sklearn.compose import TransformedTargetRegressor
@@ -23,6 +22,7 @@
2322
from sklearn.svm import SVC, SVR, LinearSVR
2423
from sklearn.utils import check_random_state
2524
from sklearn.utils._testing import ignore_warnings
25+
from sklearn.utils.fixes import CSR_CONTAINERS
2626

2727

2828
class MockClassifier:
@@ -79,13 +79,14 @@ def test_rfe_features_importance():
7979
assert_array_equal(rfe.get_support(), rfe_svc.get_support())
8080

8181

82-
def test_rfe():
82+
@pytest.mark.parametrize("csr_container", CSR_CONTAINERS)
83+
def test_rfe(csr_container):
8384
generator = check_random_state(0)
8485
iris = load_iris()
8586
# Add some irrelevant features. Random seed is set to make sure that
8687
# irrelevant features are always irrelevant.
8788
X = np.c_[iris.data, generator.normal(size=(len(iris.data), 6))]
88-
X_sparse = sparse.csr_matrix(X)
89+
X_sparse = csr_container(X)
8990
y = iris.target
9091

9192
# dense model
@@ -173,7 +174,8 @@ def test_rfe_mockclassifier():
173174
assert X_r.shape == iris.data.shape
174175

175176

176-
def test_rfecv():
177+
@pytest.mark.parametrize("csr_container", CSR_CONTAINERS)
178+
def test_rfecv(csr_container):
177179
generator = check_random_state(0)
178180
iris = load_iris()
179181
# Add some irrelevant features. Random seed is set to make sure that
@@ -197,7 +199,7 @@ def test_rfecv():
197199

198200
# same in sparse
199201
rfecv_sparse = RFECV(estimator=SVC(kernel="linear"), step=1)
200-
X_sparse = sparse.csr_matrix(X)
202+
X_sparse = csr_container(X)
201203
rfecv_sparse.fit(X_sparse, y)
202204
X_r_sparse = rfecv_sparse.transform(X_sparse)
203205
assert_array_equal(X_r_sparse.toarray(), iris.data)
@@ -241,14 +243,14 @@ def test_scorer(estimator, X, y):
241243
assert_array_equal(X_r, iris.data)
242244

243245
rfecv_sparse = RFECV(estimator=SVC(kernel="linear"), step=2)
244-
X_sparse = sparse.csr_matrix(X)
246+
X_sparse = csr_container(X)
245247
rfecv_sparse.fit(X_sparse, y)
246248
X_r_sparse = rfecv_sparse.transform(X_sparse)
247249
assert_array_equal(X_r_sparse.toarray(), iris.data)
248250

249251
# Verifying that steps < 1 don't blow up.
250252
rfecv_sparse = RFECV(estimator=SVC(kernel="linear"), step=0.2)
251-
X_sparse = sparse.csr_matrix(X)
253+
X_sparse = csr_container(X)
252254
rfecv_sparse.fit(X_sparse, y)
253255
X_r_sparse = rfecv_sparse.transform(X_sparse)
254256
assert_array_equal(X_r_sparse.toarray(), iris.data)

0 commit comments

Comments
 (0)