Skip to content

Commit 8faa920

Browse files
TST Extend tests for scipy.sparse/*array in sklearn/impute/tests/test_common (scikit-learn#27277)
Co-authored-by: Guillaume Lemaitre <g.lemaitre58@gmail.com>
1 parent 2a21025 commit 8faa920

File tree

3 files changed

+21
-5
lines changed

3 files changed

+21
-5
lines changed

doc/whats_new/v1.4.rst

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -177,6 +177,14 @@ Changelog
177177
- |Fix| :func:`feature_selection.mutual_info_regression` now correctly computes the
178178
result when `X` is of integer dtype. :pr:`26748` by :user:`Yao Xiao <Charlie-XIAO>`.
179179

180+
:mod:`sklearn.impute`
181+
.....................
182+
183+
- |Enhancement| In :class:`impute.SimpleImputer`, :class:`impute.IterativeImputer`, and
184+
:class:`impute.KNNImputer` with ``add_indicator=True``, using sparse arrays now
185+
behaves in consistent with using sparse matrices in the `transform` and
186+
`fit_transform` methods. :pr:`27277` by :user:`Yao Xiao <Charlie-XIAO>`.
187+
180188
:mod:`sklearn.kernel_approximation`
181189
...................................
182190

sklearn/impute/_base.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
import numbers
66
import warnings
77
from collections import Counter
8+
from functools import partial
89

910
import numpy as np
1011
import numpy.ma as ma
@@ -115,7 +116,13 @@ def _concatenate_indicator(self, X_imputed, X_indicator):
115116
if not self.add_indicator:
116117
return X_imputed
117118

118-
hstack = sp.hstack if sp.issparse(X_imputed) else np.hstack
119+
if sp.issparse(X_imputed):
120+
# sp.hstack may result in different formats between sparse arrays and
121+
# matrices; specify the format to keep consistent behavior
122+
hstack = partial(sp.hstack, format=X_imputed.format)
123+
else:
124+
hstack = np.hstack
125+
119126
if X_indicator is None:
120127
raise ValueError(
121128
"Data from the missing indicator are not provided. Call "

sklearn/impute/tests/test_common.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
import numpy as np
22
import pytest
3-
from scipy import sparse
43

54
from sklearn.experimental import enable_iterative_imputer # noqa
65
from sklearn.impute import IterativeImputer, KNNImputer, SimpleImputer
@@ -9,6 +8,7 @@
98
assert_allclose_dense_sparse,
109
assert_array_equal,
1110
)
11+
from sklearn.utils.fixes import CSR_CONTAINERS
1212

1313

1414
def imputers():
@@ -69,16 +69,17 @@ def test_imputers_add_indicator(marker, imputer):
6969
@pytest.mark.parametrize(
7070
"imputer", sparse_imputers(), ids=lambda x: x.__class__.__name__
7171
)
72-
def test_imputers_add_indicator_sparse(imputer, marker):
73-
X = sparse.csr_matrix(
72+
@pytest.mark.parametrize("csr_container", CSR_CONTAINERS)
73+
def test_imputers_add_indicator_sparse(imputer, marker, csr_container):
74+
X = csr_container(
7475
[
7576
[marker, 1, 5, marker, 1],
7677
[2, marker, 1, marker, 2],
7778
[6, 3, marker, marker, 3],
7879
[1, 2, 9, marker, 4],
7980
]
8081
)
81-
X_true_indicator = sparse.csr_matrix(
82+
X_true_indicator = csr_container(
8283
[
8384
[1.0, 0.0, 0.0, 1.0],
8485
[0.0, 1.0, 0.0, 1.0],

0 commit comments

Comments
 (0)