Skip to content

Commit cfbab77

Browse files
authored
TST Extend tests for scipy.sparse.*array in sklearn/utils/tests/test_utils.py (scikit-learn#27201)
1 parent af6c35d commit cfbab77

File tree

1 file changed

+10
-7
lines changed

1 file changed

+10
-7
lines changed

sklearn/utils/tests/test_utils.py

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@
66

77
import numpy as np
88
import pytest
9-
import scipy.sparse as sp
109

1110
from sklearn import config_context
1211
from sklearn.utils import (
@@ -35,6 +34,7 @@
3534
assert_array_equal,
3635
assert_no_warnings,
3736
)
37+
from sklearn.utils.fixes import CSC_CONTAINERS, CSR_CONTAINERS
3838

3939
# toy array
4040
X_toy = np.arange(9).reshape((3, 3))
@@ -160,21 +160,23 @@ def test_resample_stratify_2dy():
160160
assert y.ndim == 2
161161

162162

163-
def test_resample_stratify_sparse_error():
163+
@pytest.mark.parametrize("csr_container", CSR_CONTAINERS)
164+
def test_resample_stratify_sparse_error(csr_container):
164165
# resample must be ndarray
165166
rng = np.random.RandomState(0)
166167
n_samples = 100
167168
X = rng.normal(size=(n_samples, 2))
168169
y = rng.randint(0, 2, size=n_samples)
169-
stratify = sp.csr_matrix(y)
170+
stratify = csr_container(y)
170171
with pytest.raises(TypeError, match="A sparse matrix was passed"):
171172
X, y = resample(X, y, n_samples=50, random_state=rng, stratify=stratify)
172173

173174

174-
def test_safe_mask():
175+
@pytest.mark.parametrize("csr_container", CSR_CONTAINERS)
176+
def test_safe_mask(csr_container):
175177
random_state = check_random_state(0)
176178
X = random_state.rand(5, 4)
177-
X_csr = sp.csr_matrix(X)
179+
X_csr = csr_container(X)
178180
mask = [False, False, True, True, True]
179181

180182
mask = safe_mask(X, mask)
@@ -514,14 +516,15 @@ def to_tuple(A): # to make the inner arrays hashable
514516
assert set(to_tuple(A)) == S
515517

516518

517-
def test_shuffle_dont_convert_to_array():
519+
@pytest.mark.parametrize("csc_container", CSC_CONTAINERS)
520+
def test_shuffle_dont_convert_to_array(csc_container):
518521
# Check that shuffle does not try to convert to numpy arrays with float
519522
# dtypes can let any indexable datastructure pass-through.
520523
a = ["a", "b", "c"]
521524
b = np.array(["a", "b", "c"], dtype=object)
522525
c = [1, 2, 3]
523526
d = MockDataFrame(np.array([["a", 0], ["b", 1], ["c", 2]], dtype=object))
524-
e = sp.csc_matrix(np.arange(6).reshape(3, 2))
527+
e = csc_container(np.arange(6).reshape(3, 2))
525528
a_s, b_s, c_s, d_s, e_s = shuffle(a, b, c, d, e, random_state=0)
526529

527530
assert a_s == ["c", "b", "a"]

0 commit comments

Comments
 (0)