8000 Fix performance regression in ColumnTransformer (#29330) · scikit-learn/scikit-learn@5a74cc0 · GitHub
[go: up one dir, main page]

Skip to content

Commit 5a74cc0

Browse files
committed
Fix performance regression in ColumnTransformer (#29330)
1 parent c3d69b2 commit 5a74cc0

File tree

3 files changed

+12
-22
lines changed

3 files changed

+12
-22
lines changed

doc/whats_new/v1.5.rst

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,13 @@ Changes impacting many modules
3333
Changelog
3434
---------
3535

36+
:mod:`sklearn.compose`
37+
......................
38+
39+
- |Efficiency| Fix a performance regression in :class:`compose.ColumnTransformer`
40+
where the full input data was copied for each transformer when `n_jobs > 1`.
41+
:pr:`29330` by :user:`Jérémie du Boisberranger <jeremiedbb>`.
42+
3643
:mod:`sklearn.metrics`
3744
......................
3845

sklearn/compose/_column_transformer.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
from ..preprocessing import FunctionTransformer
2121
from ..utils import Bunch
2222
from ..utils._estimator_html_repr import _VisualBlock
23-
from ..utils._indexing import _determine_key_type, _get_column_indices
23+
from ..utils._indexing import _determine_key_type, _get_column_indices, _safe_indexing
2424
from ..utils._metadata_requests import METHODS
2525
from ..utils._param_validation import HasMethods, Hidden, Interval, StrOptions
2626
from ..utils._set_output import (
@@ -874,10 +874,9 @@ def _call_func_on_transformers(self, X, y, func, column_as_labels, routed_params
874874
jobs.append(
875875
delayed(func)(
876876
transformer=clone(trans) if not fitted else trans,
877-
X=X,
877+
X=_safe_indexing(X, columns, axis=1),
878878
y=y,
879879
weight=weight,
880-
columns=columns,
881880
**extra_args,
882881
params=routed_params[name],
883882
)

sklearn/pipeline.py

Lines changed: 3 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
from .base import TransformerMixin, _fit_context, clone
1717
from .exceptions import NotFittedError
1818
from .preprocessing import FunctionTransformer
19-
from .utils import Bunch, _safe_indexing
19+
from .utils import Bunch
2020
from .utils._estimator_html_repr import _VisualBlock
2121
from .utils._metadata_requests import METHODS
2222
from .utils._param_validation import HasMethods, Hidden
@@ -1265,7 +1265,7 @@ def make_pipeline(*steps, memory=None, verbose=False):
12651265
return Pipeline(_name_estimators(steps), memory=memory, verbose=verbose)
12661266

12671267

1268-
def _transform_one(transformer, X, y, weight, columns=None, params=None):
1268+
def _transform_one(transformer, X, y, weight, params=None):
12691269
"""Call transform and apply weight to output.
12701270
12711271
Parameters
@@ -1282,17 +1282,11 @@ def _transform_one(transformer, X, y, weight, columns=None, params=None):
12821282
weight : float
12831283
Weight to be applied to the output of the transformation.
12841284
1285-
columns : str, array-like of str, int, array-like of int, array-like of bool, slice
1286-
Columns to select before transforming.
1287-
12881285
params : dict
12891286
Parameters to be passed to the transformer's ``transform`` method.
12901287
12911288
This should be of the form ``process_routing()["step_name"]``.
12921289
"""
1293-
if columns is not None:
1294-
X = _safe_indexing(X, columns, axis=1)
1295-
12961290
res = transformer.transform(X, **params.transform)
12971291
# if we have a weight for this transformer, multiply output
12981292
if weight is None:
@@ -1301,14 +1295,7 @@ def _transform_one(transformer, X, y, weight, columns=None, params=None):
13011295

13021296

13031297
def _fit_transform_one(
1304-
transformer,
1305-
X,
1306-
y,
1307-
weight,
1308-
columns=None,
1309-
message_clsname="",
1310-
message=None,
1311-
params=None,
1298+
transformer, X, y, weight, message_clsname="", message=None, params=None
13121299
):
13131300
"""
13141301
Fits ``transformer`` to ``X`` and ``y``. The transformed result is returned
@@ -1317,9 +1304,6 @@ def _fit_transform_one(
13171304
13181305
``params`` needs to be of the form ``process_routing()["step_name"]``.
13191306
"""
1320-
if columns is not None:
1321-
X = _safe_indexing(X, columns, axis=1)
1322-
13231307
params = params or {}
13241308
with _print_elapsed_time(message_clsname, message):
13251309
if hasattr(transformer, "fit_transform"):

0 commit comments

Comments
 (0)
0