8000 SLEP006: ClassifierChain and RegressorChain routing by OmarManzoor · Pull Request #24443 · scikit-learn/scikit-learn · GitHub
[go: up one dir, main page]

Skip to content

SLEP006: ClassifierChain and RegressorChain routing #24443

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Faile 8000 d to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
58 changes: 55 additions & 3 deletions sklearn/multioutput.py
Original file line number Diff line number Diff line change
Expand Up @@ -660,6 +660,8 @@ def fit(self, X, Y, **fit_params):

del Y_pred_chain

routed_params = process_routing(obj=self, method="fit", other_params=fit_params)

for chain_idx, estimator in enumerate(self.estimators_):
message = self._log_message(
estimator_idx=chain_idx + 1,
Expand All @@ -668,7 +670,12 @@ def fit(self, X, Y, **fit_params):
)
y = Y[:, self.order_[chain_idx]]
with _print_elapsed_time("Chain", message):
estimator.fit(X_aug[:, : (X.shape[1] + chain_idx)], y, **fit_params)
estimator.fit(
X_aug[:, : (X.shape[1] + chain_idx)],
y,
**routed_params.estimator.fit,
)

if self.cv is not None and chain_idx < len(self.estimators_) - 1:
col_idx = X.shape[1] + chain_idx
cv_result = cross_val_predict(
Expand Down Expand Up @@ -831,7 +838,7 @@ class labels for each estimator in the chain.
[0.0321..., 0.9935..., 0.0625...]])
"""

def fit(self, X, Y):
def fit(self, X, Y, **fit_params):
"""Fit the model to data matrix X and targets Y.

Parameters
Expand All @@ -842,14 +849,19 @@ def fit(self, X, Y):
Y : array-like of shape (n_samples, n_classes)
The target values.

**fit_params : dict of string -> object
Parameters passed to the `fit` method of each step.

.. versionadded:: 1.2

Returns
-------
self : object
Class instance.
"""
self._validate_params()

super().fit(X, Y)
super().fit(X, Y, **fit_params)
self.classes_ = [
estimator.classes_ for chain_idx, estimator in enumerate(self.estimators_)
]
Expand Down Expand Up @@ -919,6 +931,24 @@ def decision_function(self, X):

return Y_decision

def get_metadata_routing(self):
"""Get metadata routing of this object.

Please check :ref:`User Guide <metadata_routing>` on how the routing
mechanism works.

Returns
-------
routing : MetadataRouter
A :class:`~utils.metadata_routing.MetadataRouter` encapsulating
routing information.
"""
router = MetadataRouter(owner=self.__class__.__name__).add(
estimator=self.base_estimator,
method_mapping=MethodMapping().add(callee="fit", caller="fit"),
)
return router

def _more_tags(self):
return {"_skip_test": True, "multioutput_only": True}

Expand Down Expand Up @@ -1046,5 +1076,27 @@ def fit(self, X, Y, **fit_params):
super().fit(X, Y, **fit_params)
return self

def get_metadata_routing(self):
"""Get metadata routing of this object.

Please check :ref:`User Guide <metadata_routing>` on how the routing
mechanism works.

Returns
-------
routing : MetadataRouter
A :class:`~utils.metadata_routing.MetadataRouter` encapsulating
routing information.
"""
router = (
MetadataRouter(owner=self.__class__.__name__)
.add(
estimator=self.base_estimator,
method_mapping=MethodMapping().add(callee="fit", caller="fit"),
)
.warn_on(child="estimator", method="fit", params=None)
)
return router

def _more_tags(self):
return {"multioutput_only": True}
25 changes: 24 additions & 1 deletion sklearn/tests/test_metaestimators_metadata_routing.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,12 @@
from sklearn.base import RegressorMixin, ClassifierMixin, BaseEstimator
from sklearn.calibration import CalibratedClassifierCV
from sklearn.exceptions import UnsetMetadataPassedError
from sklearn.multioutput import MultiOutputRegressor, MultiOutputClassifier
from sklearn.multioutput import (
MultiOutputRegressor,
MultiOutputClassifier,
ClassifierChain,
RegressorChain,
)
from sklearn.utils.metadata_routing import MetadataRouter
from sklearn.tests.test_metadata_routing import (
record_metadata,
Expand Down Expand Up @@ -181,6 +186,24 @@ def predict_log_proba(self, X, sample_weight="default", metadata="default"):
"warns_on": {"fit": ["sample_weight", "metadata"]},
"preserves_metadata": False,
},
{
"metaestimator": ClassifierChain,
"estimator_name": "base_estimator",
"estimator": ConsumingClassifier,
"X": X,
"y": y_multi,
"routing_methods": ["fit"],
"warns_on": {},
},
{
"metaestimator": RegressorChain,
"estimator_name": "base_estimator",
"estimator": ConsumingRegressor,
"X": X,
"y": y_multi,
"routing_methods": ["fit"],
"warns_on": {"fit": ["sample_weight", "metadata"]},
},
]
"""List containing all metaestimators to be tested and their settings

Expand Down
13 changes: 12 additions & 1 deletion sklearn/tests/test_multioutput.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
Ridge,
SGDClassifier,
SGDRegressor,
QuantileRegressor,
)
from sklearn.metrics import jaccard_score, mean_squared_error
from sklearn.model_selection import GridSearchCV, train_test_split
Expand Down Expand Up @@ -646,7 +647,7 @@ def fit(self, X, y, **fit_params):
self.sample_weight_ = fit_params["sample_weight"]
super().fit(X, y, **fit_params)

model = RegressorChain(MySGD())
model = RegressorChain(MySGD().set_fit_request(sample_weight=True))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We need a test that checks that the existing behaviour still works but issues a DeprecationWarning.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the common tests are supposed to do that.


# Fitting with params
fit_param = {"sample_weight": weight}
Expand All @@ -655,6 +656,16 @@ def fit(self, X, y, **fit_params):
for est in model.estimators_:
assert est.sample_weight_ is weight

# TODO(1.4): Remove check for FutureWarning
# Test that the existing behavior works and raises a FutureWarning
Copy link
Member

Choose a reason for hiding this comme 6D47 nt

The reason will be displayed to describe this comment to others. Learn more.

To make it easier to search for when we need to remove the warning:

Suggested change
# Test that the existing behavior works and raises a FutureWarning
# TODO(1.4): Remove check for FutureWarning
# Test that the existing behavior works and raises a FutureWarning

# when the underlying estimator used has a sample_weight parameter
# defined in it's fit method.
model = RegressorChain(QuantileRegressor())
fit_param = {"sample_weight": weight}

with pytest.warns(FutureWarning):
model.fit(X, y, **fit_param)


@pytest.mark.parametrize(
"MultiOutputEstimator, Estimator",
Expand Down
0