8000 MNT SLEP6: raise NotImplementedError for meta-estimators not supporti… · lesteve/scikit-learn@db09a76 · GitHub
[go: up one dir, main page]

Skip to content

Commit db09a76

Browse files
adrinjalalilesteve
authored andcommitted
MNT SLEP6: raise NotImplementedError for meta-estimators not supporting metadata routing (scikit-learn#27389)
1 parent 17c8401 commit db09a76

23 files changed

+367
-35
lines changed

doc/metadata_routing.rst

Lines changed: 74 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -10,9 +10,11 @@ Metadata Routing
1010

1111
.. note::
1212
The Metadata Routing API is experimental, and is not implemented yet for many
13-
estimators. It may change without the usual deprecation cycle. By default
14-
this feature is not enabled. You can enable this feature by setting the
15-
``enable_metadata_routing`` flag to ``True``:
13+
estimators. Please refer to the :ref:`list of supported and unsupported
14+
models <metadata_routing_models>` for more information. It may change without
15+
the usual deprecation cycle. By default this feature is not enabled. You can
16+
enable this feature by setting the ``enable_metadata_routing`` flag to
17+
``True``::
1618

1719
>>> import sklearn
1820
>>> sklearn.set_config(enable_metadata_routing=True)
@@ -230,3 +232,72 @@ The issue can be fixed by explicitly setting the request value::
230232
>>> lr = LogisticRegression().set_fit_request(
231233
... sample_weight=True
232234
... ).set_score_request(sample_weight=False)
235+
236+
At the end we disable the configuration flag for metadata routing::
237+
238+
>>> sklearn.set_config(enable_metadata_routing=False)
239+
240+
.. _metadata_routing_models:
241+
242+
Metadata Routing Support Status
243+
*******************************
244+
All consumers (i.e. simple estimators which only consume metadata and don't
245+
route them) support metadata routing, meaning they can be used inside
246+
meta-estimators which support metadata routing. However, development of support
247+
for metadata routing for meta-estimators is in progress, and here is a list of
248+
meta-estimators and tools which support and don't yet support metadata routing.
249+
250+
251+
Meta-estimators and functions supporting metadata routing:
252+
253+
- :class:`sklearn.calibration.CalibratedClassifierCV`
254+
- :class:`sklearn.compose.ColumnTransformer`
255+
- :class:`sklearn.linear_model.LogisticRegressionCV`
256+
- :class:`sklearn.model_selection.GridSearchCV`
257+
- :class:`sklearn.model_selection.HalvingGridSearchCV`
258+
- :class:`sklearn.model_selection.HalvingRandomSearchCV`
259+
- :class:`sklearn.model_selection.RandomizedSearchCV`
260+
- :func:`sklearn.model_selection.cross_validate`
261+
- :func:`sklearn.model_selection.cross_val_score`
262+
- :func:`sklearn.model_selection.cross_val_predict`
263+
- :class:`sklearn.multioutput.ClassifierChain`
264+
- :class:`sklearn.multioutput.MultiOutputClassifier`
265+
- :class:`sklearn.multioutput.MultiOutputRegressor`
266+
- :class:`sklearn.multioutput.RegressorChain`
267+
- :class:`sklearn.pipeline.Pipeline`
268+
269+
Meta-estimators and tools not supporting metadata routing yet:
270+
271+
- :class:`sklearn.compose.TransformedTargetRegressor`
272+
- :class:`sklearn.covariance.GraphicalLassoCV`
273+
- :class:`sklearn.ensemble.AdaBoostClassifier`
274+
- :class:`sklearn.ensemble.AdaBoostRegressor`
275+
- :class:`sklearn.ensemble.BaggingClassifier`
276+
- :class:`sklearn.ensemble.BaggingRegressor`
277+
- :class:`sklearn.ensemble.StackingClassifier`
278+
- :class:`sklearn.ensemble.StackingRegressor`
279+
- :class:`sklearn.ensemble.VotingClassifier`
280+
- :class:`sklearn.ensemble.VotingRegressor`
281+
- :class:`sklearn.feature_selection.RFE`
282+
- :class:`sklearn.feature_selection.RFECV`
283+
- :class:`sklearn.feature_selection.SelectFromModel`
284+
- :class:`sklearn.feature_selection.SequentialFeatureSelector`
285+
- :class:`sklearn.impute.IterativeImputer`
286+
- :class:`sklearn.linear_model.ElasticNetCV`
287+
- :class:`sklearn.linear_model.LarsCV`
288+
- :class:`sklearn.linear_model.LassoCV`
289+
- :class:`sklearn.linear_model.LassoLarsCV`
290+
- :class:`sklearn.linear_model.MultiTaskElasticNetCV`
291+
- :class:`sklearn.linear_model.MultiTaskLassoCV`
292+
- :class:`sklearn.linear_model.OrthogonalMatchingPursuitCV`
293+
- :class:`sklearn.linear_model.RANSACRegressor`
294+
- :class:`sklearn.linear_model.RidgeClassifierCV`
295+
- :class:`sklearn.linear_model.RidgeCV`
296+
- :class:`sklearn.model_selection.learning_curve`
297+
- :class:`sklearn.model_selection.permutation_test_score`
298+
- :class:`sklearn.model_selection.validation_curve`
299+
- :class:`sklearn.multiclass.OneVsOneClassifier`
300+
- :class:`sklearn.multiclass.OneVsRestClassifier`
301+
- :class:`sklearn.multiclass.OutputCodeClassifier`
302+
- :class:`sklearn.pipeline.FeatureUnion`
303+
- :class:`sklearn.semi_supervised.SelfTrainingClassifier`

doc/whats_new/v1.4.rst

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -58,17 +58,22 @@ more details.
5858
estimator's ``fit``, the CV splitter, and the scorer. :pr:`27058` by `Adrin
5959
Jalali`_.
6060

61-
- |Enhancement| :class:`~compose.ColumnTransformer` now supports metadata routing
61+
- |Feature| :class:`~compose.ColumnTransformer` now supports metadata routing
6262
according to :ref:`metadata routing user guide <metadata_routing>`. :pr:`27005`
6363
by `Adrin Jalali`_.
6464

65-
- |Enhancement| :class:`linear_model.LogisticRegressionCV` now supports
65+
- |Feature| :class:`linear_model.LogisticRegressionCV` now supports
6666
metadata routing. :meth:`linear_model.LogisticRegressionCV.fit` now
6767
accepts ``**params`` which are passed to the underlying splitter and
6868
scorer. :meth:`linear_model.LogisticRegressionCV.score` now accepts
6969
``**score_params`` which are passed to the underlying scorer.
7070
:pr:`26525` by :user:`Omar Salman <OmarManzoor>`.
7171

72+
- |Fix| All meta-estimators for which metadata routing is not yet implemented
73+
now raise a `NotImplementedError` on `get_metadata_routing` and on `fit` if
74+
metadata routing is enabled and any metadata is passed to them. :pr:`27389`
75+
by `Adrin Jalali`_.
76+
7277
Support for SciPy sparse arrays
7378
-------------------------------
7479

sklearn/compose/_target.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,12 +12,18 @@
1212
from ..utils import _safe_indexing, check_array
1313
from ..utils._param_validation import HasMethods
1414
from ..utils._tags import _safe_tags
15+
from ..utils.metadata_routing import (
16+
_raise_for_unsupported_routing,
17+
_RoutingNotSupportedMixin,
18+
)
1519
from ..utils.validation import check_is_fitted
1620

1721
__all__ = ["TransformedTargetRegressor"]
1822

1923

20-
class TransformedTargetRegressor(RegressorMixin, BaseEstimator):
24+
class TransformedTargetRegressor(
25+
_RoutingNotSupportedMixin, RegressorMixin, BaseEstimator
26+
):
2127
"""Meta-estimator to regress on a transformed target.
2228
2329
Useful for applying a non-linear transformation to the target `y` in
@@ -222,6 +228,7 @@ def fit(self, X, y, **fit_params):
222228
self : object
223229
Fitted estimator.
224230
"""
231+
_raise_for_unsupported_routing(self, "fit", **fit_params)
225232
if y is None:
226233
raise ValueError(
227234
f"This {self.__class__.__name__} estimator "

sklearn/covariance/_graph_lasso.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
from ..linear_model import lars_path_gram
2323
from ..model_selection import check_cv, cross_val_score
2424
from ..utils._param_validation import Interval, StrOptions, validate_params
25+
from ..utils.metadata_routing import _RoutingNotSupportedMixin
2526
from ..utils.parallel import Parallel, delayed
2627
from ..utils.validation import (
2728
_is_arraylike_not_scalar,
@@ -705,7 +706,7 @@ def graphical_lasso_path(
705706
return covariances_, precisions_
706707

707708

708-
class GraphicalLassoCV(BaseGraphicalLasso):
709+
class GraphicalLassoCV(_RoutingNotSupportedMixin, BaseGraphicalLasso):
709710
"""Sparse inverse covariance w/ cross-validated choice of the l1 penalty.
710711
711712
See glossary entry for :term:`cross-validation estimator`.

sklearn/ensem F438 ble/_bagging.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,10 @@
1919
from ..utils import check_random_state, column_or_1d, indices_to_mask
2020
from ..utils._param_validation import HasMethods, Interval, RealNotInt, StrOptions
2121
from ..utils._tags import _safe_tags
22+
from ..utils.metadata_routing import (
23+
_raise_for_unsupported_routing,
24+
_RoutingNotSupportedMixin,
25+
)
2226
from ..utils.metaestimators import available_if
2327
from ..utils.multiclass import check_classification_targets
2428
from ..utils.parallel import Parallel, delayed
@@ -326,6 +330,7 @@ def fit(self, X, y, sample_weight=None):
326330
self : object
327331
Fitted estimator.
328332
"""
333+
_raise_for_unsupported_routing(self, "fit", sample_weight=sample_weight)
329334
# Convert data (X is required to be 2d and indexable)
330335
X, y = self._validate_data(
331336
X,
@@ -542,7 +547,7 @@ def estimators_samples_(self):
542547
return [sample_indices for _, sample_indices in self._get_estimators_indices()]
543548

544549

545-
class BaggingClassifier(ClassifierMixin, BaseBagging):
550+
class BaggingClassifier(_RoutingNotSupportedMixin, ClassifierMixin, BaseBagging):
546551
"""A Bagging classifier.
547552
548553
A Bagging classifier is an ensemble meta-estimator that fits base
@@ -990,7 +995,7 @@ def _more_tags(self):
990995
return {"allow_nan": _safe_tags(estimator, "allow_nan")}
991996

992997

993-
class BaggingRegressor(RegressorMixin, BaseBagging):
998+
class BaggingRegressor(_RoutingNotSupportedMixin, RegressorMixin, BaseBagging):
994999
"""A Bagging regressor.
9951000
9961001
A Bagging regressor is an ensemble meta-estimator that fits base

sklearn/ensemble/_stacking.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,10 @@
2626
from ..utils import Bunch
2727
from ..utils._estimator_html_repr import _VisualBlock
2828
from ..utils._param_validation import HasMethods, StrOptions
29+
from ..utils.metadata_routing import (
30+
_raise_for_unsupported_routing,
31+
_RoutingNotSupportedMixin,
32+
)
2933
from ..utils.metaestimators import available_if
3034
from ..utils.multiclass import check_classification_targets, type_of_target
3135
from ..utils.parallel import Parallel, delayed
@@ -380,7 +384,7 @@ def _sk_visual_block_with_final_estimator(self, final_estimator):
380384
return _VisualBlock("serial", (parallel, final_block), dash_wrapped=False)
381385

382386

383-
class StackingClassifier(ClassifierMixin, _BaseStacking):
387+
class StackingClassifier(_RoutingNotSupportedMixin, ClassifierMixin, _BaseStacking):
384388
"""Stack of estimators with a final classifier.
385389
386390
Stacked generalization consists in stacking the output of individual
@@ -641,6 +645,7 @@ def fit(self, X, y, sample_weight=None):
641645
self : object
642646
Returns a fitted instance of estimator.
643647
"""
648+
_raise_for_unsupported_routing(self, "fit", sample_weight=sample_weight)
644649
check_classification_targets(y)
645650
if type_of_target(y) == "multilabel-indicator":
646651
self._label_encoder = [LabelEncoder().fit(yk) for yk in y.T]
@@ -761,7 +766,7 @@ def _sk_visual_block_(self):
761766
return super()._sk_visual_block_with_final_estimator(final_estimator)
762767

763768

764-
class StackingRegressor(RegressorMixin, _BaseStacking):
769+
class StackingRegressor(_RoutingNotSupportedMixin, RegressorMixin, _BaseStacking):
765770
"""Stack of estimators with a final regressor.
766771
767772
Stacked generalization consists in stacking the output of individual
@@ -952,6 +957,7 @@ def fit(self, X, y, sample_weight=None):
952957
self : object
953958
Returns a fitted instance.
954959
"""
960+
_raise_for_unsupported_routing(self, "fit", sample_weight=sample_weight)
955961
y = column_or_1d(y, warn=True)
956962
return super().fit(X, y, sample_weight)
957963

sklearn/ensemble/_voting.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,10 @@
3030
from ..utils import Bunch
3131
from ..utils._estimator_html_repr import _VisualBlock
3232
from ..utils._param_validation import StrOptions
33+
from ..utils.metadata_routing import (
34+
_raise_for_unsupported_routing,
35+
_RoutingNotSupportedMixin,
36+
)
3337
from ..utils.metaestimators import available_if
3438
from ..utils.multiclass import check_classification_targets
3539
from ..utils.parallel import Parallel, delayed
@@ -152,7 +156,7 @@ def _more_tags(self):
152156
return {"preserves_dtype": []}
153157

154158

155-
class VotingClassifier(ClassifierMixin, _BaseVoting):
159+
class VotingClassifier(_RoutingNotSupportedMixin, ClassifierMixin, _BaseVoting):
156160
"""Soft Voting/Majority Rule classifier for unfitted estimators.
157161
158162
Read more in the :ref:`User Guide <voting_classifier>`.
@@ -336,6 +340,7 @@ def fit(self, X, y, sample_weight=None):
336340
self : object
337341
Returns the instance itself.
338342
"""
343+
_raise_for_unsupported_routing(self, "fit", sample_weight=sample_weight)
339344
check_classification_targets(y)
340345
if isinstance(y, np.ndarray) and len(y.shape) > 1 and y.shape[1] > 1:
341346
raise NotImplementedError(
@@ -478,7 +483,7 @@ def get_feature_names_out(self, input_features=None):
478483
return np.asarray(names_out, dtype=object)
479484

480485

481-
class VotingRegressor(RegressorMixin, _BaseVoting):
486+
class VotingRegressor(_RoutingNotSupportedMixin, RegressorMixin, _BaseVoting):
482487
"""Prediction voting regressor for unfitted estimators.
483488
484489
A voting regressor is an ensemble meta-estimator that fits several base
@@ -601,6 +606,7 @@ def fit(self, X, y, sample_weight=None):
601606
self : object
602607
Fitted estimator.
603608
"""
609+
_raise_for_unsupported_routing(self, "fit", sample_weight=sample_weight)
604610
y = column_or_1d(y, warn=True)
605611
return super().fit(X, y, sample_weight)
606612

sklearn/ensemble/_weight_boosting.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,10 @@
4242
from ..utils import _safe_indexing, check_random_state
4343
from ..utils._param_validation import HasMethods, Interval, StrOptions
4444
from ..utils.extmath import softmax, stable_cumsum
45+
from ..utils.metadata_routing import (
46+
_raise_for_unsupported_routing,
47+
_RoutingNotSupportedMixin,
48+
)
4549
from ..utils.validation import (
4650
_check_sample_weight,
4751
_num_samples,
@@ -132,6 +136,7 @@ def fit(self, X, y, sample_weight=None):
132136
self : object
133137
Fitted estimator.
134138
"""
139+
_raise_for_unsupported_routing(self, "fit", sample_weight=sample_weight)
135140
X, y = self._validate_data(
136141
X,
137142
y,
@@ -338,7 +343,9 @@ def _samme_proba(estimator, n_classes, X):
338343
)
339344

340345

341-
class AdaBoostClassifier(ClassifierMixin, BaseWeightBoosting):
346+
class AdaBoostClassifier(
347+
_RoutingNotSupportedMixin, ClassifierMixin, BaseWeightBoosting
348+
):
342349
"""An AdaBoost classifier.
343350
344351
An AdaBoost [1]_ classifier is a meta-estimator that begins by fitting a
@@ -980,7 +987,7 @@ def predict_log_proba(self, X):
980987
return np.log(self.predict_proba(X))
981988

982989

983-
class AdaBoostRegressor(RegressorMixin, BaseWeightBoosting):
990+
class AdaBoostRegressor(_RoutingNotSupportedMixin, RegressorMixin, BaseWeightBoosting):
984991
"""An AdaBoost regressor.
985992
986993
An AdaBoost [1] regressor is a meta-estimator that begins by fitting a

sklearn/feature_selection/_from_model.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,10 @@
1010
from ..exceptions import NotFittedError
1111
from ..utils._param_validation import HasMethods, Interval, Options
1212
from ..utils._tags import _safe_tags
13+
from ..utils.metadata_routing import (
14+
_raise_for_unsupported_routing,
15+
_RoutingNotSupportedMixin,
16+
)
1317
from ..utils.metaestimators import available_if
1418
from ..utils.validation import _num_features, check_is_fitted, check_scalar
1519
from ._base import SelectorMixin, _get_feature_importances
@@ -78,7 +82,9 @@ def _estimator_has(attr):
7882
)
7983

8084

81-
class SelectFromModel(MetaEstimatorMixin, SelectorMixin, BaseEstimator):
85+
class SelectFromModel(
86+
_RoutingNotSupportedMixin, MetaEstimatorMixin, SelectorMixin, BaseEstimator
87+
):
8288
"""Meta-transformer for selecting features based on importance weights.
8389
8490
.. versionadded:: 0.17
@@ -342,6 +348,7 @@ def fit(self, X, y=None, **fit_params):
342348
self : object
343349
Fitted estimator.
344350
"""
351+
_raise_for_unsupported_routing(self, "fit", **fit_params)
345352
self._check_max_features(X)
346353

347354
if self.prefit:

sklearn/feature_selection/_rfe.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,10 @@
1717
from ..model_selection._validation import _score
1818
from ..utils._param_validation import HasMethods, Interval, RealNotInt
1919
from ..utils._tags import _safe_tags
20+
from ..utils.metadata_routing import (
21+
_raise_for_unsupported_routing,
22+
_RoutingNotSupportedMixin,
23+
)
2024
from ..utils.metaestimators import _safe_split, available_if
2125
from ..utils.parallel import Parallel, delayed
2226
from ..utils.validation import check_is_fitted
@@ -56,7 +60,7 @@ def _estimator_has(attr):
5660
)
5761

5862

59-
class RFE(SelectorMixin, MetaEstimatorMixin, BaseEstimator):
63+
class RFE(_RoutingNotSupportedMixin, SelectorMixin, MetaEstimatorMixin, BaseEstimator):
6064
"""Feature ranking with recursive feature elimination.
6165
6266
Given an external estimator that assigns weights to features (e.g., the
@@ -251,6 +255,7 @@ def fit(self, X, y, **fit_params):
251255
self : object
252256
Fitted estimator.
253257
"""
258+
_raise_for_unsupported_routing(self, "fit", **fit_params)
254259
return self._fit(X, y, **fit_params)
255260

256261
def _fit(self, X, y, step_score=None, **fit_params):
@@ -680,6 +685,7 @@ def fit(self, X, y, groups=None):
680685
self : object
681686
Fitted estimator.
682687
"""
688+
_raise_for_unsupported_routing(self, "fit", groups=groups)
683689
tags = self._get_tags()
684690
X, y = self._validate_data(
685691
X,

0 commit comments

Comments
 (0)
0