8000 FEA Add metadata routing for TransformedTargetRegressor by OmarManzoor · Pull Request #29136 · scikit-learn/scikit-learn · GitHub
[go: up one dir, main page]

Skip to content

FEA Add metadata routing for TransformedTargetRegressor #29136

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Jun 10, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion doc/metadata_routing.rst
Original file line number Diff line number Diff line change
Expand Up @@ -276,6 +276,7 @@ Meta-estimators and functions supporting metadata routing:

- :class:`sklearn.calibration.CalibratedClassifierCV`
- :class:`sklearn.compose.ColumnTransformer`
- :class:`sklearn.compose.TransformedTargetRegressor`
- :class:`sklearn.covariance.GraphicalLassoCV`
- :class:`sklearn.ensemble.StackingClassifier`
- :class:`sklearn.ensemble.StackingRegressor`
Expand Down Expand Up @@ -316,7 +317,6 @@ Meta-estimators and functions supporting metadata routing:

Meta-estimators and tools not supporting metadata routing yet:

- :class:`sklearn.compose.TransformedTargetRegressor`
- :class:`sklearn.ensemble.AdaBoostClassifier`
- :class:`sklearn.ensemble.AdaBoostRegressor`
- :class:`sklearn.feature_selection.RFE`
Expand Down
5 changes: 5 additions & 0 deletions doc/whats_new/v1.6.rst
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,11 @@ more details.
``**fit_params`` to the underlying estimators via their `fit` methods.
:pr:`28701` by :user:`Stefanie Senger <StefanieSenger>`.

- |Feature| :class:`compose.TransformedTargetRegressor` now supports metadata
routing in its `fit` and `predict` methods and routes the corresponding
params to the underlying regressor.
:pr:`29136` by :user:`Omar Salman <OmarManzoor>`.

Dropping official support for PyPy
----------------------------------

Expand Down
93 changes: 74 additions & 19 deletions sklearn/compose/_target.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,18 @@

from ..base import BaseEstimator, RegressorMixin, _fit_context, clone
from ..exceptions import NotFittedError
from ..linear_model import LinearRegression
from ..preprocessing import FunctionTransformer
from ..utils import _safe_indexing, check_array
from ..utils import Bunch, _safe_indexing, check_array
from ..utils._metadata_requests import (
MetadataRouter,
MethodMapping,
_routing_enabled,
process_routing,
)
from ..utils._param_validation import HasMethods
from ..utils._tags import _safe_tags
from ..utils.metadata_routing import (
_raise_for_unsupported_routing,
_RoutingNotSupportedMixin,
)
from ..utils.validation import check_is_fitted
Expand Down Expand Up @@ -230,15 +236,25 @@ def fit(self, X, y, **fit_params):
Target values.

**fit_params : dict
Parameters passed to the `fit` method of the underlying
regressor.
- If `enable_metadata_routing=False` (default):

Parameters directly passed to the `fit` method of the
underlying regressor.

- If `enable_metadata_routing=True`:

Parameters safely routed to the `fit` method of the
underlying regressor.

.. versionchanged:: 1.6
See :ref:`Metadata Routing User Guide <metadata_routing>` for
more details.

Returns
-------
self : object
Fitted estimator.
"""
_raise_for_unsupported_routing(self, "fit", **fit_params)
if y is None:
raise ValueError(
f"This {self.__class__.__name__} estimator "
Expand Down Expand Up @@ -274,14 +290,13 @@ def fit(self, X, y, **fit_params):
if y_trans.ndim == 2 and y_trans.shape[1] == 1:
y_trans = y_trans.squeeze(axis=1)

if self.regressor is None:
from ..linear_model import LinearRegression
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Interesting. I would have think that we have a circular import but apparently this is not the case anymore.


self.regressor_ = LinearRegression()
self.regressor_ = self._get_regressor(get_clone=True)
if _routing_enabled():
routed_params = process_routing(self, "fit", **fit_params)
else:
self.regressor_ = clone(self.regressor)
routed_params = Bunch(regressor=Bunch(fit=fit_params))

self.regressor_.fit(X, y_trans, **fit_params)
self.regressor_.fit(X, y_trans, **routed_params.regressor.fit)

if hasattr(self.regressor_, "feature_names_in_"):
self.feature_names_in_ = self.regressor_.feature_names_in_
Expand All @@ -300,16 +315,32 @@ def predict(self, X, **predict_params):
Samples.

**predict_params : dict of str -> object
Parameters passed to the `predict` method of the underlying
regressor.
- If `enable_metadata_routing=False` (default):

Parameters directly passed to the `predict` method of the
underlying regressor.

- If `enable_metadata_routing=True`:

Parameters safely routed to the `predict` method of the
underlying regressor.

.. versionchanged:: 1.6
See :ref:`Metadata Routing User Guide <metadata_routing>`
for more details.

Returns
-------
y_hat : ndarray of shape (n_samples,)
Predicted values.
"""
check_is_fitted(self)
pred = self.regressor_.predict(X, **predict_params)
if _routing_enabled():
routed_params = process_routing(self, "predict", **predict_params)
else:
routed_params = Bunch(regressor=Bunch(predict=predict_params))

pred = self.regressor_.predict(X, **routed_params.regressor.predict)
if pred.ndim == 1:
pred_trans = self.transformer_.inverse_transform(pred.reshape(-1, 1))
else:
Expand All @@ -324,11 +355,7 @@ def predict(self, X, **predict_params):
return pred_trans

def _more_tags(self):
regressor = self.regressor
if regressor is None:
from ..linear_model import LinearRegression

regressor = LinearRegression()
regressor = self._get_regressor()

return {
"poor_score": True,
Expand All @@ -350,3 +377,31 @@ def n_features_in_(self):
) from nfe

return self.regressor_.n_features_in_

def get_metadata_routing(self):
"""Get metadata routing of this object.

Please check :ref:`User Guide <metadata_routing>` on how the routing
mechanism works.

.. versionadded:: 1.6

Returns
-------
routing : MetadataRouter
A :class:`~sklearn.utils.metadata_routing.MetadataRouter` encapsulating
routing information.
"""
router = MetadataRouter(owner=self.__class__.__name__).add(
regressor=self._get_regressor(),
method_mapping=MethodMapping()
.add(caller="fit", callee="fit")
.add(caller="predict", callee="predict"),
)
return router

def _get_regressor(self, get_clone=False):
if self.regressor is None:
return LinearRegression()

return clone(self.regressor) if get_clone else self.regressor
9 changes: 8 additions & 1 deletion sklearn/tests/test_metaestimators_metadata_routing.py
< 717B td id="diff-fb0682201360bf8493836d8fb3977b58e035dd084e07dec9a8522f64df94bb38R393" data-line-number="393" class="blob-num blob-num-context js-linkable-line-number js-blob-rnum">
Original file line number Diff line number Diff line change
Expand Up @@ -382,6 +382,14 @@ def enable_slep006():
"cv_name": "cv",
"cv_routing_methods": ["fit"],
},
{
"metaestimator": TransformedTargetRegressor,
"estimator": "regressor",
"estimator_name": "regressor",
"X": X,
"y": y,
"estimator_routing_methods": ["fit", "predict"],
},
]
"""List containing all metaestimators to be tested and their settings

Expand Down Expand Up @@ -427,7 +435,6 @@ def enable_slep006():
RFECV(ConsumingClassifier()),
SelfTrainingClassifier(ConsumingClassifier()),
SequentialFeatureSelector(ConsumingClassifier()),
TransformedTargetRegressor(),
]


Expand Down
Loading
0