Skip to content

Commit 4bc61a0

Browse files
authored
Fix performance regression in ColumnTransformer (scikit-learn#29330)
1 parent d840f36 commit 4bc61a0

File tree

3 files changed

+12
-22
lines changed

3 files changed

+12
-22
lines changed

doc/whats_new/v1.5.rst

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,13 @@ Changes impacting many modules
3333
Changelog
3434
---------
3535

36+
:mod:`sklearn.compose`
37+
......................
38+
39+
- |Efficiency| Fix a performance regression in :class:`compose.ColumnTransformer`
40+
where the full input data was copied for each transformer when `n_jobs > 1`.
41+
:pr:`29330` by :user:`Jérémie du Boisberranger <jeremiedbb>`.
42+
3643
:mod:`sklearn.metrics`
3744
......................
3845

sklearn/compose/_column_transformer.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
from ..preprocessing import FunctionTransformer
2020
from ..utils import Bunch
2121
from ..utils._estimator_html_repr import _VisualBlock
22-
from ..utils._indexing import _determine_key_type, _get_column_indices
22+
from ..utils._indexing import _determine_key_type, _get_column_indices, _safe_indexing
2323
from ..utils._metadata_requests import METHODS
2424
from ..utils._param_validation import HasMethods, Hidden, Interval, StrOptions
2525
from ..utils._set_output import (
@@ -873,10 +873,9 @@ def _call_func_on_transformers(self, X, y, func, column_as_labels, routed_params
873873
jobs.append(
874874
delayed(func)(
875875
transformer=clone(trans) if not fitted else trans,
876-
X=X,
876+
X=_safe_indexing(X, columns, axis=1),
877877
y=y,
878878
weight=weight,
879-
columns=columns,
880879
**extra_args,
881880
params=routed_params[name],
882881
)

sklearn/pipeline.py

Lines changed: 3 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
from .base import TransformerMixin, _fit_context, clone
1313
from .exceptions import NotFittedError
1414
from .preprocessing import FunctionTransformer
15-
from .utils import Bunch, _safe_indexing
15+
from .utils import Bunch
1616
from .utils._estimator_html_repr import _VisualBlock
1717
from .utils._metadata_requests import METHODS
1818
from .utils._param_validation import HasMethods, Hidden
@@ -1261,7 +1261,7 @@ def make_pipeline(*steps, memory=None, verbose=False):
12611261
return Pipeline(_name_estimators(steps), memory=memory, verbose=verbose)
12621262

12631263

1264-
def _transform_one(transformer, X, y, weight, columns=None, params=None):
1264+
def _transform_one(transformer, X, y, weight, params=None):
12651265
"""Call transform and apply weight to output.
12661266
12671267
Parameters
@@ -1278,17 +1278,11 @@ def _transform_one(transformer, X, y, weight, columns=None, params=None):
12781278
weight : float
12791279
Weight to be applied to the output of the transformation.
12801280
1281-
columns : str, array-like of str, int, array-like of int, array-like of bool, slice
1282-
Columns to select before transforming.
1283-
12841281
params : dict
12851282
Parameters to be passed to the transformer's ``transform`` method.
12861283
12871284
This should be of the form ``process_routing()["step_name"]``.
12881285
"""
1289-
if columns is not None:
1290-
X = _safe_indexing(X, columns, axis=1)
1291-
12921286
res = transformer.transform(X, **params.transform)
12931287
# if we have a weight for this transformer, multiply output
12941288
if weight is None:
@@ -1297,14 +1291,7 @@ def _transform_one(transformer, X, y, weight, columns=None, params=None):
12971291

12981292

12991293
def _fit_transform_one(
1300-
transformer,
1301-
X,
1302-
y,
1303-
weight,
1304-
columns=None,
1305-
message_clsname="",
1306-
message=None,
1307-
params=None,
1294+
transformer, X, y, weight, message_clsname="", message=None, params=None
13081295
):
13091296
"""
13101297
Fits ``transformer`` to ``X`` and ``y``. The transformed result is returned
@@ -1313,9 +1300,6 @@ def _fit_transform_one(
13131300
13141301
``params`` needs to be of the form ``process_routing()["step_name"]``.
13151302
"""
1316-
if columns is not None:
1317-
X = _safe_indexing(X, columns, axis=1)
1318-
13191303
params = params or {}
13201304
with _print_elapsed_time(message_clsname, message):
13211305
if hasattr(transformer, "fit_transform"):

0 commit comments

Comments
 (0)