Skip to content

Commit dea72cc

Browse files
authored
TST Extend tests for scipy.sparse.*array in sklearn/tests/test_multioutput.py (scikit-learn#27171)
1 parent 407070b commit dea72cc

File tree

1 file changed

+31
-18
lines changed

1 file changed

+31
-18
lines changed

sklearn/tests/test_multioutput.py

Lines changed: 31 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@
22

33
import numpy as np
44
import pytest
5-
import scipy.sparse as sp
65
from joblib import cpu_count
76

87
from sklearn import datasets
@@ -49,6 +48,14 @@
4948
assert_array_almost_equal,
5049
assert_array_equal,
5150
)
51+
from sklearn.utils.fixes import (
52+
BSR_CONTAINERS,
53+
COO_CONTAINERS,
54+
CSC_CONTAINERS,
55+
CSR_CONTAINERS,
56+
DOK_CONTAINERS,
57+
LIL_CONTAINERS,
58+
)
5259

5360

5461
def test_multi_target_regression():
@@ -101,25 +108,29 @@ def test_multi_target_regression_one_target():
101108
rgr.fit(X, y)
102109

103110

104-
def test_multi_target_sparse_regression():
111+
@pytest.mark.parametrize(
112+
"sparse_container",
113+
CSR_CONTAINERS
114+
+ CSC_CONTAINERS
115+
+ COO_CONTAINERS
116+
+ LIL_CONTAINERS
117+
+ DOK_CONTAINERS
118+
+ BSR_CONTAINERS,
119+
)
120+
def test_multi_target_sparse_regression(sparse_container):
105121
X, y = datasets.make_regression(n_targets=3, random_state=0)
106122
X_train, y_train = X[:50], y[:50]
107123
X_test = X[50:]
108124

109-
for sparse in [
110-
sp.csr_matrix,
111-
sp.csc_matrix,
112-
sp.coo_matrix,
113-
sp.dok_matrix,
114-
sp.lil_matrix,
115-
]:
116-
rgr = MultiOutputRegressor(Lasso(random_state=0))
117-
rgr_sparse = MultiOutputRegressor(Lasso(random_state=0))
125+
rgr = MultiOutputRegressor(Lasso(random_state=0))
126+
rgr_sparse = MultiOutputRegressor(Lasso(random_state=0))
118127

119-
rgr.fit(X_train, y_train)
120-
rgr_sparse.fit(sparse(X_train), y_train)
128+
rgr.fit(X_train, y_train)
129+
rgr_sparse.fit(sparse_container(X_train), y_train)
121130

122-
assert_almost_equal(rgr.predict(X_test), rgr_sparse.predict(sparse(X_test)))
131+
assert_almost_equal(
132+
rgr.predict(X_test), rgr_sparse.predict(sparse_container(X_test))
133+
)
123134

124135

125136
def test_multi_target_sample_weights_api():
@@ -497,10 +508,11 @@ def test_classifier_chain_fit_and_predict_with_linear_svc():
497508
assert not hasattr(classifier_chain, "predict_proba")
498509

499510

500-
def test_classifier_chain_fit_and_predict_with_sparse_data():
511+
@pytest.mark.parametrize("csr_container", CSR_CONTAINERS)
512+
def test_classifier_chain_fit_and_predict_with_sparse_data(csr_container):
501513
# Fit classifier chain with sparse data
502514
X, Y = generate_multilabel_dataset_with_correlations()
503-
X_sparse = sp.csr_matrix(X)
515+
X_sparse = csr_container(X)
504516

505517
classifier_chain = ClassifierChain(LogisticRegression())
506518
classifier_chain.fit(X_sparse, Y)
@@ -555,10 +567,11 @@ def test_base_chain_fit_and_predict():
555567
assert isinstance(chains[1], ClassifierMixin)
556568

557569

558-
def test_base_chain_fit_and_predict_with_sparse_data_and_cv():
570+
@pytest.mark.parametrize("csr_container", CSR_CONTAINERS)
571+
def test_base_chain_fit_and_predict_with_sparse_data_and_cv(csr_container):
559572
# Fit base chain with sparse data cross_val_predict
560573
X, Y = generate_multilabel_dataset_with_correlations()
561-
X_sparse = sp.csr_matrix(X)
574+
X_sparse = csr_container(X)
562575
base_chains = [
563576
ClassifierChain(LogisticRegression(), cv=3),
564577
RegressorChain(Ridge(), cv=3),

0 commit comments

Comments
 (0)