10000 SLEP006: ClassifierChain and RegressorChain routing (#24443) · scikit-learn/scikit-learn@1a4f9bd · GitHub
[go: up one dir, main page]

Skip to content

Commit 1a4f9bd

Browse files
authored
SLEP006: ClassifierChain and RegressorChain routing (#24443)
1 parent 2e62f3b commit 1a4f9bd

File tree

3 files changed

+91
-5
lines changed

3 files changed

+91
-5
lines changed

sklearn/multioutput.py

Lines changed: 55 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -660,6 +660,8 @@ def fit(self, X, Y, **fit_params):
660660

661661
del Y_pred_chain
662662

663+
routed_params = process_routing(obj=self, method="fit", other_params=fit_params)
664+
663665
for chain_idx, estimator in enumerate(self.estimators_):
664666
message = self._log_message(
665667
estimator_idx=chain_idx + 1,
@@ -668,7 +670,12 @@ def fit(self, X, Y, **fit_params):
668670
)
669671
y = Y[:, self.order_[chain_idx]]
670672
with _print_elapsed_time("Chain", message):
671-
estimator.fit(X_aug[:, : (X.shape[1] + chain_idx)], y, **fit_params)
673+
estimator.fit(
674+
X_aug[:, : (X.shape[1] + chain_idx)],
675+
y,
676+
**routed_params.estimator.fit,
677+
)
678+
672679
if self.cv is not None and chain_idx < len(self.estimators_) - 1:
673680
col_idx = X.shape[1] + chain_idx
674681
cv_result = cross_val_predict(
@@ -831,7 +838,7 @@ class labels for each estimator in the chain.
831838
[0.0321..., 0.9935..., 0.0625...]])
832839
"""
833840

834-
def fit(self, X, Y):
841+
def fit(self, X, Y, **fit_params):
835842
"""Fit the model to data matrix X and targets Y.
836843
837844
Parameters
@@ -842,14 +849,19 @@ def fit(self, X, Y):
842849
Y : array-like of shape (n_samples, n_classes)
843850
The target values.
844851
852+
**fit_params : dict of string -> object
853+
Parameters passed to the `fit` method of each step.
854+
855+
.. versionadded:: 1.2
856+
845857
Returns
846858
-------
847859
self : object
848860
Class instance.
849861
"""
850862
self._validate_params()
851863

852-
super().fit(X, Y)
864+
super().fit(X, Y, **fit_params)
853865
self.classes_ = [
854866
estimator.classes_ for chain_idx, estimator in enumerate(self.estimators_)
855867
]
@@ -919,6 +931,24 @@ def decision_function(self, X):
919931

920932
return Y_decision
921933

934+
def get_metadata_routing(self):
935+
"""Get metadata routing of this object.
936+
937+
Please check :ref:`User Guide <metadata_routing>` on how the routing
938+
mechanism works.
939+
940+
Returns
941+
-------
942+
routing : MetadataRouter
943+
A :class:`~utils.metadata_routing.MetadataRouter` encapsulating
944+
routing information.
945+
"""
946+
router = MetadataRouter(owner=self.__class__.__name__).add(
947+
estimator=self.base_estimator,
948+
method_mapping=MethodMapping().add(callee="fit", caller="fit"),
949+
)
950+
return router
951+
922952
def _more_tags(self):
923953
return {"_skip_test": True, "multioutput_only": True}
924954

@@ -1046,5 +1076,27 @@ def fit(self, X, Y, **fit_params):
10461076
super().fit(X, Y, **fit_params)
10471077
return self
10481078

1079+
def get_metadata_routing(self):
1080+
"""Get metadata routing of this object.
1081+
1082+
Please check :ref:`User Guide <metadata_routing>` on how the routing
1083+
mechanism works.
1084+
1085+
Returns
1086+
-------
1087+
routing : MetadataRouter
1088+
A :class:`~utils.metadata_routing.MetadataRouter` encapsulating
1089+
routing information.
1090+
"""
1091+
router = (
1092+
MetadataRouter(owner=self.__class__.__name__)
1093+
.add(
1094+
estimator=self.base_estimator,
1095+
method_mapping=MethodMapping().add(callee="fit", caller="fit"),
1096+
)
1097+
.warn_on(child="estimator", method="fit", params=None)
1098+
)
1099+
return router
1100+
10491101
def _more_tags(self):
10501102
return {"multioutput_only": True}

sklearn/tests/test_metaestimators_metadata_routing.py

Lines changed: 24 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,12 @@
77
from sklearn.base import RegressorMixin, ClassifierMixin, BaseEstimator
88
from sklearn.calibration import CalibratedClassifierCV
99
from sklearn.exceptions import UnsetMetadataPassedError
10-
from sklearn.multioutput import MultiOutputRegressor, MultiOutputClassifier
10+
from sklearn.multioutput import (
11+
MultiOutputRegressor,
12+
MultiOutputClassifier,
13+
ClassifierChain,
14+
RegressorChain,
15+
)
1116
from sklearn.utils.metadata_routing import MetadataRouter
1217
from sklearn.tests.test_metadata_routing import (
1318
record_metadata,
@@ -181,6 +186,24 @@ def predict_log_proba(self, X, sample_weight="default", metadata="default"):
181186
"warns_on": {"fit": ["sample_weight", "metadata"]},
182187
"preserves_metadata": False,
183188
},
189+
{
190+
"metaestimator": ClassifierChain,
191+
"estimator_name": "base_estimator",
192+
"estimator": ConsumingClassifier,
193+
"X": X,
194+
"y": y_multi,
195+
"routing_methods": ["fit"],
196+
"warns_on": {},
197+
},
198+
{
199+
"metaestimator": RegressorChain,
200+
"estimator_name": "base_estimator",
201+
"estimator": ConsumingRegressor,
202+
"X": X,
203+
"y": y_multi,
204+
"routing_methods": ["fit"],
205+
"warns_on": {"fit": ["sample_weight", "metadata"]},
206+
},
184207
]
185208
"""List containing all metaestimators to be tested and their settings
186209

sklearn/tests/test_multioutput.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
Ridge,
2929
SGDClassifier,
3030
SGDRegressor,
31+
QuantileRegressor,
3132
)
323 D634 3
from sklearn.metrics import jaccard_score, mean_squared_error
3334
from sklearn.model_selection import GridSearchCV, train_test_split
@@ -646,7 +647,7 @@ def fit(self, X, y, **fit_params):
646647
self.sample_weight_ = fit_params["sample_weight"]
647648
super().fit(X, y, **fit_params)
648649

649-
model = RegressorChain(MySGD())
650+
model = RegressorChain(MySGD().set_fit_request(sample_weight=True))
650651

651652
# Fitting with params
652653
fit_param = {"sample_weight": weight}
@@ -655,6 +656,16 @@ def fit(self, X, y, **fit_params):
655656
for est in model.estimators_:
656657
assert est.sample_weight_ is weight
657658

659+
# TODO(1.4): Remove check for FutureWarning
660+
# Test that the existing behavior works and raises a FutureWarning
661+
# when the underlying estimator used has a sample_weight parameter
662+
# defined in it's fit method.
663+
model = RegressorChain(QuantileRegressor())
664+
fit_param = {"sample_weight": weight}
665+
666+
with pytest.warns(FutureWarning):
667+
model.fit(X, y, **fit_param)
668+
658669

659670
@pytest.mark.parametrize(
660671
"MultiOutputEstimator, Estimator",

0 commit comments

Comments
 (0)
0