Skip to content

Commit 6413feb

Browse files
MrMjauhlesteve
authored andcommitted
[MRG+1] Fix inverse_transform in deprecated GridSearchCV (scikit-learn#8860)
1 parent 3de7da3 commit 6413feb

File tree

3 files changed

+17
-2
lines changed

3 files changed

+17
-2
lines changed

doc/whats_new.rst

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -189,6 +189,8 @@ Bug fixes
189189
- Fixed a bug where :func:`sklearn.model_selection.BaseSearchCV.inverse_transform`
190190
returns self.best_estimator_.transform() instead of self.best_estimator_.inverse_transform()
191191
:issue:`8344` by :user:`Akshay Gupta <Akshay0724>`
192+
- Fixed same issue in :func:`sklearn.grid_search.BaseSearchCV.inverse_transform`
193+
:issue:`8846` by :user:`Rasmus Eriksson <MrMjauh>`
192194

193195
- Fixed a bug where :class:`sklearn.linear_model.RandomizedLasso` and
194196
:class:`sklearn.linear_model.RandomizedLogisticRegression` breaks for

sklearn/grid_search.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -540,7 +540,7 @@ def inverse_transform(self, Xt):
540540
underlying estimator.
541541
542542
"""
543-
return self.best_estimator_.transform(Xt)
543+
return self.best_estimator_.inverse_transform(Xt)
544544

545545
def _fit(self, X, y, parameter_iterable):
546546
"""Actual fitting, performing the search over parameters."""

sklearn/tests/test_grid_search.py

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -71,9 +71,14 @@ def fit(self, X, Y):
7171
def predict(self, T):
7272
return T.shape[0]
7373

74+
def transform(self, X):
75+
return X - self.foo_param
76+
77+
def inverse_transform(self, X):
78+
return X + self.foo_param
79+
7480
predict_proba = predict
7581
decision_function = predict
76-
transform = predict
7782

7883
def score(self, X=None, Y=None):
7984
if self.foo_param > 1:
@@ -166,6 +171,14 @@ def test_grid_search():
166171
assert_raises(ValueError, grid_search.fit, X, y)
167172

168173

174+
def test_transform_inverse_transform_round_trip():
175+
clf = MockClassifier()
176+
grid_search = GridSearchCV(clf, {'foo_param': [1, 2, 3]}, verbose=3)
177+
grid_search.fit(X, y)
178+
X_round_trip = grid_search.inverse_transform(grid_search.transform(X))
179+
assert_array_equal(X, X_round_trip)
180+
181+
169182
@ignore_warnings
170183
def test_grid_search_no_score():
171184
# Test grid-search on classifier that has no score function.

0 commit comments

Comments
 (0)