Skip to content

Commit 0a0bf24

Browse files
authored
Revert "ENH: Add indicator features to imputer output (scikit-learn#6607)" (scikit-learn#7292)
This reverts commit 18396be as it was merged when incomplete.
1 parent 4566251 commit 0a0bf24

File tree

4 files changed

+28
-173
lines changed

4 files changed

+28
-173
lines changed

doc/modules/preprocessing.rst

Lines changed: 5 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -445,49 +445,28 @@ values, either using the mean, the median or the most frequent value of
445445
the row or column in which the missing values are located. This class
446446
also allows for different missing values encodings.
447447

448-
Imputing missing values ordinarily discards the information of which values
449-
were missing. Setting ``add_indicator_features=True`` allows the knowledge of
450-
which features were imputed to be exploited by a downstream estimator
451-
by adding features that indicate which elements have been imputed.
452-
453448
The following snippet demonstrates how to replace missing values,
454449
encoded as ``np.nan``, using the mean value of the columns (axis 0)
455-
that contain the missing values. In case there is a feature which has
456-
all missing features, it is discarded when transformed. Also if the
457-
indicator matrix is requested (``add_indicator_features=True``),
458-
then the shape of the transformed input is
459-
``(n_samples, n_features_new + len(imputed_features_))`` ::
450+
that contain the missing values::
460451

461452
>>> import numpy as np
462453
>>> from sklearn.preprocessing import Imputer
463454
>>> imp = Imputer(missing_values='NaN', strategy='mean', axis=0)
464-
>>> imp.fit([[1, 2], [np.nan, 3], [7, 6]]) # doctest: +NORMALIZE_WHITESPACE
465-
Imputer(add_indicator_features=False, axis=0, copy=True, missing_values='NaN',
466-
strategy='mean', verbose=0)
455+
>>> imp.fit([[1, 2], [np.nan, 3], [7, 6]])
456+
Imputer(axis=0, copy=True, missing_values='NaN', strategy='mean', verbose=0)
467457
>>> X = [[np.nan, 2], [6, np.nan], [7, 6]]
468458
>>> print(imp.transform(X)) # doctest: +ELLIPSIS
469459
[[ 4. 2. ]
470460
[ 6. 3.666...]
471461
[ 7. 6. ]]
472-
>>> imp_with_in = Imputer(missing_values='NaN', strategy='mean', axis=0,add_indicator_features=True)
473-
>>> imp_with_in.fit([[1, 2], [np.nan, 3], [7, 6]])
474-
Imputer(add_indicator_features=True, axis=0, copy=True, missing_values='NaN',
475-
strategy='mean', verbose=0)
476-
>>> print(imp_with_in.transform(X)) # doctest: +ELLIPSIS
477-
[[ 4. 2. 1. 0. ]
478-
[ 6. 3.66666667 0. 1. ]
479-
[ 7. 6. 0. 0. ]]
480-
>>> print(imp_with_in.imputed_features_)
481-
[0 1]
482462

483463
The :class:`Imputer` class also supports sparse matrices::
484464

485465
>>> import scipy.sparse as sp
486466
>>> X = sp.csc_matrix([[1, 2], [0, 3], [7, 6]])
487467
>>> imp = Imputer(missing_values=0, strategy='mean', axis=0)
488-
>>> imp.fit(X) # doctest: +NORMALIZE_WHITESPACE
489-
Imputer(add_indicator_features=False, axis=0, copy=True, missing_values=0,
490-
strategy='mean', verbose=0)
468+
>>> imp.fit(X)
469+
Imputer(axis=0, copy=True, missing_values=0, strategy='mean', verbose=0)
491470
>>> X_test = sp.csc_matrix([[0, 2], [6, 0], [7, 6]])
492471
>>> print(imp.transform(X_test)) # doctest: +ELLIPSIS
493472
[[ 4. 2. ]

examples/missing_values.py

Lines changed: 3 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -8,22 +8,16 @@
88
Imputing does not always improve the predictions, so please check via cross-validation.
99
Sometimes dropping rows or using marker values is more effective.
1010
11-
In this example, we artificially mark some of the elements in complete
12-
dataset as missing. Then we estimate performance using the complete dataset,
13-
dataset without the missing samples, after imputation without the indicator
14-
matrix and imputation with the indicator matrix for the missing values.
15-
1611
Missing values can be replaced by the mean, the median or the most frequent
1712
value using the ``strategy`` hyper-parameter.
1813
The median is a more robust estimator for data with high magnitude variables
1914
which could dominate results (otherwise known as a 'long tail').
2015
2116
Script output::
2217
23-
Score with the complete dataset = 0.56
18+
Score with the entire dataset = 0.56
2419
Score without the samples containing missing values = 0.48
2520
Score after imputation of the missing values = 0.55
26-
Score after imputation with indicator features = 0.57
2721
2822
In this case, imputing helps the classifier get close to the original score.
2923
@@ -46,11 +40,11 @@
4640
# Estimate the score on the entire dataset, with no missing values
4741
estimator = RandomForestRegressor(random_state=0, n_estimators=100)
4842
score = cross_val_score(estimator, X_full, y_full).mean()
49-
print("Score with the complete dataset = %.2f" % score)
43+
print("Score with the entire dataset = %.2f" % score)
5044

5145
# Add missing values in 75% of the lines
5246
missing_rate = 0.75
53-
n_missing_samples = int(n_samples * missing_rate)
47+
n_missing_samples = np.floor(n_samples * missing_rate)
5448
missing_samples = np.hstack((np.zeros(n_samples - n_missing_samples,
5549
dtype=np.bool),
5650
np.ones(n_missing_samples,
@@ -76,12 +70,3 @@
7670
n_estimators=100))])
7771
score = cross_val_score(estimator, X_missing, y_missing).mean()
7872
print("Score after imputation of the missing values = %.2f" % score)
79-
80-
# Estimate score after imputation of the missing values with indicator matrix
81-
estimator = Pipeline([("imputer", Imputer(missing_values=0,
82-
strategy="mean",
83-
axis=0, add_indicator_features=True)),
84-
("forest", RandomForestRegressor(random_state=0,
85-
n_estimators=100))])
86-
score = cross_val_score(estimator, X_missing, y_missing).mean()
87-
print("Score after imputation with indicator features = %.2f" % score)

sklearn/preprocessing/imputation.py

Lines changed: 20 additions & 81 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,6 @@
1010

1111
from ..base import BaseEstimator, TransformerMixin
1212
from ..utils import check_array
13-
from ..utils import safe_mask
1413
from ..utils.fixes import astype
1514
from ..utils.sparsefuncs import _get_median
1615
from ..utils.validation import check_is_fitted
@@ -103,21 +102,11 @@ class Imputer(BaseEstimator, TransformerMixin):
103102
- If `axis=0` and X is encoded as a CSR matrix;
104103
- If `axis=1` and X is encoded as a CSC matrix.
105104
106-
add_indicator_features : boolean, optional (default=False)
107-
If True, the transformed ``X`` will have binary indicator features
108-
appended. These correspond to input features with at least one
109-
missing value marking which elements have been imputed.
110-
111105
Attributes
112106
----------
113107
statistics_ : array of shape (n_features,)
114108
The imputation fill value for each feature if axis == 0.
115109
116-
imputed_features_ : array of shape (n_features_with_missing, )
117-
The input features which have been imputed during transform.
118-
The size of this attribute will be the number of features with
119-
at least one missing value (and fewer than all in the axis=0 case).
120-
121110
Notes
122111
-----
123112
- When ``axis=0``, columns which only contained missing values at `fit`
@@ -127,13 +116,12 @@ class Imputer(BaseEstimator, TransformerMixin):
127116
contain missing values).
128117
"""
129118
def __init__(self, missing_values="NaN", strategy="mean",
130-
axis=0, verbose=0, copy=True, add_indicator_features=False):
119+
axis=0, verbose=0, copy=True):
131120
self.missing_values = missing_values
132121
self.strategy = strategy
133122
self.axis = axis
134123
self.verbose = verbose
135124
self.copy = copy
136-
self.add_indicator_features = add_indicator_features
137125

138126
def fit(self, X, y=None):
139127
"""Fit the imputer on X.
@@ -311,74 +299,13 @@ def _dense_fit(self, X, strategy, missing_values, axis):
311299

312300
return most_frequent
313301

314-
def _sparse_transform(self, X, valid_stats, valid_idx):
315-
"""transformer on sparse data."""
316-
mask = _get_mask(X.data, self.missing_values)
317-
indexes = np.repeat(np.arange(len(X.indptr) - 1, dtype=np.int),
318-
np.diff(X.indptr))[mask]
319-
320-
X.data[mask] = astype(valid_stats[indexes], X.dtype,
321-
copy=False)
322-
323-
mask_matrix = X.__class__((mask, X.indices.copy(),
324-
X.indptr.copy()), shape=X.shape,
325-
dtype=X.dtype)
326-
mask_matrix.eliminate_zeros() # removes explicit False entries
327-
features_with_missing_values = mask_matrix.sum(axis=0).A.nonzero()[1]
328-
features_mask = safe_mask(mask_matrix, features_with_missing_values)
329-
imputed_mask = mask_matrix[:, features_mask]
330-
if self.axis == 0:
331-
self.imputed_features_ = valid_idx[features_with_missing_values]
332-
else:
333-
self.imputed_features_ = features_with_missing_values
334-
335-
if self.add_indicator_features:
336-
X = sparse.hstack((X, imputed_mask))
337-
338-
return X
339-
340-
def _dense_transform(self, X, valid_stats, valid_idx):
341-
"""transformer on dense data."""
342-
mask = _get_mask(X, self.missing_values)
343-
n_missing = np.sum(mask, axis=self.axis)
344-
values = np.repeat(valid_stats, n_missing)
345-
346-
if self.axis == 0:
347-
coordinates = np.where(mask.transpose())[::-1]
348-
else:
349-
coordinates = mask
350-
351-
X[coordinates] = values
352-
353-
features_with_missing_values = np.where(np.any
354-
(mask, axis=0))[0]
355-
imputed_mask = mask[:, features_with_missing_values]
356-
if self.axis == 0:
357-
self.imputed_features_ = valid_idx[features_with_missing_values]
358-
else:
359-
self.imputed_features_ = features_with_missing_values
360-
361-
if self.add_indicator_features:
362-
X = np.hstack((X, imputed_mask))
363-
364-
return X
365-
366302
def transform(self, X):
367303
"""Impute all missing values in X.
368304
369305
Parameters
370306
----------
371-
X : {array-like, sparse matrix}, shape = (n_samples, n_features)
307+
X : {array-like, sparse matrix}, shape = [n_samples, n_features]
372308
The input data to complete.
373-
374-
Return
375-
------
376-
X_new : {array-like, sparse matrix},
377-
Transformed array.
378-
shape (n_samples, n_features_new) when
379-
``add_indicator_features`` is False,
380-
shape (n_samples, n_features_new + len(imputed_features_)
381-
when ``add_indicator_features`` is True.
382309
"""
383310
if self.axis == 0:
384311
check_is_fitted(self, 'statistics_')
@@ -410,27 +337,39 @@ def transform(self, X):
410337
invalid_mask = np.isnan(statistics)
411338
valid_mask = np.logical_not(invalid_mask)
412339
valid_statistics = statistics[valid_mask]
413-
valid_idx = np.where(valid_mask)[0]
340+
valid_statistics_indexes = np.where(valid_mask)[0]
414341
missing = np.arange(X.shape[not self.axis])[invalid_mask]
415342

416343
if self.axis == 0 and invalid_mask.any():
417344
if self.verbose:
418345
warnings.warn("Deleting features without "
419346
"observed values: %s" % missing)
420-
X = X[:, valid_idx]
347+
X = X[:, valid_statistics_indexes]
421348
elif self.axis == 1 and invalid_mask.any():
422349
raise ValueError("Some rows only contain "
423350
"missing values: %s" % missing)
424351

425352
# Do actual imputation
426353
if sparse.issparse(X) and self.missing_values != 0:
427-
# sparse matrix and missing values is not zero
428-
X = self._sparse_transform(X, valid_statistics, valid_idx)
354+
mask = _get_mask(X.data, self.missing_values)
355+
indexes = np.repeat(np.arange(len(X.indptr) - 1, dtype=np.int),
356+
np.diff(X.indptr))[mask]
357+
358+
X.data[mask] = astype(valid_statistics[indexes], X.dtype,
359+
copy=False)
429360
else:
430-
# sparse with zero as missing value and dense matrix
431361
if sparse.issparse(X):
432362
X = X.toarray()
433363

434-
X = self._dense_transform(X, valid_statistics, valid_idx)
364+
mask = _get_mask(X, self.missing_values)
365+
n_missing = np.sum(mask, axis=self.axis)
366+
values = np.repeat(valid_statistics, n_missing)
367+
368+
if self.axis == 0:
369+
coordinates = np.where(mask.transpose())[::-1]
370+
else:
371+
coordinates = mask
372+
373+
X[coordinates] = values
435374

436375
return X

sklearn/preprocessing/tests/test_imputation.py

Lines changed: 0 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@
22
import numpy as np
33
from scipy import sparse
44

5-
from sklearn.base import clone
65
from sklearn.utils.testing import assert_equal
76
from sklearn.utils.testing import assert_array_equal
87
from sklearn.utils.testing import assert_raises
@@ -359,50 +358,3 @@ def test_imputation_copy():
359358

360359
# Note: If X is sparse and if missing_values=0, then a (dense) copy of X is
361360
# made, even if copy=False.
362-
363-
364-
def check_indicator(X, expected_imputed_features, axis):
365-
n_samples, n_features = X.shape
366-
imputer = Imputer(missing_values=-1, strategy='mean', axis=axis)
367-
imputer_with_in = clone(imputer).set_params(add_indicator_features=True)
368-
Xt = imputer.fit_transform(X)
369-
Xt_with_in = imputer_with_in.fit_transform(X)
370-
imputed_features_mask = X[:, expected_imputed_features] == -1
371-
n_features_new = Xt.shape[1]
372-
n_imputed_features = len(imputer_with_in.imputed_features_)
373-
assert_array_equal(imputer.imputed_features_, expected_imputed_features)
374-
assert_array_equal(imputer_with_in.imputed_features_,
375-
expected_imputed_features)
376-
assert_equal(Xt_with_in.shape,
377-
(n_samples, n_features_new + n_imputed_features))
378-
assert_array_equal(Xt_with_in, np.hstack((Xt, imputed_features_mask)))
379-
imputer_with_in = clone(imputer).set_params(add_indicator_features=True)
380-
assert_array_equal(Xt_with_in,
381-
imputer_with_in.fit_transform(sparse.csc_matrix(X)).A)
382-
assert_array_equal(Xt_with_in,
383-
imputer_with_in.fit_transform(sparse.csr_matrix(X)).A)
384-
385-
386-
def test_indicator_features():
387-
# one feature with all missng values
388-
X = np.array([
389-
[-1, -1, 2, 3],
390-
[4, -1, 6, -1],
391-
[8, -1, 10, 11],
392-
[12, -1, -1, 15],
393-
[16, -1, 18, 19]
394-
])
395-
check_indicator(X, np.array([0, 2, 3]), axis=0)
396-
check_indicator(X, np.array([0, 1, 2, 3]), axis=1)
397-
398-
# one feature with all missing values and one with no missing value
399-
# when axis=0 the feature gets discarded
400-
X = np.array([
401-
[-1, -1, 1, 3],
402-
[4, -1, 0, -1],
403-
[8, -1, 1, 0],
404-
[0, -1, 0, 15],
405-
[16, -1, 1, 19]
406-
])
407-
check_indicator(X, np.array([0, 3]), axis=0)
408-
check_indicator(X, np.array([0, 1, 3]), axis=1)

0 commit comments

Comments
 (0)