diff --git a/doc/metadata_routing.rst b/doc/metadata_routing.rst index d319b311dddd7..0ada6ef6c4dbe 100644 --- a/doc/metadata_routing.rst +++ b/doc/metadata_routing.rst @@ -277,6 +277,8 @@ Meta-estimators and functions supporting metadata routing: - :class:`sklearn.calibration.CalibratedClassifierCV` - :class:`sklearn.compose.ColumnTransformer` - :class:`sklearn.covariance.GraphicalLassoCV` +- :class:`sklearn.ensemble.StackingClassifier` +- :class:`sklearn.ensemble.StackingRegressor` - :class:`sklearn.ensemble.VotingClassifier` - :class:`sklearn.ensemble.VotingRegressor` - :class:`sklearn.ensemble.BaggingClassifier` @@ -316,13 +318,9 @@ Meta-estimators and tools not supporting metadata routing yet: - :class:`sklearn.compose.TransformedTargetRegressor` - :class:`sklearn.ensemble.AdaBoostClassifier` - :class:`sklearn.ensemble.AdaBoostRegressor` -- :class:`sklearn.ensemble.StackingClassifier` -- :class:`sklearn.ensemble.StackingRegressor` - :class:`sklearn.feature_selection.RFE` - :class:`sklearn.feature_selection.RFECV` - :class:`sklearn.feature_selection.SequentialFeatureSelector` -- :class:`sklearn.impute.IterativeImputer` -- :class:`sklearn.linear_model.RANSACRegressor` - :class:`sklearn.model_selection.learning_curve` - :class:`sklearn.model_selection.permutation_test_score` - :class:`sklearn.model_selection.validation_curve` diff --git a/doc/modules/ensemble.rst b/doc/modules/ensemble.rst index 4237d023973f7..58c9127850f6a 100644 --- a/doc/modules/ensemble.rst +++ b/doc/modules/ensemble.rst @@ -1581,8 +1581,8 @@ availability, tested in the order of preference: `predict_proba`, `decision_function` and `predict`. A :class:`StackingRegressor` and :class:`StackingClassifier` can be used as -any other regressor or classifier, exposing a `predict`, `predict_proba`, and -`decision_function` methods, e.g.:: +any other regressor or classifier, exposing a `predict`, `predict_proba`, or +`decision_function` method, e.g.:: >>> y_pred = reg.predict(X_test) >>> from sklearn.metrics import r2_score diff --git a/doc/whats_new/v1.6.rst b/doc/whats_new/v1.6.rst index 6eda6717b3d1b..5000866b59c03 100644 --- a/doc/whats_new/v1.6.rst +++ b/doc/whats_new/v1.6.rst @@ -38,7 +38,19 @@ See :ref:`array_api` for more details. **Classes:** -- +- + +Metadata Routing +---------------- + +The following models now support metadata routing in one or more of their +methods. Refer to the :ref:`Metadata Routing User Guide ` for +more details. + +- |Feature| :class:`ensemble.StackingClassifier` and + :class:`ensemble.StackingRegressor` now support metadata routing and pass + ``**fit_params`` to the underlying estimators via their `fit` methods. + :pr:`28701` by :user:`Stefanie Senger `. Changelog --------- diff --git a/sklearn/ensemble/_base.py b/sklearn/ensemble/_base.py index 5483206de51d5..18079b02c49f1 100644 --- a/sklearn/ensemble/_base.py +++ b/sklearn/ensemble/_base.py @@ -21,7 +21,7 @@ def _fit_single_estimator( estimator, X, y, fit_params, message_clsname=None, message=None ): """Private function used to fit an estimator within a job.""" - # TODO(SLEP6): remove if condition for unrouted sample_weight when metadata + # TODO(SLEP6): remove if-condition for unrouted sample_weight when metadata # routing can't be disabled. if not _routing_enabled() and "sample_weight" in fit_params: try: diff --git a/sklearn/ensemble/_stacking.py b/sklearn/ensemble/_stacking.py index a18803d507ffa..9dc93b6c35975 100644 --- a/sklearn/ensemble/_stacking.py +++ b/sklearn/ensemble/_stacking.py @@ -27,8 +27,11 @@ from ..utils._estimator_html_repr import _VisualBlock from ..utils._param_validation import HasMethods, StrOptions from ..utils.metadata_routing import ( - _raise_for_unsupported_routing, - _RoutingNotSupportedMixin, + MetadataRouter, + MethodMapping, + _raise_for_params, + _routing_enabled, + process_routing, ) from ..utils.metaestimators import available_if from ..utils.multiclass import check_classification_targets, type_of_target @@ -36,6 +39,7 @@ from ..utils.validation import ( _check_feature_names_in, _check_response_method, + _deprecate_positional_args, check_is_fitted, column_or_1d, ) @@ -171,7 +175,7 @@ def _method_name(name, estimator, method): # estimators in Stacking*.estimators are not validated yet prefer_skip_nested_validation=False ) - def fit(self, X, y, sample_weight=None): + def fit(self, X, y, **fit_params): """Fit the estimators. Parameters @@ -183,14 +187,13 @@ def fit(self, X, y, sample_weight=None): y : array-like of shape (n_samples,) Target values. - sample_weight : array-like of shape (n_samples,) or default=None - Sample weights. If None, then samples are equally weighted. - Note that this is supported only if all underlying estimators - support sample weights. + **fit_params : dict + Dict of metadata, potentially containing sample_weight as a + key-value pair. If sample_weight is not present, then samples are + equally weighted. Note that sample_weight is supported only if all + underlying estimators support sample weights. - .. versionchanged:: 0.23 - when not None, `sample_weight` is passed to all underlying - estimators + .. versionadded:: 1.6 Returns ------- @@ -201,16 +204,19 @@ def fit(self, X, y, sample_weight=None): names, all_estimators = self._validate_estimators() self._validate_final_estimator() - # FIXME: when adding support for metadata routing in Stacking*. - # This is a hotfix to make StackingClassifier and StackingRegressor - # pass the tests despite not supporting metadata routing but sharing - # the same base class with VotingClassifier and VotingRegressor. - fit_params = dict() - if sample_weight is not None: - fit_params["sample_weight"] = sample_weight - stack_method = [self.stack_method] * len(all_estimators) + if _routing_enabled(): + routed_params = process_routing(self, "fit", **fit_params) + else: + routed_params = Bunch() + for name in names: + routed_params[name] = Bunch(fit={}) + if "sample_weight" in fit_params: + routed_params[name].fit["sample_weight"] = fit_params[ + "sample_weight" + ] + if self.cv == "prefit": self.estimators_ = [] for estimator in all_estimators: @@ -222,8 +228,10 @@ def fit(self, X, y, sample_weight=None): # base estimators will be used in transform, predict, and # predict_proba. They are exposed publicly. self.estimators_ = Parallel(n_jobs=self.n_jobs)( - delayed(_fit_single_estimator)(clone(est), X, y, fit_params) - for est in all_estimators + delayed(_fit_single_estimator)( + clone(est), X, y, routed_params[name]["fit"] + ) + for name, est in zip(names, all_estimators) if est != "drop" ) @@ -269,10 +277,10 @@ def fit(self, X, y, sample_weight=None): cv=deepcopy(cv), method=meth, n_jobs=self.n_jobs, - params=fit_params, + params=routed_params[name]["fit"], verbose=self.verbose, ) - for est, meth in zip(all_estimators, self.stack_method_) + for name, est, meth in zip(names, all_estimators, self.stack_method_) if est != "drop" ) @@ -370,7 +378,7 @@ def predict(self, X, **predict_params): Parameters to the `predict` called by the `final_estimator`. Note that this may be used to return uncertainties from some estimators with `return_std` or `return_cov`. Be aware that it will only - accounts for uncertainty in the final estimator. + account for uncertainty in the final estimator. Returns ------- @@ -392,8 +400,43 @@ def _sk_visual_block_with_final_estimator(self, final_estimator): ) return _VisualBlock("serial", (parallel, final_block), dash_wrapped=False) + def get_metadata_routing(self): + """Get metadata routing of this object. + + Please check :ref:`User Guide ` on how the routing + mechanism works. + + .. versionadded:: 1.6 + + Returns + ------- + routing : MetadataRouter + A :class:`~sklearn.utils.metadata_routing.MetadataRouter` encapsulating + routing information. + """ + router = MetadataRouter(owner=self.__class__.__name__) + + # `self.estimators` is a list of (name, est) tuples + for name, estimator in self.estimators: + router.add( + **{name: estimator}, + method_mapping=MethodMapping().add(callee="fit", caller="fit"), + ) + + try: + final_estimator_ = self.final_estimator_ + except AttributeError: + final_estimator_ = self.final_estimator + + router.add( + final_estimator_=final_estimator_, + method_mapping=MethodMapping().add(caller="predict", callee="predict"), + ) + + return router + -class StackingClassifier(_RoutingNotSupportedMixin, ClassifierMixin, _BaseStacking): +class StackingClassifier(ClassifierMixin, _BaseStacking): """Stack of estimators with a final classifier. Stacked generalization consists in stacking the output of individual @@ -528,7 +571,7 @@ class StackingClassifier(_RoutingNotSupportedMixin, ClassifierMixin, _BaseStacki ----- When `predict_proba` is used by each estimator (i.e. most of the time for `stack_method='auto'` or specifically for `stack_method='predict_proba'`), - The first column predicted by each estimator will be dropped in the case + the first column predicted by each estimator will be dropped in the case of a binary classification problem. Indeed, both feature will be perfectly collinear. @@ -629,7 +672,11 @@ def _validate_estimators(self): return names, estimators - def fit(self, X, y, sample_weight=None): + # TODO(1.7): remove `sample_weight` from the signature after deprecation + # cycle; pop it from `fit_params` before the `_raise_for_params` check and + # reinsert afterwards, for backwards compatibility + @_deprecate_positional_args(version="1.7") + def fit(self, X, y, *, sample_weight=None, **fit_params): """Fit the estimators. Parameters @@ -649,12 +696,22 @@ def fit(self, X, y, sample_weight=None): Note that this is supported only if all underlying estimators support sample weights. + **fit_params : dict + Parameters to pass to the underlying estimators. + + .. versionadded:: 1.6 + + Only available if `enable_metadata_routing=True`, which can be + set by using ``sklearn.set_config(enable_metadata_routing=True)``. + See :ref:`Metadata Routing User Guide ` for + more details. + Returns ------- self : object Returns a fitted instance of estimator. """ - _raise_for_unsupported_routing(self, "fit", sample_weight=sample_weight) + _raise_for_params(fit_params, self, "fit") check_classification_targets(y) if type_of_target(y) == "multilabel-indicator": self._label_encoder = [LabelEncoder().fit(yk) for yk in y.T] @@ -669,7 +726,10 @@ def fit(self, X, y, sample_weight=None): self._label_encoder = LabelEncoder().fit(y) self.classes_ = self._label_encoder.classes_ y_encoded = self._label_encoder.transform(y) - return super().fit(X, y_encoded, sample_weight) + + if sample_weight is not None: + fit_params["sample_weight"] = sample_weight + return super().fit(X, y_encoded, **fit_params) @available_if(_estimator_has("predict")) def predict(self, X, **predict_params): @@ -685,14 +745,33 @@ def predict(self, X, **predict_params): Parameters to the `predict` called by the `final_estimator`. Note that this may be used to return uncertainties from some estimators with `return_std` or `return_cov`. Be aware that it will only - accounts for uncertainty in the final estimator. + account for uncertainty in the final estimator. + + - If `enable_metadata_routing=False` (default): + Parameters directly passed to the `predict` method of the + `final_estimator`. + + - If `enable_metadata_routing=True`: Parameters safely routed to + the `predict` method of the `final_estimator`. See :ref:`Metadata + Routing User Guide ` for more details. + + .. versionchanged:: 1.6 + `**predict_params` can be routed via metadata routing API. Returns ------- y_pred : ndarray of shape (n_samples,) or (n_samples, n_output) Predicted targets. """ - y_pred = super().predict(X, **predict_params) + if _routing_enabled(): + routed_params = process_routing(self, "predict", **predict_params) + else: + # TODO(SLEP6): remove when metadata routing cannot be disabled. + routed_params = Bunch() + routed_params.final_estimator_ = Bunch(predict={}) + routed_params.final_estimator_.predict = predict_params + + y_pred = super().predict(X, **routed_params.final_estimator_["predict"]) if isinstance(self._label_encoder, list): # Handle the multilabel-indicator case y_pred = np.array( @@ -775,7 +854,7 @@ def _sk_visual_block_(self): return super()._sk_visual_block_with_final_estimator(final_estimator) -class StackingRegressor(_RoutingNotSupportedMixin, RegressorMixin, _BaseStacking): +class StackingRegressor(RegressorMixin, _BaseStacking): """Stack of estimators with a final regressor. Stacked generalization consists in stacking the output of individual @@ -944,7 +1023,11 @@ def _validate_final_estimator(self): ) ) - def fit(self, X, y, sample_weight=None): + # TODO(1.7): remove `sample_weight` from the signature after deprecation + # cycle; pop it from `fit_params` before the `_raise_for_params` check and + # reinsert afterwards, for backwards compatibility + @_deprecate_positional_args(version="1.7") + def fit(self, X, y, *, sample_weight=None, **fit_params): """Fit the estimators. Parameters @@ -961,14 +1044,26 @@ def fit(self, X, y, sample_weight=None): Note that this is supported only if all underlying estimators support sample weights. + **fit_params : dict + Parameters to pass to the underlying estimators. + + .. versionadded:: 1.6 + + Only available if `enable_metadata_routing=True`, which can be + set by using ``sklearn.set_config(enable_metadata_routing=True)``. + See :ref:`Metadata Routing User Guide ` for + more details. + Returns ------- self : object Returns a fitted instance. """ - _raise_for_unsupported_routing(self, "fit", sample_weight=sample_weight) + _raise_for_params(fit_params, self, "fit") y = column_or_1d(y, warn=True) - return super().fit(X, y, sample_weight) + if sample_weight is not None: + fit_params["sample_weight"] = sample_weight + return super().fit(X, y, **fit_params) def transform(self, X): """Return the predictions for X for each estimator. @@ -986,7 +1081,11 @@ def transform(self, X): """ return self._transform(X) - def fit_transform(self, X, y, sample_weight=None): + # TODO(1.7): remove `sample_weight` from the signature after deprecation + # cycle; pop it from `fit_params` before the `_raise_for_params` check and + # reinsert afterwards, for backwards compatibility + @_deprecate_positional_args(version="1.7") + def fit_transform(self, X, y, *, sample_weight=None, **fit_params): """Fit the estimators and return the predictions for X for each estimator. Parameters @@ -1003,12 +1102,69 @@ def fit_transform(self, X, y, sample_weight=None): Note that this is supported only if all underlying estimators support sample weights. + **fit_params : dict + Parameters to pass to the underlying estimators. + + .. versionadded:: 1.6 + + Only available if `enable_metadata_routing=True`, which can be + set by using ``sklearn.set_config(enable_metadata_routing=True)``. + See :ref:`Metadata Routing User Guide ` for + more details. + Returns ------- y_preds : ndarray of shape (n_samples, n_estimators) Prediction outputs for each estimator. """ - return super().fit_transform(X, y, sample_weight=sample_weight) + _raise_for_params(fit_params, self, "fit") + if sample_weight is not None: + fit_params["sample_weight"] = sample_weight + return super().fit_transform(X, y, **fit_params) + + @available_if(_estimator_has("predict")) + def predict(self, X, **predict_params): + """Predict target for X. + + Parameters + ---------- + X : {array-like, sparse matrix} of shape (n_samples, n_features) + Training vectors, where `n_samples` is the number of samples and + `n_features` is the number of features. + + **predict_params : dict of str -> obj + Parameters to the `predict` called by the `final_estimator`. Note + that this may be used to return uncertainties from some estimators + with `return_std` or `return_cov`. Be aware that it will only + account for uncertainty in the final estimator. + + - If `enable_metadata_routing=False` (default): + Parameters directly passed to the `predict` method of the + `final_estimator`. + + - If `enable_metadata_routing=True`: Parameters safely routed to + the `predict` method of the `final_estimator`. See :ref:`Metadata + Routing User Guide ` for more details. + + .. versionchanged:: 1.6 + `**predict_params` can be routed via metadata routing API. + + Returns + ------- + y_pred : ndarray of shape (n_samples,) or (n_samples, n_output) + Predicted targets. + """ + if _routing_enabled(): + routed_params = process_routing(self, "predict", **predict_params) + else: + # TODO(SLEP6): remove when metadata routing cannot be disabled. + routed_params = Bunch() + routed_params.final_estimator_ = Bunch(predict={}) + routed_params.final_estimator_.predict = predict_params + + y_pred = super().predict(X, **routed_params.final_estimator_["predict"]) + + return y_pred def _sk_visual_block_(self): # If final_estimator's default changes then this should be diff --git a/sklearn/ensemble/tests/test_stacking.py b/sklearn/ensemble/tests/test_stacking.py index 300b011f661d4..1c038cd469216 100644 --- a/sklearn/ensemble/tests/test_stacking.py +++ b/sklearn/ensemble/tests/test_stacking.py @@ -3,6 +3,7 @@ # Authors: Guillaume Lemaitre # License: BSD 3 clause +import re from unittest.mock import Mock import numpy as np @@ -38,6 +39,12 @@ from sklearn.neural_network import MLPClassifier from sklearn.preprocessing import scale from sklearn.svm import SVC, LinearSVC, LinearSVR +from sklearn.tests.metadata_routing_common import ( + ConsumingClassifier, + ConsumingRegressor, + _Registry, + check_recorded_metadata, +) from sklearn.utils._mocking import CheckingClassifier from sklearn.utils._testing import ( assert_allclose, @@ -888,3 +895,116 @@ def test_stacking_final_estimator_attribute_error(): clf.fit(X, y).decision_function(X) assert isinstance(exec_info.value.__cause__, AttributeError) assert inner_msg in str(exec_info.value.__cause__) + + +# Metadata Routing Tests +# ====================== + + +@pytest.mark.parametrize( + "Estimator, Child", + [ + (StackingClassifier, ConsumingClassifier), + (StackingRegressor, ConsumingRegressor), + ], +) +def test_routing_passed_metadata_not_supported(Estimator, Child): + """Test that the right error message is raised when metadata is passed while + not supported when `enable_metadata_routing=False`.""" + + with pytest.raises( + ValueError, match="is only supported if enable_metadata_routing=True" + ): + Estimator(["clf", Child()]).fit( + X_iris, y_iris, sample_weight=[1, 1, 1, 1, 1], metadata="a" + ) + + +@pytest.mark.usefixtures("enable_slep006") +@pytest.mark.parametrize( + "Estimator, Child", + [ + (StackingClassifier, ConsumingClassifier), + (StackingRegressor, ConsumingRegressor), + ], +) +def test_get_metadata_routing_without_fit(Estimator, Child): + # Test that metadata_routing() doesn't raise when called before fit. + est = Estimator([("sub_est", Child())]) + est.get_metadata_routing() + + +@pytest.mark.usefixtures("enable_slep006") +@pytest.mark.parametrize( + "Estimator, Child", + [ + (StackingClassifier, ConsumingClassifier), + (StackingRegressor, ConsumingRegressor), + ], +) +@pytest.mark.parametrize( + "prop, prop_value", [("sample_weight", np.ones(X_iris.shape[0])), ("metadata", "a")] +) +def test_metadata_routing_for_stacking_estimators(Estimator, Child, prop, prop_value): + """Test that metadata is routed correctly for Stacking*.""" + + est = Estimator( + [ + ( + "sub_est1", + Child(registry=_Registry()).set_fit_request(**{prop: True}), + ), + ( + "sub_est2", + Child(registry=_Registry()).set_fit_request(**{prop: True}), + ), + ], + final_estimator=Child(registry=_Registry()).set_predict_request(**{prop: True}), + ) + + est.fit(X_iris, y_iris, **{prop: prop_value}) + est.fit_transform(X_iris, y_iris, **{prop: prop_value}) + + est.predict(X_iris, **{prop: prop_value}) + + for estimator in est.estimators: + # access sub-estimator in (name, est) with estimator[1]: + registry = estimator[1].registry + assert len(registry) + for sub_est in registry: + check_recorded_metadata( + obj=sub_est, method="fit", split_params=(prop), **{prop: prop_value} + ) + # access final_estimator: + registry = est.final_estimator_.registry + assert len(registry) + check_recorded_metadata( + obj=registry[-1], method="predict", split_params=(prop), **{prop: prop_value} + ) + + +@pytest.mark.usefixtures("enable_slep006") +@pytest.mark.parametrize( + "Estimator, Child", + [ + (StackingClassifier, ConsumingClassifier), + (StackingRegressor, ConsumingRegressor), + ], +) +def test_metadata_routing_error_for_stacking_estimators(Estimator, Child): + """Test that the right error is raised when metadata is not requested.""" + sample_weight, metadata = np.ones(X_iris.shape[0]), "a" + + est = Estimator([("sub_est", Child())]) + + error_message = ( + "[sample_weight, metadata] are passed but are not explicitly set as requested" + f" or not requested for {Child.__name__}.fit" + ) + + with pytest.raises(ValueError, match=re.escape(error_message)): + est.fit(X_iris, y_iris, sample_weight=sample_weight, metadata=metadata) + + +# End of Metadata Routing Tests +# ============================= diff --git a/sklearn/tests/metadata_routing_common.py b/sklearn/tests/metadata_routing_common.py index 889524bc05ddb..5091569e434a3 100644 --- a/sklearn/tests/metadata_routing_common.py +++ b/sklearn/tests/metadata_routing_common.py @@ -257,16 +257,13 @@ def predict(self, X, sample_weight="default", metadata="default"): record_metadata_not_default( self, "predict", sample_weight=sample_weight, metadata=metadata ) - return np.zeros(shape=(len(X),)) + return np.zeros(shape=(len(X),), dtype="int8") def predict_proba(self, X, sample_weight="default", metadata="default"): - pass # pragma: no cover - - # uncomment when needed - # record_metadata_not_default( - # self, "predict_proba", sample_weight=sample_weight, metadata=metadata - # ) - # return np.asarray([[0.0, 1.0]] * len(X)) + record_metadata_not_default( + self, "predict_proba", sample_weight=sample_weight, metadata=metadata + ) + return np.asarray([[0.0, 1.0]] * len(X)) def predict_log_proba(self, X, sample_weight="default", metadata="default"): pass # pragma: no cover diff --git a/sklearn/tests/test_metaestimators_metadata_routing.py b/sklearn/tests/test_metaestimators_metadata_routing.py index aa6af5bd09aac..38168f3f0261f 100644 --- a/sklearn/tests/test_metaestimators_metadata_routing.py +++ b/sklearn/tests/test_metaestimators_metadata_routing.py @@ -14,8 +14,6 @@ AdaBoostRegressor, BaggingClassifier, BaggingRegressor, - StackingClassifier, - StackingRegressor, ) from sklearn.exceptions import UnsetMetadataPassedError from sklearn.experimental import ( @@ -408,8 +406,6 @@ def enable_slep006(): RFECV(ConsumingClassifier()), SelfTrainingClassifier(ConsumingClassifier()), SequentialFeatureSelector(ConsumingClassifier()), - StackingClassifier(ConsumingClassifier()), - StackingRegressor(ConsumingRegressor()), TransformedTargetRegressor(), ]