Skip to content

Commit 89ea028

Browse files
thomasjpfanogrisel
andauthored
ENH Improve warnings if func returns a dataframe in FunctionTransformer (scikit-learn#26944)
Co-authored-by: Olivier Grisel <olivier.grisel@ensta.org>
1 parent 066375f commit 89ea028

File tree

3 files changed

+45
-15
lines changed

3 files changed

+45
-15
lines changed

doc/whats_new/v1.4.rst

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -221,6 +221,13 @@ Changelog
221221
- |Enhancement| Added `neg_root_mean_squared_log_error_scorer` as scorer
222222
:pr:`26734` by :user:`Alejandro Martin Gil <101AlexMartin>`.
223223

224+
:mod:`sklearn.preprocessing`
225+
............................
226+
227+
- |Enhancement| Improves warnings in :class:`preprocessing.FunctionTransfomer` when
228+
`func` returns a pandas dataframe and the output is configured to be pandas.
229+
:pr:`26944` by `Thomas Fan`_.
230+
224231
:mod:`sklearn.model_selection`
225232
..............................
226233

sklearn/preprocessing/_function_transformer.py

Lines changed: 19 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -4,10 +4,12 @@
44

55
from ..base import BaseEstimator, TransformerMixin, _fit_context
66
from ..utils._param_validation import StrOptions
7+
from ..utils._set_output import _get_output_config
78
from ..utils.metaestimators import available_if
89
from ..utils.validation import (
910
_allclose_dense_sparse,
1011
_check_feature_names_in,
12+
_is_pandas_df,
1113
check_array,
1214
)
1315

@@ -237,7 +239,20 @@ def transform(self, X):
237239
Transformed input.
238240
"""
239241
X = self._check_input(X, reset=False)
240-
return self._transform(X, func=self.func, kw_args=self.kw_args)
242+
out = self._transform(X, func=self.func, kw_args=self.kw_args)
243+
244+
output_config = _get_output_config("transform", self)["dense"]
245+
if (
246+
output_config == "pandas"
247+
and self.feature_names_out is None
248+
and not _is_pandas_df(out)
249+
):
250+
warnings.warn(
251+
"When `set_output` is configured to be 'pandas', `func` should return "
252+
"a DataFrame to follow the `set_output` API or `feature_names_out` "
253+
"should be defined."
254+
)
255+
return out
241256

242257
def inverse_transform(self, X):
243258
"""Transform X using the inverse function.
@@ -338,13 +353,8 @@ def set_output(self, *, transform=None):
338353
self : estimator instance
339354
Estimator instance.
340355
"""
341-
if hasattr(super(), "set_output"):
342-
return super().set_output(transform=transform)
343-
344-
if transform == "pandas" and self.feature_names_out is None:
345-
warnings.warn(
346-
'With transform="pandas", `func` should return a DataFrame to follow'
347-
" the set_output API."
348-
)
356+
if not hasattr(self, "_sklearn_output_config"):
357+
self._sklearn_output_config = {}
349358

359+
self._sklearn_output_config["transform"] = transform
350360
return self

sklearn/preprocessing/tests/test_function_transformer.py

Lines changed: 19 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -451,13 +451,26 @@ def test_set_output_func():
451451
assert isinstance(X_trans, pd.DataFrame)
452452
assert_array_equal(X_trans.columns, ["a", "b"])
453453

454-
# If feature_names_out is not defined, then a warning is raised in
455-
# `set_output`
456454
ft = FunctionTransformer(lambda x: 2 * x)
457-
msg = "should return a DataFrame to follow the set_output API"
458-
with pytest.warns(UserWarning, match=msg):
459-
ft.set_output(transform="pandas")
455+
ft.set_output(transform="pandas")
460456

461-
X_trans = ft.fit_transform(X)
457+
# no warning is raised when func returns a panda dataframe
458+
with warnings.catch_warnings():
459+
warnings.simplefilter("error", UserWarning)
460+
X_trans = ft.fit_transform(X)
462461
assert isinstance(X_trans, pd.DataFrame)
463462
assert_array_equal(X_trans.columns, ["a", "b"])
463+
464+
# Warning is raised when func returns a ndarray
465+
ft_np = FunctionTransformer(lambda x: np.asarray(x))
466+
ft_np.set_output(transform="pandas")
467+
468+
msg = "When `set_output` is configured to be 'pandas'"
469+
with pytest.warns(UserWarning, match=msg):
470+
ft_np.fit_transform(X)
471+
472+
# default transform does not warn
473+
ft_np.set_output(transform="default")
474+
with warnings.catch_warnings():
475+
warnings.simplefilter("error", UserWarning)
476+
ft_np.fit_transform(X)

0 commit comments

Comments
 (0)