|
6 | 6 |
|
7 | 7 | import numpy as np
|
8 | 8 | import pytest
|
9 |
| -import scipy.sparse as sp |
10 | 9 |
|
11 | 10 | from sklearn import config_context
|
12 | 11 | from sklearn.utils import (
|
|
35 | 34 | assert_array_equal,
|
36 | 35 | assert_no_warnings,
|
37 | 36 | )
|
| 37 | +from sklearn.utils.fixes import CSC_CONTAINERS, CSR_CONTAINERS |
38 | 38 |
|
39 | 39 | # toy array
|
40 | 40 | X_toy = np.arange(9).reshape((3, 3))
|
@@ -160,21 +160,23 @@ def test_resample_stratify_2dy():
|
160 | 160 | assert y.ndim == 2
|
161 | 161 |
|
162 | 162 |
|
163 |
| -def test_resample_stratify_sparse_error(): |
| 163 | +@pytest.mark.parametrize("csr_container", CSR_CONTAINERS) |
| 164 | +def test_resample_stratify_sparse_error(csr_container): |
164 | 165 | # resample must be ndarray
|
165 | 166 | rng = np.random.RandomState(0)
|
166 | 167 | n_samples = 100
|
167 | 168 | X = rng.normal(size=(n_samples, 2))
|
168 | 169 | y = rng.randint(0, 2, size=n_samples)
|
169 |
| - stratify = sp.csr_matrix(y) |
| 170 | + stratify = csr_container(y) |
170 | 171 | with pytest.raises(TypeError, match="A sparse matrix was passed"):
|
171 | 172 | X, y = resample(X, y, n_samples=50, random_state=rng, stratify=stratify)
|
172 | 173 |
|
173 | 174 |
|
174 |
| -def test_safe_mask(): |
| 175 | +@pytest.mark.parametrize("csr_container", CSR_CONTAINERS) |
| 176 | +def test_safe_mask(csr_container): |
175 | 177 | random_state = check_random_state(0)
|
176 | 178 | X = random_state.rand(5, 4)
|
177 |
| - X_csr = sp.csr_matrix(X) |
| 179 | + X_csr = csr_container(X) |
178 | 180 | mask = [False, False, True, True, True]
|
179 | 181 |
|
180 | 182 | mask = safe_mask(X, mask)
|
@@ -514,14 +516,15 @@ def to_tuple(A): # to make the inner arrays hashable
|
514 | 516 | assert set(to_tuple(A)) == S
|
515 | 517 |
|
516 | 518 |
|
517 |
| -def test_shuffle_dont_convert_to_array(): |
| 519 | +@pytest.mark.parametrize("csc_container", CSC_CONTAINERS) |
| 520 | +def test_shuffle_dont_convert_to_array(csc_container): |
518 | 521 | # Check that shuffle does not try to convert to numpy arrays with float
|
519 | 522 | # dtypes can let any indexable datastructure pass-through.
|
520 | 523 | a = ["a", "b", "c"]
|
521 | 524 | b = np.array(["a", "b", "c"], dtype=object)
|
522 | 525 | c = [1, 2, 3]
|
523 | 526 | d = MockDataFrame(np.array([["a", 0], ["b", 1], ["c", 2]], dtype=object))
|
524 |
| - e = sp.csc_matrix(np.arange(6).reshape(3, 2)) |
| 527 | + e = csc_container(np.arange(6).reshape(3, 2)) |
525 | 528 | a_s, b_s, c_s, d_s, e_s = shuffle(a, b, c, d, e, random_state=0)
|
526 | 529 |
|
527 | 530 | assert a_s == ["c", "b", "a"]
|
|
0 commit comments