8000 ENH implement metadata routing in Pipeline (#26789) · punndcoder28/scikit-learn@a8c83ff · GitHub
[go: up one dir, main page]

Skip to content

Commit a8c83ff

Browse files
adrinjalalipunndcoder28
authored andcommitted
ENH implement metadata routing in Pipeline (scikit-learn#26789)
1 parent 92bc3cf commit a8c83ff

File tree

5 files changed

+596
-92
lines changed

5 files changed

+596
-92
lines changed

doc/whats_new/v1.4.rst

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -78,6 +78,13 @@ Changelog
7878
- |Fix| :func:`feature_selection.mutual_info_regression` now correctly computes the
7979
result when `X` is of integer dtype. :pr:`26748` by :user:`Yao Xiao <Charlie-XIAO>`.
8080

81+
:mod:`sklearn.pipeline`
82+
.......................
83+
84+
- |Feature| :class:`pipeline.Pipeline` now supports metadata routing according
85+
to :ref:`metadata routing user guide <metadata_routing>`. :pr:`26789` by
86+
`Adrin Jalali`_.
87+
8188
:mod:`sklearn.tree`
8289
...................
8390

examples/miscellaneous/plot_metadata_routing.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -492,8 +492,17 @@ def transform(self, X, groups=None):
492492
check_metadata(self, groups=groups)
493493
return X
494494

495+
def fit_transform(self, X, y, sample_weight=None, groups=None):
496+
return self.fit(X, y, sample_weight).transform(X, groups)
497+
495498

496499
# %%
500+
# Note that in the above example, we have implemented ``fit_transform`` which
501+
# calls ``fit`` and ``transform`` with the appropriate metadata. This is only
502+
# required if ``transform`` accepts metadata, since the default ``fit_transform``
503+
# implementation in :class:`~base.TransformerMixin` doesn't pass metadata to
504+
# ``transform``.
505+
#
497506
# Now we can test our pipeline, and see if metadata is correctly passed around.
498507
# This example uses our simple pipeline, and our transformer, and our
499508
# consumer+router estimator which uses our simple classifier.

sklearn/conftest.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
from _pytest.doctest import DoctestItem
1313
from threadpoolctl import threadpool_limits
1414

15+
from sklearn import config_context
1516
from sklearn._min_dependencies import PYTEST_MIN_VERSION
1617
from sklearn.datasets import (
1718
fetch_20newsgroups,
@@ -35,6 +36,13 @@
3536
scipy_datasets_require_network = sp_version >= parse_version("1.10")
3637

3738

39+
@pytest.fixture
40+
def enable_slep006():
41+
"""Enable SLEP006 for all tests."""
42+
with config_context(enable_metadata_routing=True):
43+
yield
44+
45+
3846
def raccoon_face_or_skip():
3947
# SciPy >= 1.10 requires network to access to get data
4048
if scipy_datasets_require_network:

0 commit comments

Comments
 (0)
0