Skip to content

Commit 74a9756

Browse files
antoinewdgamueller
authored andcommitted
[MRG+2] Norm inconsistency between RFE and SelectFromModel (was _LearntSelectorMixin) scikit-learn#2121 (scikit-learn#6181)
* Norm inconsistency between RFE and SelectFromModel (was _LearntSelectorMixin) scikit-learn#2121 * safe_pwr utility * Norm fix * Removed safe_pwr * 1D arrays support for norm fix * Test case for 2d coef in SelectFromModel * Fix numpy version requirement for norm fix * Implement fixes suggested by @jnothman * Add numpy version requiring the fix.
1 parent 177ac84 commit 74a9756

File tree

5 files changed

+96
-4
lines changed

5 files changed

+96
-4
lines changed

doc/whats_new.rst

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,9 @@ Enhancements
5252
(`#7506` <https://github.com/scikit-learn/scikit-learn/pull/7506>_) by
5353
`Narine Kokhlikyan`_.
5454

55+
- Added ``norm_order`` parameter to :class:`feature_selection.SelectFromModel`
56+
to enable selection of the norm order when ``coef_`` is more than 1D
57+
5558
Bug fixes
5659
.........
5760

sklearn/feature_selection/from_model.py

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -10,9 +10,10 @@
1010
from ..utils import safe_mask, check_array, deprecated
1111
from ..utils.validation import check_is_fitted
1212
from ..exceptions import NotFittedError
13+
from ..utils.fixes import norm
1314

1415

15-
def _get_feature_importances(estimator):
16+
def _get_feature_importances(estimator, norm_order=1):
1617
"""Retrieve or aggregate feature importances from estimator"""
1718
importances = getattr(estimator, "feature_importances_", None)
1819

@@ -21,7 +22,7 @@ def _get_feature_importances(estimator):
2122
importances = np.abs(estimator.coef_)
2223

2324
else:
24-
importances = np.sum(np.abs(estimator.coef_), axis=0)
25+
importances = norm(estimator.coef_, axis=0, ord=norm_order)
2526

2627
elif importances is None:
2728
raise ValueError(
@@ -172,6 +173,11 @@ class SelectFromModel(BaseEstimator, SelectorMixin):
172173
Otherwise train the model using ``fit`` and then ``transform`` to do
173174
feature selection.
174175
176+
norm_order : non-zero int, inf, -inf, default 1
177+
Order of the norm used to filter the vectors of coefficients below
178+
``threshold`` in the case where the ``coef_`` attribute of the
179+
estimator is of dimension 2.
180+
175181
Attributes
176182
----------
177183
`estimator_`: an estimator
@@ -182,10 +188,12 @@ class SelectFromModel(BaseEstimator, SelectorMixin):
182188
`threshold_`: float
183189
The threshold value used for feature selection.
184190
"""
185-
def __init__(self, estimator, threshold=None, prefit=False):
191+
192+
def __init__(self, estimator, threshold=None, prefit=False, norm_order=1):
186193
self.estimator = estimator
187194
self.threshold = threshold
188195
self.prefit = prefit
196+
self.norm_order = norm_order
189197

190198
def _get_support_mask(self):
191199
# SelectFromModel can directly call on transform.
@@ -197,7 +205,7 @@ def _get_support_mask(self):
197205
raise ValueError(
198206
'Either fit the model before transform or set "prefit=True"'
199207
' while passing the fitted estimator to the constructor.')
200-
scores = _get_feature_importances(estimator)
208+
scores = _get_feature_importances(estimator, self.norm_order)
201209
self.threshold_ = _calculate_threshold(estimator, scores,
202210
self.threshold)
203211
return scores >= self.threshold_

sklearn/feature_selection/tests/test_from_model.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
from sklearn.feature_selection import SelectFromModel
1818
from sklearn.ensemble import RandomForestClassifier
1919
from sklearn.linear_model import PassiveAggressiveClassifier
20+
from sklearn.utils.fixes import norm
2021

2122
iris = datasets.load_iris()
2223
data, y = iris.data, iris.target
@@ -102,6 +103,31 @@ def test_feature_importances():
102103
assert_array_equal(X_new, X[:, mask])
103104

104105

106+
@skip_if_32bit
107+
def test_feature_importances_2d_coef():
108+
X, y = datasets.make_classification(
109+
n_samples=1000, n_features=10, n_informative=3, n_redundant=0,
110+
n_repeated=0, shuffle=False, random_state=0, n_classes=4)
111+
112+
est = LogisticRegression()
113+
for threshold, func in zip(["mean", "median"], [np.mean, np.median]):
114+
for order in [1, 2, np.inf]:
115+
# Fit SelectFromModel a multi-class problem
116+
transformer = SelectFromModel(estimator=LogisticRegression(),
117+
threshold=threshold,
118+
norm_order=order)
119+
transformer.fit(X, y)
120+
assert_true(hasattr(transformer.estimator_, 'coef_'))
121+
X_new = transformer.transform(X)
122+
assert_less(X_new.shape[1], X.shape[1])
123+
124+
# Manually check that the norm is correctly performed
125+
est.fit(X, y)
126+
importances = norm(est.coef_, axis=0, ord=order)
127+
feature_mask = importances > func(importances)
128+
assert_array_equal(X_new, X[:, feature_mask])
129+
130+
105131
def test_partial_fit():
106132
est = PassiveAggressiveClassifier(random_state=0, shuffle=False)
107133
transformer = SelectFromModel(estimator=est)

sklearn/utils/fixes.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -419,3 +419,33 @@ def __getstate__(self):
419419
self._fill_value)
420420
else:
421421
from numpy.ma import MaskedArray # noqa
422+
423+
if 'axis' not in signature(np.linalg.norm).parameters:
424+
425+
def norm(X, ord=None, axis=None):
426+
"""
427+
Handles the axis parameter for the norm function
428+
in old versions of numpy (useless for numpy >= 1.8).
429+
"""
430+
431+
if axis is None or X.ndim == 1:
432+
result = np.linalg.norm(X, ord=ord)
433+
return result
434+
435+
if axis not in (0, 1):
436+
raise NotImplementedError("""
437+
The fix that adds axis parameter to the old numpy
438+
norm only works for 1D or 2D arrays.
439+
""")
440+
441+
if axis == 0:
442+
X = X.T
443+
444+
result = np.zeros(X.shape[0])
445+
for i in range(len(result)):
446+
result[i] = np.linalg.norm(X[i], ord=ord)
447+
448+
return result
449+
450+
else:
451+
norm = np.linalg.norm

sklearn/utils/tests/test_fixes.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55

66
import pickle
77
import numpy as np
8+
import math
89

910
from sklearn.utils.testing import assert_equal
1011
from sklearn.utils.testing import assert_false
@@ -16,6 +17,7 @@
1617
from sklearn.utils.fixes import divide, expit
1718
from sklearn.utils.fixes import astype
1819
from sklearn.utils.fixes import MaskedArray
20+
from sklearn.utils.fixes import norm
1921

2022

2123
def test_expit():
@@ -66,3 +68,26 @@ def test_masked_array_obj_dtype_pickleable():
6668
marr_pickled = pickle.loads(pickle.dumps(marr))
6769
assert_array_equal(marr.data, marr_pickled.data)
6870
assert_array_equal(marr.mask, marr_pickled.mask)
71+
72+
73+
def test_norm():
74+
X = np.array([[-2, 4, 5],
75+
[1, 3, -4],
76+
[0, 0, 8],
77+
[0, 0, 0]]).astype(float)
78+
79+
# Test various axis and order
80+
assert_equal(math.sqrt(135), norm(X))
81+
assert_array_equal(
82+
np.array([math.sqrt(5), math.sqrt(25), math.sqrt(105)]),
83+
norm(X, axis=0)
84+
)
85+
assert_array_equal(np.array([3, 7, 17]), norm(X, axis=0, ord=1))
86+
assert_array_equal(np.array([2, 4, 8]), norm(X, axis=0, ord=np.inf))
87+
assert_array_equal(np.array([0, 0, 0]), norm(X, axis=0, ord=-np.inf))
88+
assert_array_equal(np.array([11, 8, 8, 0]), norm(X, axis=1, ord=1))
89+
90+
# Test shapes
91+
assert_equal((), norm(X).shape)
92+
assert_equal((3,), norm(X, axis=0).shape)
93+
assert_equal((4,), norm(X, axis=1).shape)

0 commit comments

Comments
 (0)