diff --git a/sklearn/multioutput.py b/sklearn/multioutput.py index fbacdc79aa140..fe846492bddda 100644 --- a/sklearn/multioutput.py +++ b/sklearn/multioutput.py @@ -660,6 +660,8 @@ def fit(self, X, Y, **fit_params): del Y_pred_chain + routed_params = process_routing(obj=self, method="fit", other_params=fit_params) + for chain_idx, estimator in enumerate(self.estimators_): message = self._log_message( estimator_idx=chain_idx + 1, @@ -668,7 +670,12 @@ def fit(self, X, Y, **fit_params): ) y = Y[:, self.order_[chain_idx]] with _print_elapsed_time("Chain", message): - estimator.fit(X_aug[:, : (X.shape[1] + chain_idx)], y, **fit_params) + estimator.fit( + X_aug[:, : (X.shape[1] + chain_idx)], + y, + **routed_params.estimator.fit, + ) + if self.cv is not None and chain_idx < len(self.estimators_) - 1: col_idx = X.shape[1] + chain_idx cv_result = cross_val_predict( @@ -831,7 +838,7 @@ class labels for each estimator in the chain. [0.0321..., 0.9935..., 0.0625...]]) """ - def fit(self, X, Y): + def fit(self, X, Y, **fit_params): """Fit the model to data matrix X and targets Y. Parameters @@ -842,6 +849,11 @@ def fit(self, X, Y): Y : array-like of shape (n_samples, n_classes) The target values. + **fit_params : dict of string -> object + Parameters passed to the `fit` method of each step. + + .. versionadded:: 1.2 + Returns ------- self : object @@ -849,7 +861,7 @@ def fit(self, X, Y): """ self._validate_params() - super().fit(X, Y) + super().fit(X, Y, **fit_params) self.classes_ = [ estimator.classes_ for chain_idx, estimator in enumerate(self.estimators_) ] @@ -919,6 +931,24 @@ def decision_function(self, X): return Y_decision + def get_metadata_routing(self): + """Get metadata routing of this object. + + Please check :ref:`User Guide ` on how the routing + mechanism works. + + Returns + ------- + routing : MetadataRouter + A :class:`~utils.metadata_routing.MetadataRouter` encapsulating + routing information. + """ + router = MetadataRouter(owner=self.__class__.__name__).add( + estimator=self.base_estimator, + method_mapping=MethodMapping().add(callee="fit", caller="fit"), + ) + return router + def _more_tags(self): return {"_skip_test": True, "multioutput_only": True} @@ -1046,5 +1076,27 @@ def fit(self, X, Y, **fit_params): super().fit(X, Y, **fit_params) return self + def get_metadata_routing(self): + """Get metadata routing of this object. + + Please check :ref:`User Guide ` on how the routing + mechanism works. + + Returns + ------- + routing : MetadataRouter + A :class:`~utils.metadata_routing.MetadataRouter` encapsulating + routing information. + """ + router = ( + MetadataRouter(owner=self.__class__.__name__) + .add( + estimator=self.base_estimator, + method_mapping=MethodMapping().add(callee="fit", caller="fit"), + ) + .warn_on(child="estimator", method="fit", params=None) + ) + return router + def _more_tags(self): return {"multioutput_only": True} diff --git a/sklearn/tests/test_metaestimators_metadata_routing.py b/sklearn/tests/test_metaestimators_metadata_routing.py index 603e3ae1e8d99..9cc21135cf225 100644 --- a/sklearn/tests/test_metaestimators_metadata_routing.py +++ b/sklearn/tests/test_metaestimators_metadata_routing.py @@ -7,7 +7,12 @@ from sklearn.base import RegressorMixin, ClassifierMixin, BaseEstimator from sklearn.calibration import CalibratedClassifierCV from sklearn.exceptions import UnsetMetadataPassedError -from sklearn.multioutput import MultiOutputRegressor, MultiOutputClassifier +from sklearn.multioutput import ( + MultiOutputRegressor, + MultiOutputClassifier, + ClassifierChain, + RegressorChain, +) from sklearn.utils.metadata_routing import MetadataRouter from sklearn.tests.test_metadata_routing import ( record_metadata, @@ -181,6 +186,24 @@ def predict_log_proba(self, X, sample_weight="default", metadata="default"): "warns_on": {"fit": ["sample_weight", "metadata"]}, "preserves_metadata": False, }, + { + "metaestimator": ClassifierChain, + "estimator_name": "base_estimator", + "estimator": ConsumingClassifier, + "X": X, + "y": y_multi, + "routing_methods": ["fit"], + "warns_on": {}, + }, + { + "metaestimator": RegressorChain, + "estimator_name": "base_estimator", + "estimator": ConsumingRegressor, + "X": X, + "y": y_multi, + "routing_methods": ["fit"], + "warns_on": {"fit": ["sample_weight", "metadata"]}, + }, ] """List containing all metaestimators to be tested and their settings diff --git a/sklearn/tests/test_multioutput.py b/sklearn/tests/test_multioutput.py index f89734833d019..416d82a1b47e9 100644 --- a/sklearn/tests/test_multioutput.py +++ b/sklearn/tests/test_multioutput.py @@ -28,6 +28,7 @@ Ridge, SGDClassifier, SGDRegressor, + QuantileRegressor, ) from sklearn.metrics import jaccard_score, mean_squared_error from sklearn.model_selection import GridSearchCV, train_test_split @@ -646,7 +647,7 @@ def fit(self, X, y, **fit_params): self.sample_weight_ = fit_params["sample_weight"] super().fit(X, y, **fit_params) - model = RegressorChain(MySGD()) + model = RegressorChain(MySGD().set_fit_request(sample_weight=True)) # Fitting with params fit_param = {"sample_weight": weight} @@ -655,6 +656,16 @@ def fit(self, X, y, **fit_params): for est in model.estimators_: assert est.sample_weight_ is weight + # TODO(1.4): Remove check for FutureWarning + # Test that the existing behavior works and raises a FutureWarning + # when the underlying estimator used has a sample_weight parameter + # defined in it's fit method. + model = RegressorChain(QuantileRegressor()) + fit_param = {"sample_weight": weight} + + with pytest.warns(FutureWarning): + model.fit(X, y, **fit_param) + @pytest.mark.parametrize( "MultiOutputEstimator, Estimator",