Skip to content

Commit dfe01c2

Browse files
authored
TST Extend tests for scipy.sparse.*array in sklearn/utils/tests/test_set_output.py (scikit-learn#27202)
1 parent cfbab77 commit dfe01c2

File tree

1 file changed

+4
-3
lines changed

1 file changed

+4
-3
lines changed

sklearn/utils/tests/test_set_output.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@
33
import numpy as np
44
import pytest
55
from numpy.testing import assert_array_equal
6-
from scipy.sparse import csr_matrix
76

87
from sklearn._config import config_context, get_config
98
from sklearn.utils._set_output import (
@@ -12,6 +11,7 @@
1211
_SetOutputMixin,
1312
_wrap_in_pandas_container,
1413
)
14+
from sklearn.utils.fixes import CSR_CONTAINERS
1515

1616

1717
def test__wrap_in_pandas_container_dense():
@@ -41,10 +41,11 @@ def test__wrap_in_pandas_container_dense_update_columns_and_index():
4141
assert_array_equal(new_df.index, X_df.index)
4242

4343

44-
def test__wrap_in_pandas_container_error_validation():
44+
@pytest.mark.parametrize("csr_container", CSR_CONTAINERS)
45+
def test__wrap_in_pandas_container_error_validation(csr_container):
4546
"""Check errors in _wrap_in_pandas_container."""
4647
X = np.asarray([[1, 0, 3], [0, 0, 1]])
47-
X_csr = csr_matrix(X)
48+
X_csr = csr_container(X)
4849
match = "Pandas output does not support sparse data"
4950
with pytest.raises(ValueError, match=match):
5051
_wrap_in_pandas_container(X_csr, columns=["a", "b", "c"])

0 commit comments

Comments
 (0)