diff --git a/doc/metadata_routing.rst b/doc/metadata_routing.rst new file mode 100644 index 0000000000000..d1cd385b53a4a --- /dev/null +++ b/doc/metadata_routing.rst @@ -0,0 +1,187 @@ + +.. _metadata_routing: + +Metadata Routing +================ + +This guide demonstrates how metadata such as ``sample_weight`` can be routed +and passed along to estimators, scorers, and CV splitters through +meta-estimators such as ``Pipeline`` and ``GridSearchCV``. In order to pass +metadata to a method such as ``fit`` or ``score``, the object accepting the +metadata, must *request* it. For estimators and splitters this is done via +``*_requests`` methods, e.g. ``fit_requests(...)``, and for scorers this is +done via ``score_requests`` method of a scorer. For grouped splitters such as +``GroupKFold`` a ``groups`` parameter is requested by default. This is best +demonstrated by the following examples. + +Usage Examples +************** +Here we present a few examples to show different common usecases. The examples +in this section require the following imports and data:: + + >>> import numpy as np + >>> from sklearn.metrics import make_scorer, accuracy_score + >>> from sklearn.linear_model import LogisticRegressionCV + >>> from sklearn.linear_model import LogisticRegression + >>> from sklearn.model_selection import cross_validate + >>> from sklearn.model_selection import GridSearchCV + >>> from sklearn.model_selection import GroupKFold + >>> from sklearn.feature_selection import SelectKBest + >>> from sklearn.pipeline import make_pipeline + >>> n_samples, n_features = 100, 4 + >>> X = np.random.rand(n_samples, n_features) + >>> y = np.random.randint(0, 2, size=n_samples) + >>> my_groups = np.random.randint(0, 10, size=n_samples) + >>> my_weights = np.random.rand(n_samples) + >>> my_other_weights = np.random.rand(n_samples) + +Weighted scoring and fitting +---------------------------- + +Here ``GroupKFold`` requests ``groups`` by default. However, we need to +explicitly request weights in ``make_scorer`` and for ``LogisticRegressionCV``. +Both of these *consumers* understand the meaning of the key +``"sample_weight"``:: + + >>> weighted_acc = make_scorer(accuracy_score).score_requests( + ... sample_weight=True + ... ) + >>> lr = LogisticRegressionCV( + ... cv=GroupKFold(), scoring=weighted_acc, + ... ).fit_requests(sample_weight=True) + >>> cv_results = cross_validate( + ... lr, + ... X, + ... y, + ... cv=GroupKFold(), + ... props={"sample_weight": my_weights, "groups": my_groups}, + ... scoring=weighted_acc, + ... ) + +Error handling: if ``props={'sample_weigh': my_weights, ...}`` were passed +(note the typo), cross_validate would raise an error, since 'sample_weigh' was +not requested by any of its children. + +Weighted scoring and unweighted fitting +--------------------------------------- + +Since ``LogisticRegressionCV``, like all scikit-learn estimators, requires that +weights explicitly be requested, we need to explicitly say that +``sample_weight`` is not used for it, so that ``cross_validate`` doesn't pass +it along. + + >>> weighted_acc = make_scorer(accuracy_score).score_requests( + ... sample_weight=True + ... ) + >>> lr = LogisticRegressionCV( + ... cv=GroupKFold(), scoring=weighted_acc, + ... ).fit_requests(sample_weight=False) + >>> cv_results = cross_validate( + ... lr, + ... X, + ... y, + ... cv=GroupKFold(), + ... props={"sample_weight": my_weights, "groups": my_groups}, + ... scoring=weighted_acc, + ... ) + +Unweighted feature selection +---------------------------- + +Unlike ``LogisticRegressionCV``, ``SelectKBest`` doesn't accept weights and +therefore `"sample_weight"` is not routed to it:: + + >>> weighted_acc = make_scorer(accuracy_score).score_requests( + ... sample_weight=True + ... ) + >>> lr = LogisticRegressionCV( + ... cv=GroupKFold(), scoring=weighted_acc, + ... ).fit_requests(sample_weight=True) + >>> sel = SelectKBest(k=2) + >>> pipe = make_pipeline(sel, lr) + >>> cv_results = cross_validate( + ... pipe, + ... X, + ... y, + ... cv=GroupKFold(), + ... props={"sample_weight": my_weights, "groups": my_groups}, + ... scoring=weighted_acc, + ... ) + +Different scoring and fitting weights +------------------------------------- + +Despite ``make_scorer`` and ``LogisticRegressionCV`` both expecting a key +``sample_weight``, we can use aliases to pass different weights to different +consumers. In this example, we pass ``scoring_weight`` to the scorer, and +``fitting_weight`` to ``LogisticRegressionCV``:: + + >>> weighted_acc = make_scorer(accuracy_score).score_requests( + ... sample_weight="scoring_weight" + ... ) + >>> lr = LogisticRegressionCV( + ... cv=GroupKFold(), scoring=weighted_acc, + ... ).fit_requests(sample_weight="fitting_weight") + >>> cv_results = cross_validate( + ... lr, + ... X, + ... y, + ... cv=GroupKFold(), + ... props={ + ... "scoring_weight": my_weights, + ... "fitting_weight": my_other_weights, + ... "groups": my_groups, + ... }, + ... scoring=weighted_acc, + ... ) + +API Interface +************* + +A *consumer* is an object (estimator, meta-estimator, scorer, splitter) which +accepts and uses some metadata in at least one of their methods (``fit``, +``predict``, ``inverse_transform``, ``transform``, ``score``, ``split``). +Meta-estimators which only forward the metadata other objects (the child +estimator, scorers, or splitters) and don't use the metadata themselves are not +consumers. (Meta)Estimators which route metadata to other objects are routers. +An (meta)estimator can be a consumer and a router at the same time. +(Meta)Estimators and splitters expose a ``*_requests`` method for each method +which accepts at least one metadata. For instance, if an estimator supports +``sample_weight`` in ``fit`` and ``score``, it exposes +``estimator.fit_requests(sample_weight=value)`` and +``estimator.score_requests(sample_weight=value)``. Here ``value`` can be: + +- ``RequestType.REQUESTED`` or ``True``: method requests a ``sample_weight``. + This means if the metadata is provided, it will be used, otherwise no error + is raised. +- ``RequestType.UNREQUESTED`` or ``False``: method does not request a + ``sample_weight``. +- ``RequestType.ERROR_IF_PASSED`` or ``None``: router will raise an error if + ``sample_weight`` is passed. This is in almost all cases the default value + when an object is instantiated and ensures the user sets the metadata + requests explicitly when a metadata is passed. +- ``"param_name"``: if this estimator is used in a meta-estimator, the + meta-estimator should forward ``"param_name"`` as ``sample_weight`` to this + estimator. This means the mapping between the metadata required by the + object, e.g. ``sample_weight`` and what is provided by the user, e.g. + ``my_weights`` is done at the router level, and not by the object, e.g. + estimator, itself. + +For the scorers, this is done the same way, using ``.score_requests`` method. + +If a metadata, e.g. ``sample_weight`` is passed by the user, the metadata +request for all objects which potentially can accept ``sample_weight`` should +be set by the user, otherwise an error is raised by the router object. For +example, the following code would raise, since it hasn't been explicitly set +whether ``sample_weight`` should be passed to the estimator's scorer or not:: + + >>> param_grid = {"C": [0.1, 1]} + >>> lr = LogisticRegression().fit_requests(sample_weight=True) + >>> try: + ... GridSearchCV( + ... estimator=lr, param_grid=param_grid + ... ).fit(X, y, sample_weight=my_weights) + ... except ValueError as e: + ... print(e) + sample_weight is passed but is not explicitly set as requested or not. In + method: score diff --git a/doc/user_guide.rst b/doc/user_guide.rst index 7d48934d32727..7e656567f3249 100644 --- a/doc/user_guide.rst +++ b/doc/user_guide.rst @@ -30,3 +30,4 @@ User Guide computing.rst modules/model_persistence.rst common_pitfalls.rst + metadata_routing.rst diff --git a/sklearn/base.py b/sklearn/base.py index 60fc82eff6088..bde80fc65b3fb 100644 --- a/sklearn/base.py +++ b/sklearn/base.py @@ -25,6 +25,7 @@ from .utils.validation import _num_features from .utils.validation import _check_feature_names_in from .utils._estimator_html_repr import estimator_html_repr +from .utils.metadata_requests import _MetadataRequester from .utils.validation import _get_feature_names @@ -79,7 +80,13 @@ def clone(estimator, *, safe=True): new_object_params = estimator.get_params(deep=False) for name, param in new_object_params.items(): new_object_params[name] = clone(param, safe=False) + new_object = klass(**new_object_params) + try: + new_object._metadata_request = copy.deepcopy(estimator._metadata_request) + except AttributeError: + pass + params_set = new_object.get_params(deep=False) # quick sanity check of the parameters of the clone @@ -144,7 +151,7 @@ def _pprint(params, offset=0, printer=repr): return lines -class BaseEstimator: +class BaseEstimator(_MetadataRequester): """Base class for all estimators in scikit-learn. Notes diff --git a/sklearn/compose/_target.py b/sklearn/compose/_target.py index 8ca158890c17c..649b20c9ec47d 100644 --- a/sklearn/compose/_target.py +++ b/sklearn/compose/_target.py @@ -307,3 +307,19 @@ def n_features_in_(self): ) from nfe return self.regressor_.n_features_in_ + + def get_metadata_request(self): + """Get requested data properties. + + This method mirrors the given regressor's metadata request. + + .. versionadded:: 1.1 + + Returns + ------- + request : dict + A dict of dict of str->value. The key to the first dict is the name + of the method, and the key to the second dict is the name of the + argument requested by the method. + """ + return self.regressor.get_metadata_request() diff --git a/sklearn/compose/tests/test_target.py b/sklearn/compose/tests/test_target.py index f0d63c00c2772..f530ba527265c 100644 --- a/sklearn/compose/tests/test_target.py +++ b/sklearn/compose/tests/test_target.py @@ -346,6 +346,8 @@ def test_transform_target_regressor_count_fit(check_inverse): class DummyRegressorWithExtraFitParams(DummyRegressor): + _metadata_request__check_input = {"fit": "check_input"} + def fit(self, X, y, sample_weight=None, check_input=True): # on the test below we force this to false, we make sure this is # actually passed to the regressor @@ -356,7 +358,10 @@ def fit(self, X, y, sample_weight=None, check_input=True): def test_transform_target_regressor_pass_fit_parameters(): X, y = friedman regr = TransformedTargetRegressor( - regressor=DummyRegressorWithExtraFitParams(), transformer=DummyTransformer() + regressor=DummyRegressorWithExtraFitParams().fit_requests( + sample_weight=True, check_input=True + ), + transformer=DummyTransformer(), ) regr.fit(X, y, check_input=False) @@ -367,12 +372,13 @@ def test_transform_target_regressor_route_pipeline(): X, y = friedman regr = TransformedTargetRegressor( - regressor=DummyRegressorWithExtraFitParams(), transformer=DummyTransformer() + regressor=DummyRegressorWithExtraFitParams().fit_requests(check_input=True), + transformer=DummyTransformer(), ) estimators = [("normalize", StandardScaler()), ("est", regr)] pip = Pipeline(estimators) - pip.fit(X, y, **{"est__check_input": False}) + pip.fit(X, y, **{"check_input": False}) assert regr.transformer_.fit_counter == 1 diff --git a/sklearn/ensemble/_gb.py b/sklearn/ensemble/_gb.py index 7b66324d1f08b..0ff790a70a9b1 100644 --- a/sklearn/ensemble/_gb.py +++ b/sklearn/ensemble/_gb.py @@ -54,6 +54,8 @@ from ..utils.validation import check_is_fitted, _check_sample_weight from ..utils.multiclass import check_classification_targets from ..exceptions import NotFittedError +from ..utils import metadata_request_factory +from ..utils import MetadataRouter class VerboseReporter: @@ -263,18 +265,14 @@ def _fit_stage( return raw_predictions - def _check_params(self): - """Check validity of parameters and raise ValueError if not valid.""" - if self.n_estimators <= 0: - raise ValueError( - "n_estimators must be greater than 0 but was %r" % self.n_estimators - ) - - if self.learning_rate <= 0.0: - raise ValueError( - "learning_rate must be greater than 0 but was %r" % self.learning_rate - ) + def _get_loss(self, n_classes=None): + """Return the right loss object. + Parameters + ---------- + n_classes : int, default=None + Relevant if loss is "deviance". self.classes_ is used if not given. + """ if ( self.loss not in self._SUPPORTED_LOSS or self.loss not in _gb_losses.LOSS_FUNCTIONS @@ -298,20 +296,38 @@ def _check_params(self): ) if self.loss == "deviance": + if n_classes is None: + n_classes = len(self.classes_) loss_class = ( _gb_losses.MultinomialDeviance - if len(self.classes_) > 2 + if n_classes > 2 else _gb_losses.BinomialDeviance ) else: loss_class = _gb_losses.LOSS_FUNCTIONS[self.loss] if is_classifier(self): - self.loss_ = loss_class(self.n_classes_) + if n_classes is None: + n_classes = len(self.classes_) + return loss_class(n_classes) elif self.loss in ("huber", "quantile"): - self.loss_ = loss_class(self.alpha) + return loss_class(self.alpha) else: - self.loss_ = loss_class() + return loss_class() + + def _check_params(self): + """Check validity of parameters and raise ValueError if not valid.""" + if self.n_estimators <= 0: + raise ValueError( + "n_estimators must be greater than 0 but was %r" % self.n_estimators + ) + + if self.learning_rate <= 0.0: + raise ValueError( + "learning_rate must be greater than 0 but was %r" % self.learning_rate + ) + + self.loss_ = self._get_loss() if not (0.0 < self.subsample <= 1.0): raise ValueError("subsample must be in (0,1] but was %r" % self.subsample) @@ -369,7 +385,7 @@ def _init_state(self): self.init_ = self.init if self.init_ is None: - self.init_ = self.loss_.init_estimator() + self.init_ = self.loss_.init_estimator().fit_requests(sample_weight=True) self.estimators_ = np.empty((self.n_estimators, self.loss_.K), dtype=object) self.train_score_ = np.zeros((self.n_estimators,), dtype=np.float64) @@ -426,7 +442,7 @@ def _check_initialized(self): def _warn_mae_for_criterion(self): pass - def fit(self, X, y, sample_weight=None, monitor=None): + def fit(self, X, y, sample_weight=None, monitor=None, **fit_params): """Fit the gradient boosting model. Parameters @@ -457,11 +473,21 @@ def fit(self, X, y, sample_weight=None, monitor=None): computing held-out estimates, early stopping, model introspect, and snapshoting. + **fit_params : dict + Other parameters required by ``init.fit(...)``. If ``init`` is an + estimator and requests certain metadata, they should be included + in ``fit_params``. + + .. versionadded:: 1.1 + Returns ------- self : object Fitted estimator. """ + if sample_weight is not None: + fit_params["sample_weight"] = sample_weight + if self.criterion in ("absolute_error", "mae"): # TODO: This should raise an error from 1.1 self._warn_mae_for_criterion() @@ -487,8 +513,6 @@ def fit(self, X, y, sample_weight=None, monitor=None): X, y, accept_sparse=["csr", "csc", "coo"], dtype=DTYPE, multi_output=True ) - sample_weight_is_none = sample_weight is None - sample_weight = _check_sample_weight(sample_weight, X) y = column_or_1d(y, warn=True) @@ -534,29 +558,10 @@ def fit(self, X, y, sample_weight=None, monitor=None): shape=(X.shape[0], self.loss_.K), dtype=np.float64 ) else: - # XXX clean this once we have a support_sample_weight tag - if sample_weight_is_none: - self.init_.fit(X, y) - else: - msg = ( - "The initial estimator {} does not support sample " - "weights.".format(self.init_.__class__.__name__) - ) - try: - self.init_.fit(X, y, sample_weight=sample_weight) - except TypeError as e: - # regular estimator without SW support - raise ValueError(msg) from e - except ValueError as e: - if ( - "pass parameters to specific steps of " - "your pipeline using the " - "stepname__parameter" - in str(e) - ): # pipeline - raise ValueError(msg) from e - else: # regular estimator whose input checking failed - raise + init_fit_params = metadata_request_factory( + self.init_ + ).fit.get_method_input(ignore_extras=True, kwargs=fit_params) + self.init_.fit(X, y, **init_fit_params) raw_predictions = self.loss_.get_init_raw_predictions(X, self.init_) @@ -892,6 +897,35 @@ def apply(self, X): def n_features_(self): return self.n_features_in_ + def get_metadata_request(self): + """Get requested data properties. + + .. versionadded:: 1.1 + + Returns + ------- + request : dict + A dict of dict of str->value. The key to the first dict is the name + of the method, and the key to the second dict is the name of the + argument requested by the method. + """ + router = MetadataRouter().add(super(), mask=False) + init = self.init + if self.init is None: + # we pass n_classes=2 since the estimators are the same regardless. + init = ( + self._get_loss(n_classes=2) + .init_estimator() + .fit_requests(sample_weight=True) + ) + # here overwrite="ignore" because we should not expose a + # `sample_weight=True` if init is None. + router.add(init, mapping={"fit": "fit"}, mask=True, overwrite="ignore") + else: + router.add(init, mapping={"fit": "fit"}, mask=True, overwrite="smart") + + return router.get_metadata_request() + class GradientBoostingClassifier(ClassifierMixin, BaseGradientBoosting): """Gradient Boosting for classification. diff --git a/sklearn/ensemble/tests/test_gradient_boosting.py b/sklearn/ensemble/tests/test_gradient_boosting.py index 410f4086bb7c4..01ef24a2b488f 100644 --- a/sklearn/ensemble/tests/test_gradient_boosting.py +++ b/sklearn/ensemble/tests/test_gradient_boosting.py @@ -23,7 +23,6 @@ from sklearn.metrics import mean_squared_error from sklearn.model_selection import train_test_split from sklearn.utils import check_random_state, tosequence -from sklearn.utils._mocking import NoSampleWeightWrapper from sklearn.utils._testing import assert_almost_equal from sklearn.utils._testing import assert_array_almost_equal from sklearn.utils._testing import assert_array_equal @@ -34,6 +33,8 @@ from sklearn.pipeline import make_pipeline from sklearn.linear_model import LinearRegression from sklearn.svm import NuSVR +from sklearn.utils.metadata_requests import RequestType +from sklearn.utils.metadata_requests import metadata_request_factory GRADIENT_BOOSTING_ESTIMATORS = [GradientBoostingClassifier, GradientBoostingRegressor] @@ -1109,7 +1110,11 @@ def test_non_uniform_weights_toy_edge_case_clf(): # ignore the first 2 training samples by setting their weight to 0 sample_weight = [0, 0, 1, 1] for loss in ("deviance", "exponential"): - gb = GradientBoostingClassifier(n_estimators=5, loss=loss) + gb = GradientBoostingClassifier( + n_estimators=5, + loss=loss, + init=DummyClassifier().fit_requests(sample_weight=True), + ) gb.fit(X, y, sample_weight=sample_weight) assert_array_equal(gb.predict([[1, 0]]), [1]) @@ -1283,15 +1288,9 @@ def test_gradient_boosting_with_init(gb, dataset_maker, init_estimator): sample_weight = np.random.RandomState(42).rand(100) # init supports sample weights - init_est = init_estimator() + init_est = init_estimator().fit_requests(sample_weight=True) gb(init=init_est).fit(X, y, sample_weight=sample_weight) - # init does not support sample weights - init_est = NoSampleWeightWrapper(init_estimator()) - gb(init=init_est).fit(X, y) # ok no sample weights - with pytest.raises(ValueError, match="estimator.*does not support sample weights"): - gb(init=init_est).fit(X, y, sample_weight=sample_weight) - def test_gradient_boosting_with_init_pipeline(): # Check that the init estimator can be a pipeline (see issue #13466) @@ -1303,7 +1302,7 @@ def test_gradient_boosting_with_init_pipeline(): with pytest.raises( ValueError, - match="The initial estimator Pipeline does not support sample weights", + match="sample_weight is passed but is not explicitly set as requested or not", ): gb.fit(X, y, sample_weight=np.ones(X.shape[0])) @@ -1313,7 +1312,7 @@ def test_gradient_boosting_with_init_pipeline(): # whose input checking failed. with pytest.raises(ValueError, match="nu <= 0 or nu > 1"): # Note that NuSVR properly supports sample_weight - init = NuSVR(gamma="auto", nu=1.5) + init = NuSVR(gamma="auto", nu=1.5).fit_requests(sample_weight=True) gb = GradientBoostingRegressor(init=init) gb.fit(X, y, sample_weight=np.ones(X.shape[0])) @@ -1458,3 +1457,12 @@ def test_loss_deprecated(old_loss, new_loss): est2 = GradientBoostingRegressor(loss=new_loss, random_state=0) est2.fit(X, y) assert_allclose(est1.predict(X), est2.predict(X)) + + +@pytest.mark.parametrize("Estimator", GRADIENT_BOOSTING_ESTIMATORS) +def test_metadata_request(Estimator): + est = Estimator() + assert ( + metadata_request_factory(est).fit.requests["sample_weight"] + == RequestType.ERROR_IF_PASSED + ) diff --git a/sklearn/externals/_sentinels.py b/sklearn/externals/_sentinels.py new file mode 100644 index 0000000000000..662a82864ca8d --- /dev/null +++ b/sklearn/externals/_sentinels.py @@ -0,0 +1,82 @@ +# type: ignore +""" +Copied from https://github.com/taleinat/python-stdlib-sentinels +PEP-0661: Status: Draft +""" +import sys as _sys +from typing import Optional + + +__all__ = ["sentinel"] + + +def sentinel( + name: str, + repr: Optional[str] = None, + module: Optional[str] = None, +): + """Create a unique sentinel object. + + *name* should be the fully-qualified name of the variable to which the + return value shall be assigned. + + *repr*, if supplied, will be used for the repr of the sentinel object. + If not provided, "" will be used (with any leading class names + removed). + + *module*, if supplied, will be used as the module name for the purpose + of setting a unique name for the sentinels unique class. The class is + set as an attribute of this name on the "sentinels" module, so that it + may be found by the pickling mechanism. In most cases, the module name + does not need to be provided, and it will be found by inspecting the + stack frame. + """ + name = _sys.intern(str(name)) + repr = repr or f'<{name.rsplit(".", 1)[-1]}>' + + if module is None: + try: + module = _get_parent_frame().f_globals.get("__name__", "__main__") + except (AttributeError, ValueError): + pass + class_name = _sys.intern(_get_class_name(name, module)) + + class_namespace = { + "__repr__": lambda self: repr, + } + cls = type(class_name, (), class_namespace) + cls.__module__ = __name__ + globals()[class_name] = cls + + sentinel = cls() + + def __new__(cls): + return sentinel + + __new__.__qualname__ = f"{class_name}.__new__" + cls.__new__ = __new__ + + return sentinel + + +if hasattr(_sys, "_getframe"): + _get_parent_frame = lambda: _sys._getframe(2) +else: # pragma: no cover + + def _get_parent_frame(): + """Return the frame object for the caller's parent stack frame.""" + try: + raise Exception + except Exception: + return _sys.exc_info()[2].tb_frame.f_back.f_back + + +def _get_class_name( + sentinel_qualname: str, + module_name: Optional[str] = None, +) -> str: + return ( + "_sentinel_type__" + f'{module_name.replace(".", "_") + "__" if module_name else ""}' + f'{sentinel_qualname.replace(".", "_")}' + ) diff --git a/sklearn/feature_selection/_rfe.py b/sklearn/feature_selection/_rfe.py index d458936c5267d..a88a5ecad1f16 100644 --- a/sklearn/feature_selection/_rfe.py +++ b/sklearn/feature_selection/_rfe.py @@ -16,6 +16,7 @@ from ..utils._tags import _safe_tags from ..utils.validation import check_is_fitted from ..utils.fixes import delayed +from ..utils import metadata_request_factory from ..utils.deprecation import deprecated from ..base import BaseEstimator from ..base import MetaEstimatorMixin @@ -23,12 +24,13 @@ from ..base import is_classifier from ..model_selection import check_cv from ..model_selection._validation import _score +from ..model_selection._search import CVMetadataRequester from ..metrics import check_scoring from ._base import SelectorMixin from ._base import _get_feature_importances -def _rfe_single_fit(rfe, estimator, X, y, train, test, scorer): +def _rfe_single_fit(rfe, estimator, X, y, train, test, scorer, score_params): """ Return the score for a fit across one fold. """ @@ -38,7 +40,7 @@ def _rfe_single_fit(rfe, estimator, X, y, train, test, scorer): X_train, y_train, lambda estimator, features: _score( - estimator, X_test[:, features], y_test, scorer + estimator, X_test[:, features], y_test, scorer, score_params=score_params ), ).scores_ @@ -270,6 +272,10 @@ def _fit(self, X, y, step_score=None, **fit_params): if step_score: self.scores_ = [] + # since ignore_extras=False, validation is done as well + estimator_fit_params = metadata_request_factory( + self.estimator + ).fit.get_method_input(ignore_extras=False, kwargs=fit_params) # Elimination while np.sum(support_) > n_features_to_select: # Remaining features @@ -280,7 +286,7 @@ def _fit(self, X, y, step_score=None, **fit_params): if self.verbose > 0: print("Fitting estimator with %d features." % np.sum(support_)) - estimator.fit(X[:, features], y, **fit_params) + estimator.fit(X[:, features], y, **estimator_fit_params) # Get importance and rank them importances = _get_feature_importances( @@ -307,7 +313,7 @@ def _fit(self, X, y, step_score=None, **fit_params): # Set final attributes features = np.arange(n_features)[support_] self.estimator_ = clone(self.estimator) - self.estimator_.fit(X[:, features], y, **fit_params) + self.estimator_.fit(X[:, features], y, **estimator_fit_params) # Compute step score when only n_features_to_select features left if step_score: @@ -336,7 +342,7 @@ def predict(self, X): return self.estimator_.predict(self.transform(X)) @if_delegate_has_method(delegate="estimator") - def score(self, X, y, **fit_params): + def score(self, X, y, **score_params): """Reduce X to the selected features and return the score of the underlying estimator. Parameters @@ -347,7 +353,7 @@ def score(self, X, y, **fit_params): y : array of shape [n_samples] The target values. - **fit_params : dict + **score_params : dict Parameters to pass to the `score` method of the underlying estimator. @@ -360,7 +366,11 @@ def score(self, X, y, **fit_params): features returned by `rfe.transform(X)` and `y`. """ check_is_fitted(self) - return self.estimator_.score(self.transform(X), y, **fit_params) + # since ignore_extras=False, validation is done as well + score_params = metadata_request_factory(self.estimator).score.get_method_input( + ignore_extras=False, kwargs=score_params + ) + return self.estimator_.score(self.transform(X), y, **score_params) def _get_support_mask(self): check_is_fitted(self) @@ -433,8 +443,20 @@ def _more_tags(self): "requires_y": True, } + def get_metadata_request(self): + """Get requested data properties. + + Returns + ------- + request : dict + A dict of dict of str->value. The key to the first dict is the name + of the method, and the key to the second dict is the name of the + argument requested by the method. + """ + return self.estimator.get_metadata_request() + -class RFECV(RFE): +class RFECV(CVMetadataRequester, RFE): """Recursive feature elimination with cross-validation to select the number of features. See glossary entry for :term:`cross-validation estimator`. @@ -632,7 +654,7 @@ def __init__( self.n_jobs = n_jobs self.min_features_to_select = min_features_to_select - def fit(self, X, y, groups=None): + def fit(self, X, y, groups=None, **kwargs): """Fit the RFE model and automatically tune the number of selected features. Parameters @@ -652,6 +674,11 @@ def fit(self, X, y, groups=None): .. versionadded:: 0.20 + **kwargs : dict + Extra parameteres passed to the underlying scorer. + + .. versionadded:: 1.1 + Returns ------- self : object @@ -666,9 +693,17 @@ def fit(self, X, y, groups=None): force_all_finite=not tags.get("allow_nan", True), multi_output=True, ) - + metadata_request_factory(self).fit.validate_metadata( + ignore_extras=False, kwargs=kwargs + ) + score_params = metadata_request_factory(self).score.get_method_input( + ignore_extras=True, kwargs=kwargs + ) # Initialization - cv = check_cv(self.cv, y, classifier=is_classifier(self.estimator)) + cv_params = metadata_request_factory(self.cv).split.get_method_input( + ignore_extras=True, kwargs=kwargs + ) + cv = check_cv(self.cv, y, classifier=is_classifier(self.estimator), **cv_params) scorer = check_scoring(self.estimator, scoring=self.scoring) n_features = X.shape[1] @@ -708,7 +743,16 @@ def fit(self, X, y, groups=None): func = delayed(_rfe_single_fit) scores = parallel( - func(rfe, self.estimator, X, y, train, test, scorer) + func( + rfe, + self.estimator, + X, + y, + train, + test, + scorer, + score_params=score_params, + ) for train, test in cv.split(X, y, groups) ) diff --git a/sklearn/feature_selection/tests/test_rfe.py b/sklearn/feature_selection/tests/test_rfe.py index a8eef65049bd6..4dbc57225d302 100644 --- a/sklearn/feature_selection/tests/test_rfe.py +++ b/sklearn/feature_selection/tests/test_rfe.py @@ -126,12 +126,29 @@ def score(self, X, y, prop=None): return self.svc_.score(X, y) X, y = load_iris(return_X_y=True) - with pytest.raises(ValueError, match="fit: prop cannot be None"): - RFE(estimator=TestEstimator()).fit(X, y) - with pytest.raises(ValueError, match="score: prop cannot be None"): - RFE(estimator=TestEstimator()).fit(X, y, prop="foo").score(X, y) + with pytest.raises( + ValueError, + match=( + "prop is passed but is not explicitly set as requested or not. In" + " method: fit" + ), + ): + RFE(estimator=TestEstimator()).fit(X, y, prop="foo") + + with pytest.raises( + ValueError, + match=( + "prop is passed but is not explicitly set as requested or not. In method:" + " score" + ), + ): + RFE(estimator=TestEstimator().fit_requests(prop=True)).fit( + X, y, prop="foo" + ).score(X, y, prop="bar") - RFE(estimator=TestEstimator()).fit(X, y, prop="foo").score(X, y, prop="foo") + RFE( + estimator=TestEstimator().fit_requests(prop=True).score_requests(prop=True) + ).fit(X, y, prop="foo").score(X, y, prop="foo") @pytest.mark.parametrize("n_features_to_select", [-1, 2.1]) diff --git a/sklearn/inspection/_permutation_importance.py b/sklearn/inspection/_permutation_importance.py index b095cbec9ee49..c6aed5a167464 100644 --- a/sklearn/inspection/_permutation_importance.py +++ b/sklearn/inspection/_permutation_importance.py @@ -15,7 +15,7 @@ def _weights_scorer(scorer, estimator, X, y, sample_weight): if sample_weight is not None: - return scorer(estimator, X, y, sample_weight) + return scorer(estimator, X, y, sample_weight=sample_weight) return scorer(estimator, X, y) diff --git a/sklearn/inspection/tests/test_permutation_importance.py b/sklearn/inspection/tests/test_permutation_importance.py index d68fc718da8b5..789c7cb4224e8 100644 --- a/sklearn/inspection/tests/test_permutation_importance.py +++ b/sklearn/inspection/tests/test_permutation_importance.py @@ -405,7 +405,7 @@ def test_permutation_importance_sample_weight(): y[n_half_samples:] = x[n_half_samples:, 0] + 2 * x[n_half_samples:, 1] # Fitting linear regression with perfect prediction - lr = LinearRegression(fit_intercept=False) + lr = LinearRegression(fit_intercept=False).fit_requests(sample_weight=True) lr.fit(x, y) # When all samples are weighted with the same weights, the ratio of diff --git a/sklearn/linear_model/_coordinate_descent.py b/sklearn/linear_model/_coordinate_descent.py index 91b6b1e584469..e3803bb304498 100644 --- a/sklearn/linear_model/_coordinate_descent.py +++ b/sklearn/linear_model/_coordinate_descent.py @@ -28,6 +28,8 @@ check_is_fitted, column_or_1d, ) +from ..utils.metadata_requests import metadata_request_factory +from ..utils.metadata_requests import MetadataRouter from ..utils.fixes import delayed # mypy error: Module 'sklearn.linear_model' has no attribute '_cd_fast' @@ -1476,7 +1478,7 @@ def _is_multitask(self): def path(X, y, **kwargs): """Compute path with coordinate descent.""" - def fit(self, X, y, sample_weight=None): + def fit(self, X, y, sample_weight=None, **kwargs): """Fit linear model with coordinate descent. Fit is on grid of alphas and best alpha estimated by cross-validation. @@ -1498,11 +1500,26 @@ def fit(self, X, y, sample_weight=None): MSE that is finally used to find the best model is the unweighted mean over the (weighted) MSEs of each test fold. + **kwargs : dict + Other arguments to be passed to the underlying score and CV methods. + + .. versionadded:: 1.1 + Returns ------- self : object Returns an instance of fitted model. """ + if sample_weight is not None: + kwargs["sample_weight"] = sample_weight + router = ( + MetadataRouter() + .add(self._get_estimator()) + .add(check_cv(self.cv), mapping={"fit": "split"}) + ) + metadata_request_factory(router).fit.validate_metadata( + ignore_extras=False, kwargs=kwargs + ) # Do as _deprecate_normalize but without warning as it's raised # below during the refitting on the best alpha. @@ -1527,7 +1544,7 @@ def fit(self, X, y, sample_weight=None): # by the model fitting itself # Need to validate separately here. - # We can't pass multi_ouput=True because that would allow y to be + # We can't pass multi_output=True because that would allow y to be # csr. We also want to allow y to be 64 or 32 but check_X_y only # allows to convert for 64. check_X_params = dict( @@ -1585,8 +1602,6 @@ def fit(self, X, y, sample_weight=None): raise ValueError("Sample weights do not (yet) support sparse matrices.") sample_weight = _check_sample_weight(sample_weight, X, dtype=X.dtype) - model = self._get_estimator() - if self.selection not in ["random", "cyclic"]: raise ValueError("selection should be either random or cyclic.") @@ -1642,9 +1657,11 @@ def fit(self, X, y, sample_weight=None): # init cross-validation generator cv = check_cv(self.cv) - + cv_params = metadata_request_factory(cv).split.get_method_input( + ignore_extras=True, kwargs=kwargs + ) # Compute path for all folds and compute MSE to get the best alpha - folds = list(cv.split(X, y)) + folds = list(cv.split(X, y, **cv_params)) best_mse = np.inf # We do a double for loop folded in one, in order to be able to @@ -1695,6 +1712,10 @@ def fit(self, X, y, sample_weight=None): else: self.alphas_ = np.asarray(alphas[0]) + model = self._get_estimator() + model_fit_params = metadata_request_factory(model).fit.get_method_input( + ignore_extras=True, kwargs=kwargs + ) # Refit the model with the parameters selected common_params = { name: value @@ -1708,13 +1729,7 @@ def fit(self, X, y, sample_weight=None): precompute = getattr(self, "precompute", None) if isinstance(precompute, str) and precompute == "auto": model.precompute = False - - if sample_weight is None: - # MultiTaskElasticNetCV does not (yet) support sample_weight, even - # not sample_weight=None. - model.fit(X, y) - else: - model.fit(X, y, sample_weight=sample_weight) + model.fit(X, y, **model_fit_params) if not hasattr(self, "l1_ratio"): del self.l1_ratio_ self.coef_ = model.coef_ @@ -1734,6 +1749,25 @@ def _more_tags(self): } } + def get_metadata_request(self): + """Get requested data properties. + + .. versionadded:: 1.1 + + Returns + ------- + request : dict + A dict of dict of str->value. The key to the first dict is the name + of the method, and the key to the second dict is the name of the + argument requested by the method. + """ + router = ( + MetadataRouter() + .add(super(), mapping="one-to-one", overwrite=True, mask=False) + .add(check_cv(self.cv), mapping={"fit": "split"}, overwrite=True, mask=True) + ) + return router.get_metadata_request() + class LassoCV(RegressorMixin, LinearModelCV): """Lasso linear model with iterative fitting along a regularization path. @@ -1945,7 +1979,7 @@ def __init__( ) def _get_estimator(self): - return Lasso() + return Lasso().fit_requests(sample_weight=True) def _is_multitask(self): return False @@ -2190,7 +2224,7 @@ def __init__( self.selection = selection def _get_estimator(self): - return ElasticNet() + return ElasticNet().fit_requests(sample_weight=True) def _is_multitask(self): return False @@ -2831,7 +2865,7 @@ def __init__( self.selection = selection def _get_estimator(self): - return MultiTaskElasticNet() + return MultiTaskElasticNet().fit_requests(sample_weight=True) def _is_multitask(self): return True @@ -3060,7 +3094,7 @@ def __init__( ) def _get_estimator(self): - return MultiTaskLasso() + return MultiTaskLasso().fit_requests(sample_weight=True) def _is_multitask(self): return True diff --git a/sklearn/linear_model/_logistic.py b/sklearn/linear_model/_logistic.py index 08e71edbc69ab..fe55a911b9359 100644 --- a/sklearn/linear_model/_logistic.py +++ b/sklearn/linear_model/_logistic.py @@ -24,10 +24,14 @@ from ..svm._base import _fit_liblinear from ..utils import check_array, check_consistent_length, compute_class_weight from ..utils import check_random_state +from ..utils import MetadataRouter +from ..utils import metadata_request_factory +from ..utils.metadata_requests import METHODS from ..utils.extmath import log_logistic, safe_sparse_dot, softmax, squared_norm from ..utils.extmath import row_norms from ..utils.optimize import _newton_cg, _check_optimize_result from ..utils.validation import check_is_fitted, _check_sample_weight +from ..utils.validation import _check_fit_params from ..utils.multiclass import check_classification_targets from ..utils.fixes import _joblib_parallel_args from ..utils.fixes import delayed @@ -904,6 +908,7 @@ def _log_reg_scoring_path( y, train, test, + *, pos_class=None, Cs=10, scoring=None, @@ -920,6 +925,7 @@ def _log_reg_scoring_path( random_state=None, max_squared_sum=None, sample_weight=None, + score_params=None, l1_ratio=None, ): """Computes scores across logistic_regression_path @@ -1112,7 +1118,10 @@ def _log_reg_scoring_path( if scoring is None: scores.append(log_reg.score(X_test, y_test)) else: - scores.append(scoring(log_reg, X_test, y_test)) + if score_params is None: + score_params = {} + score_params_test = _check_fit_params(X, score_params, test) + scores.append(scoring(log_reg, X_test, y_test, **score_params_test)) return coefs, Cs, np.array(scores), n_iter @@ -2007,7 +2016,7 @@ def __init__( self.random_state = random_state self.l1_ratios = l1_ratios - def fit(self, X, y, sample_weight=None): + def fit(self, X, y, sample_weight=None, **fit_params): """Fit the model according to the given training data. Parameters @@ -2023,11 +2032,18 @@ def fit(self, X, y, sample_weight=None): Array of weights that are assigned to individual samples. If not provided, then each sample is given unit weight. + **fit_params : dict of str -> array-like + Parameters requested by the scorer and the CV splitter. + + .. versionadded:: 1.1 + Returns ------- self : object Fitted LogisticRegressionCV estimator. """ + if sample_weight is not None: + fit_params["sample_weight"] = sample_weight solver = _check_solver(self.solver, self.penalty, self.dual) if not isinstance(self.max_iter, numbers.Number) or self.max_iter < 0: @@ -2107,7 +2123,17 @@ def fit(self, X, y, sample_weight=None): # init cross-validation generator cv = check_cv(self.cv, y, classifier=True) - folds = list(cv.split(X, y)) + scorer = get_scorer(self.scoring) + metadata_request_factory(self).fit.validate_metadata( + ignore_extras=False, self_metadata=super(), kwargs=fit_params + ) + cv_params = metadata_request_factory(cv).split.get_method_input( + ignore_extras=True, kwargs=fit_params + ) + score_params = metadata_request_factory(scorer).score.get_method_input( + ignore_extras=True, kwargs=fit_params + ) + folds = list(cv.split(X, y, **cv_params)) # Use the label encoded classes n_classes = len(encoded_labels) @@ -2178,6 +2204,7 @@ def fit(self, X, y, sample_weight=None): max_squared_sum=max_squared_sum, sample_weight=sample_weight, l1_ratio=l1_ratio, + score_params=score_params, ) for label in iter_encoded_labels for train, test in folds @@ -2391,3 +2418,31 @@ def _more_tags(self): ), } } + + def get_metadata_request(self): + """Get requested data properties. + + .. versionadded:: 1.1 + + Returns + ------- + request : dict + A dict of dict of str->value. The key to the first dict is the name + of the method, and the key to the second dict is the name of the + argument requested by the method. + """ + router = ( + MetadataRouter() + .add( + super(), + mapping={m: m for m in METHODS if m != "score"}, + mask=False, + ) + .add(check_cv(self.cv), mapping={"fit": "split"}, mask=True) + .add( + get_scorer(self.score), + mapping={"fit": "score", "score": "score"}, + mask=True, + ) + ) + return router.get_metadata_request() diff --git a/sklearn/linear_model/_ridge.py b/sklearn/linear_model/_ridge.py index 1dcc81e3b988f..f9cb78572a4af 100644 --- a/sklearn/linear_model/_ridge.py +++ b/sklearn/linear_model/_ridge.py @@ -1977,13 +1977,21 @@ def fit(self, X, y, sample_weight=None): raise ValueError("cv!=None and alpha_per_target=True are incompatible") parameters = {"alpha": self.alphas} solver = "sparse_cg" if sparse.issparse(X) else "auto" - model = RidgeClassifier if is_classifier(self) else Ridge - gs = GridSearchCV( - model( + est = RidgeClassifier if is_classifier(self) else Ridge + est = ( + est( fit_intercept=self.fit_intercept, normalize=self.normalize, solver=solver, - ), + ) + .fit_requests(sample_weight=True) + .score_requests(sample_weight=True) + ) + # The old behavior would be sample_weight=False for "score" + # Do we want to "fix" the issue, or keep the old behavior? + + gs = GridSearchCV( + est, parameters, cv=cv, scoring=self.scoring, diff --git a/sklearn/linear_model/tests/test_coordinate_descent.py b/sklearn/linear_model/tests/test_coordinate_descent.py index dd67c49585bad..75f6b95a18e53 100644 --- a/sklearn/linear_model/tests/test_coordinate_descent.py +++ b/sklearn/linear_model/tests/test_coordinate_descent.py @@ -10,6 +10,7 @@ from sklearn.base import is_classifier from sklearn.base import clone +from sklearn.metrics import get_scorer from sklearn.datasets import load_diabetes from sklearn.datasets import make_regression from sklearn.model_selection import ( @@ -520,14 +521,16 @@ def test_linear_model_sample_weights_normalize_in_pipeline( _scale_alpha_inplace(linear_regressor, sample_weight.sum()) reg_with_scaler = Pipeline( [ - ("scaler", StandardScaler(with_mean=with_mean)), - ("linear_regressor", linear_regressor), + ( + "scaler", + StandardScaler(with_mean=with_mean).fit_requests(sample_weight=True), + ), + ("linear_regressor", linear_regressor.fit_requests(sample_weight=True)), ] ) fit_params = { - "scaler__sample_weight": sample_weight, - "linear_regressor__sample_weight": sample_weight, + "sample_weight": sample_weight, } reg_with_scaler.fit(X_train, y_train, **fit_params) @@ -1558,10 +1561,12 @@ def test_enet_cv_grid_search(sample_weight): param = {"alpha": alphas, "l1_ratio": l1_ratios} gs = GridSearchCV( - estimator=ElasticNet(), + estimator=ElasticNet().fit_requests(sample_weight=True), param_grid=param, cv=cv, - scoring="neg_mean_squared_error", + scoring=get_scorer("neg_mean_squared_error").score_requests( + sample_weight=False + ), ).fit(X, y, sample_weight=sample_weight) assert reg.l1_ratio_ == pytest.approx(gs.best_params_["l1_ratio"]) diff --git a/sklearn/linear_model/tests/test_logistic.py b/sklearn/linear_model/tests/test_logistic.py index 1171613eb3718..41f4ace2c03a6 100644 --- a/sklearn/linear_model/tests/test_logistic.py +++ b/sklearn/linear_model/tests/test_logistic.py @@ -13,7 +13,7 @@ from sklearn.datasets import load_iris, make_classification from sklearn.metrics import log_loss from sklearn.metrics import get_scorer -from sklearn.model_selection import StratifiedKFold +from sklearn.model_selection import StratifiedKFold, GroupKFold from sklearn.model_selection import GridSearchCV from sklearn.model_selection import train_test_split from sklearn.model_selection import cross_val_score @@ -2239,6 +2239,66 @@ def test_sample_weight_not_modified(multi_class, class_weight): assert_allclose(expected, W) +def test_lrcv_metadata_routing(): + X, y = make_classification(n_samples=20, random_state=0) + sample_weight = y + 1 + rng = np.random.RandomState(0) + groups = rng.randint(low=0, high=10, size=len(y)) + err_message = "Metadata passed which is not understood: {param}. In method: fit" + + lrcv = LogisticRegressionCV( + random_state=0, + max_iter=1, + ) + lrcv.fit(X, y) + lrcv.fit(X, y, sample_weight=sample_weight) + lrcv.fit(X, y, sample_weight=None) + + with pytest.raises( + ValueError, + match=re.escape(err_message.format(param=["my_weights"])), + ): + lrcv.fit(X, y, my_weights=sample_weight) + + lrcv = LogisticRegressionCV( + random_state=0, + max_iter=1, + ).fit_requests(sample_weight="my_weight") + lrcv.fit(X, y) + lrcv.fit(X, y, sample_weight=sample_weight) + lrcv.fit(X, y, sample_weight=None) + + with pytest.raises( + ValueError, + match=re.escape(err_message.format(param=["my_weights"])), + ): + lrcv.fit(X, y, my_weights=sample_weight) + + lrcv = LogisticRegressionCV( + random_state=0, + max_iter=1, + cv=GroupKFold(), + ).fit_requests(sample_weight="my_weight") + lrcv.fit(X, y, groups=groups) + lrcv.fit(X, y, sample_weight=sample_weight, groups=groups) + lrcv.fit(X, y, sample_weight=None, groups=groups) + + with pytest.raises( + ValueError, + match=re.escape(err_message.format(param=["my_weights"])), + ): + lrcv.fit(X, y, my_weights=sample_weight, groups=groups) + + with pytest.raises(ValueError, match="The 'groups' parameter should not be None."): + lrcv.fit(X, y) + + with pytest.raises( + ValueError, + match=re.escape(err_message.format(param=["my_groups"])), + ): + lrcv.fit(X, y, my_groups=groups) + + @pytest.mark.parametrize("solver", ["liblinear", "lbfgs", "newton-cg", "sag", "saga"]) def test_large_sparse_matrix(solver): # Solvers either accept large sparse matrices, or raise helpful error. diff --git a/sklearn/linear_model/tests/test_ridge.py b/sklearn/linear_model/tests/test_ridge.py index bfc6722737bd8..6e6ce7e4dd399 100644 --- a/sklearn/linear_model/tests/test_ridge.py +++ b/sklearn/linear_model/tests/test_ridge.py @@ -1122,12 +1122,16 @@ def test_ridgecv_sample_weight(): sample_weight = 1.0 + rng.rand(n_samples) cv = KFold(5) - ridgecv = RidgeCV(alphas=alphas, cv=cv) + ridgecv = RidgeCV(alphas=alphas, cv=cv).fit_requests(sample_weight=True) ridgecv.fit(X, y, sample_weight=sample_weight) # Check using GridSearchCV directly parameters = {"alpha": alphas} - gs = GridSearchCV(Ridge(), parameters, cv=cv) + gs = GridSearchCV( + Ridge().fit_requests(sample_weight=True).score_requests(sample_weight=True), + parameters, + cv=cv, + ) gs.fit(X, y, sample_weight=sample_weight) assert ridgecv.alpha_ == gs.best_estimator_.alpha diff --git a/sklearn/metrics/_scorer.py b/sklearn/metrics/_scorer.py index c5a725ad3a13b..cc46c6eaa3fef 100644 --- a/sklearn/metrics/_scorer.py +++ b/sklearn/metrics/_scorer.py @@ -18,6 +18,7 @@ # Arnaud Joly # License: Simplified BSD +import copy from collections.abc import Iterable from functools import partial from collections import Counter @@ -59,7 +60,10 @@ from .cluster import fowlkes_mallows_score from ..utils.multiclass import type_of_target -from ..base import is_regressor +from ..base import is_regressor, _MetadataRequester +from ..utils import metadata_request_factory +from ..utils import MetadataRequest +from ..utils import MetadataRouter def _cached_call(cache, estimator, method, *args, **kwargs): @@ -98,11 +102,18 @@ def __call__(self, estimator, *args, **kwargs): cache = {} if self._use_cache(estimator) else None cached_call = partial(_cached_call, cache) + metadata_request_factory(self).score.validate_metadata( + ignore_extras=False, kwargs=kwargs + ) + for name, scorer in self._scorers.items(): + params = metadata_request_factory(scorer).score.get_method_input( + ignore_extras=True, kwargs=kwargs + ) if isinstance(scorer, _BaseScorer): - score = scorer._score(cached_call, estimator, *args, **kwargs) + score = scorer._score(cached_call, estimator, *args, **params) else: - score = scorer(estimator, *args, **kwargs) + score = scorer(estimator, *args, **params) scores[name] = score return scores @@ -137,8 +148,28 @@ def _use_cache(self, estimator): return True return False + def get_metadata_request(self): + """Get requested data properties. + + .. versionadded:: 1.1 -class _BaseScorer: + Returns + ------- + request : dict + A dict of dict of str->value. The key to the first dict is the name + of the method, and the key to the second dict is the name of the + argument requested by the method. + """ + router = MetadataRouter().add( + *self._scorers.values(), + mapping={"score": "score"}, + mask=True, + overwrite="smart", + ) + return router.get_metadata_request() + + +class _BaseScorer(_MetadataRequester): def __init__(self, score_func, sign, kwargs): self._kwargs = kwargs self._score_func = score_func @@ -190,7 +221,7 @@ def __repr__(self): kwargs_string, ) - def __call__(self, estimator, X, y_true, sample_weight=None): + def __call__(self, estimator, X, y_true, **kwargs): """Evaluate predicted target values for X relative to y_true. Parameters @@ -205,29 +236,43 @@ def __call__(self, estimator, X, y_true, sample_weight=None): y_true : array-like Gold standard target values for X. - sample_weight : array-like of shape (n_samples,), default=None - Sample weights. + **kwargs : dict + Other parameters passed to the scorer. + + .. versionadded:: 1.1 Returns ------- score : float Score function applied to prediction of estimator on X. """ - return self._score( - partial(_cached_call, None), - estimator, - X, - y_true, - sample_weight=sample_weight, - ) + return self._score(partial(_cached_call, None), estimator, X, y_true, **kwargs) def _factory_args(self): """Return non-default make_scorer arguments for repr.""" return "" + def score_requests(self, **kwargs): + """Set requested parameters by the scorer. + + Note that this method returns a new instance of the scorer, and does + **not** change the original scorer object. + + .. versionadded:: 1.1 + + Parameters + ---------- + kwargs : dict + Arguments should be of the form param_name={True, False, None, str}. + The value can also be of the form RequestType + """ + res = copy.deepcopy(self) + res._metadata_request = MetadataRequest({"score": kwargs}) + return res + class _PredictScorer(_BaseScorer): - def _score(self, method_caller, estimator, X, y_true, sample_weight=None): + def _score(self, method_caller, estimator, X, y_true, **kwargs): """Evaluate predicted target values for X relative to y_true. Parameters @@ -246,8 +291,10 @@ def _score(self, method_caller, estimator, X, y_true, sample_weight=None): y_true : array-like Gold standard target values for X. - sample_weight : array-like of shape (n_samples,), default=None - Sample weights. + **kwargs : dict + Other parameters passed to the scorer. + + .. versionadded:: 1.1 Returns ------- @@ -256,16 +303,13 @@ def _score(self, method_caller, estimator, X, y_true, sample_weight=None): """ y_pred = method_caller(estimator, "predict", X) - if sample_weight is not None: - return self._sign * self._score_func( - y_true, y_pred, sample_weight=sample_weight, **self._kwargs - ) - else: - return self._sign * self._score_func(y_true, y_pred, **self._kwargs) + scoring_kwargs = copy.deepcopy(self._kwargs) + scoring_kwargs.update(kwargs) + return self._sign * self._score_func(y_true, y_pred, **scoring_kwargs) class _ProbaScorer(_BaseScorer): - def _score(self, method_caller, clf, X, y, sample_weight=None): + def _score(self, method_caller, clf, X, y, **kwargs): """Evaluate predicted probabilities for X relative to y_true. Parameters @@ -285,8 +329,10 @@ def _score(self, method_caller, clf, X, y, sample_weight=None): Gold standard target values for X. These must be class labels, not probabilities. - sample_weight : array-like, default=None - Sample weights. + **kwargs : dict + Other parameters passed to the scorer. + + .. versionadded:: 1.1 Returns ------- @@ -301,19 +347,22 @@ def _score(self, method_caller, clf, X, y, sample_weight=None): # problem: (when only 2 class are given to `y_true` during scoring) # Thus, we need to check for the shape of `y_pred`. y_pred = self._select_proba_binary(y_pred, clf.classes_) - if sample_weight is not None: - return self._sign * self._score_func( - y, y_pred, sample_weight=sample_weight, **self._kwargs - ) - else: - return self._sign * self._score_func(y, y_pred, **self._kwargs) + + scoring_kwargs = copy.deepcopy(self._kwargs) + scoring_kwargs.update(kwargs) + # this is for backward compatibility to avoid passing sample_weight + # to the scorer if it's None + if scoring_kwargs.get("sample_weight", -1) is None: + del scoring_kwargs["sample_weight"] + + return self._sign * self._score_func(y, y_pred, **scoring_kwargs) def _factory_args(self): return ", needs_proba=True" class _ThresholdScorer(_BaseScorer): - def _score(self, method_caller, clf, X, y, sample_weight=None): + def _score(self, method_caller, clf, X, y, **kwargs): """Evaluate decision function output for X relative to y_true. Parameters @@ -335,8 +384,10 @@ def _score(self, method_caller, clf, X, y, sample_weight=None): Gold standard target values for X. These must be class labels, not decision function values. - sample_weight : array-like, default=None - Sample weights. + **kwargs : dict + Other parameters passed to the scorer. + + .. versionadded:: 1.1 Returns ------- @@ -373,12 +424,13 @@ def _score(self, method_caller, clf, X, y, sample_weight=None): elif isinstance(y_pred, list): y_pred = np.vstack([p[:, -1] for p in y_pred]).T - if sample_weight is not None: - return self._sign * self._score_func( - y, y_pred, sample_weight=sample_weight, **self._kwargs - ) - else: - return self._sign * self._score_func(y, y_pred, **self._kwargs) + scoring_kwargs = copy.deepcopy(self._kwargs) + scoring_kwargs.update(kwargs) + # this is for backward compatibility to avoid passing sample_weight + # to the scorer if it's None + if scoring_kwargs.get("sample_weight", -1) is None: + del scoring_kwargs["sample_weight"] + return self._sign * self._score_func(y, y_pred, **scoring_kwargs) def _factory_args(self): return ", needs_threshold=True" @@ -413,9 +465,30 @@ def get_scorer(scoring): return scorer -def _passthrough_scorer(estimator, *args, **kwargs): - """Function that wraps estimator.score""" - return estimator.score(*args, **kwargs) +class _passthrough_scorer: + def __init__(self, estimator): + self._estimator = estimator + + def __call__(self, estimator, *args, **kwargs): + """Function that wraps estimator.score""" + return estimator.score(*args, **kwargs) + + def get_metadata_request(self): + """Get requested data properties. + + .. versionadded:: 1.1 + + Returns + ------- + request : dict + A dict of dict of str->value. The key to the first dict is the name + of the method, and the key to the second dict is the name of the + argument requested by the method. + """ + router = MetadataRouter().add( + self._estimator, mapping={"score": "score"}, mask=False + ) + return router.get_metadata_request() def check_scoring(estimator, scoring=None, *, allow_none=False): @@ -470,7 +543,7 @@ def check_scoring(estimator, scoring=None, *, allow_none=False): return get_scorer(scoring) elif scoring is None: if hasattr(estimator, "score"): - return _passthrough_scorer + return _passthrough_scorer(estimator) elif allow_none: return None else: @@ -660,6 +733,8 @@ def make_scorer( output of :term:`decision_function` or :term:`predict_proba` when :term:`decision_function` is not present. """ + if "score_params" in kwargs: + raise Exception("aaaaaaagh") sign = 1 if greater_is_better else -1 if needs_proba and needs_threshold: raise ValueError( @@ -671,7 +746,8 @@ def make_scorer( cls = _ThresholdScorer else: cls = _PredictScorer - return cls(score_func, sign, kwargs) + res = cls(score_func, sign, kwargs) + return res # Standard regression scores @@ -686,13 +762,16 @@ def make_scorer( mean_absolute_error, greater_is_better=False ) neg_mean_absolute_percentage_error_scorer = make_scorer( - mean_absolute_percentage_error, greater_is_better=False + mean_absolute_percentage_error, + greater_is_better=False, ) neg_median_absolute_error_scorer = make_scorer( median_absolute_error, greater_is_better=False ) neg_root_mean_squared_error_scorer = make_scorer( - mean_squared_error, greater_is_better=False, squared=False + mean_squared_error, + greater_is_better=False, + squared=False, ) neg_mean_poisson_deviance_scorer = make_scorer( mean_poisson_deviance, greater_is_better=False @@ -708,28 +787,42 @@ def make_scorer( # Score functions that need decision values top_k_accuracy_scorer = make_scorer( - top_k_accuracy_score, greater_is_better=True, needs_threshold=True + top_k_accuracy_score, + greater_is_better=True, + needs_threshold=True, ) roc_auc_scorer = make_scorer( - roc_auc_score, greater_is_better=True, needs_threshold=True + roc_auc_score, + greater_is_better=True, + needs_threshold=True, ) average_precision_scorer = make_scorer(average_precision_score, needs_threshold=True) roc_auc_ovo_scorer = make_scorer(roc_auc_score, needs_proba=True, multi_class="ovo") roc_auc_ovo_weighted_scorer = make_scorer( - roc_auc_score, needs_proba=True, multi_class="ovo", average="weighted" + roc_auc_score, + needs_proba=True, + multi_class="ovo", + average="weighted", ) roc_auc_ovr_scorer = make_scorer(roc_auc_score, needs_proba=True, multi_class="ovr") roc_auc_ovr_weighted_scorer = make_scorer( - roc_auc_score, needs_proba=True, multi_class="ovr", average="weighted" + roc_auc_score, + needs_proba=True, + multi_class="ovr", + average="weighted", ) # Score function for probabilistic classification neg_log_loss_scorer = make_scorer(log_loss, greater_is_better=False, needs_proba=True) neg_brier_score_scorer = make_scorer( - brier_score_loss, greater_is_better=False, needs_proba=True + brier_score_loss, + greater_is_better=False, + needs_proba=True, ) brier_score_loss_scorer = make_scorer( - brier_score_loss, greater_is_better=False, needs_proba=True + brier_score_loss, + greater_is_better=False, + needs_proba=True, ) diff --git a/sklearn/metrics/tests/test_score_objects.py b/sklearn/metrics/tests/test_score_objects.py index 65d8efebe775f..b62f952e04e2f 100644 --- a/sklearn/metrics/tests/test_score_objects.py +++ b/sklearn/metrics/tests/test_score_objects.py @@ -234,7 +234,7 @@ def check_scoring_validator_for_single_metric_usecases(scoring_validator): estimator = EstimatorWithFitAndScore() estimator.fit([[1]], [1]) scorer = scoring_validator(estimator) - assert scorer is _passthrough_scorer + assert isinstance(scorer, _passthrough_scorer) assert_almost_equal(scorer(estimator, [[1]], [1]), 1.0) estimator = EstimatorWithFitAndPredict() @@ -626,11 +626,14 @@ def test_classification_scorer_sample_weight(): else: target = y_test try: + scorer = scorer.score_requests(sample_weight=True) weighted = scorer( estimator[name], X_test, target, sample_weight=sample_weight ) ignored = scorer(estimator[name], X_test[10:], target[10:]) unweighted = scorer(estimator[name], X_test, target) + # this should not raise. sample_weight should be ignored if None. + _ = scorer(estimator[name], X_test[:10], target[:10], sample_weight=None) assert weighted != unweighted, ( f"scorer {name} behaves identically when called with " f"sample weights: {weighted} vs {unweighted}" @@ -674,6 +677,7 @@ def test_regression_scorer_sample_weight(): # skip classification scorers continue try: + scorer = scorer.score_requests(sample_weight=True) weighted = scorer(reg, X_test, y_test, sample_weight=sample_weight) ignored = scorer(reg, X_test[11:], y_test[11:]) unweighted = scorer(reg, X_test, y_test) diff --git a/sklearn/model_selection/_search.py b/sklearn/model_selection/_search.py index c13e5b6643ce1..4163acd4aaf4b 100644 --- a/sklearn/model_selection/_search.py +++ b/sklearn/model_selection/_search.py @@ -35,16 +35,68 @@ from ..exceptions import NotFittedError from joblib import Parallel from ..utils import check_random_state +from ..utils import MetadataRouter +from ..utils import metadata_request_factory +from ..utils.metadata_requests import METHODS from ..utils.random import sample_without_replacement from ..utils._tags import _safe_tags from ..utils.validation import indexable, check_is_fitted, _check_fit_params from ..utils.metaestimators import available_if from ..utils.fixes import delayed from ..metrics._scorer import _check_multimetric_scoring +from ..metrics._scorer import _MultimetricScorer from ..metrics import check_scoring from ..utils import deprecated -__all__ = ["GridSearchCV", "ParameterGrid", "ParameterSampler", "RandomizedSearchCV"] +__all__ = [ + "GridSearchCV", + "ParameterGrid", + "ParameterSampler", + "RandomizedSearchCV", + "CVMetadataRequester", +] + + +class CVMetadataRequester: + def get_metadata_request(self): + """Get requested data properties. + + .. versionadded:: 1.1 + + Returns + ------- + request : dict + A dict of dict of str->value. The key to the first dict is the name + of the method, and the key to the second dict is the name of the + argument requested by the method. + """ + # if the *CV estimator doesn't take any scoring function, we take it from the + # estimator. + scoring = getattr(self, "scoring", None) + if callable(scoring): + scorers = [scoring] + elif scoring is None or isinstance(scoring, str): + scorers = [check_scoring(self.estimator, scoring)] + else: + scorers = _check_multimetric_scoring(self.estimator, scoring).values() + + router = ( + MetadataRouter() + .add(*scorers, mapping={"fit": "score", "score": "score"}, mask=True) + .add( + self.estimator, + mapping={m: m for m in METHODS if m != "score"}, + mask=True, + overwrite="smart", + ) + .add( + check_cv(self.cv), + mapping={"fit": "split"}, + mask=True, + overwrite="smart", + ) + ) + return router.get_metadata_request() class ParameterGrid: @@ -379,7 +431,9 @@ def check(self): return check -class BaseSearchCV(MetaEstimatorMixin, BaseEstimator, metaclass=ABCMeta): +class BaseSearchCV( + CVMetadataRequester, MetaEstimatorMixin, BaseEstimator, metaclass=ABCMeta +): """Abstract base class for hyper parameter search with cross-validation.""" @abstractmethod @@ -784,23 +838,47 @@ def fit(self, X, y=None, *, groups=None, **fit_params): self : object Instance of fitted estimator. """ + if groups is not None: + fit_params.update({"groups": groups}) estimator = self.estimator refit_metric = "score" if callable(self.scoring): - scorers = self.scoring + scorers = score_router = self.scoring elif self.scoring is None or isinstance(self.scoring, str): - scorers = check_scoring(self.estimator, self.scoring) + scorers = score_router = check_scoring(self.estimator, self.scoring) else: scorers = _check_multimetric_scoring(self.estimator, self.scoring) self._check_refit_for_multimetric(scorers) refit_metric = self.refit - - X, y, groups = indexable(X, y, groups) - fit_params = _check_fit_params(X, fit_params) + score_router = _MultimetricScorer(**scorers) cv_orig = check_cv(self.cv, y, classifier=is_classifier(estimator)) - n_splits = cv_orig.get_n_splits(X, y, groups) + + metadata_request_factory(self).fit.validate_metadata( + ignore_extras=False, kwargs=fit_params + ) + _fit_params = metadata_request_factory(estimator).fit.get_method_input( + ignore_extras=True, kwargs=fit_params + ) + _score_params = metadata_request_factory(score_router).score.get_method_input( + ignore_extras=True, kwargs=fit_params + ) + _cv_params = metadata_request_factory(cv_orig).split.get_method_input( + ignore_extras=True, kwargs=fit_params + ) + + _cv_param_values = _cv_params.values() + _cv_param_names = _cv_params.keys() + indexables = indexable(X, y, *_cv_param_values) + X, y = indexables[0], indexables[1] + _cv_param_values = indexables[2:] if len(indexables) > 2 else [] + _cv_params = { + name: value for name, value in zip(_cv_param_names, _cv_param_values) + } + _fit_params = _check_fit_params(X, _fit_params) + + n_splits = cv_orig.get_n_splits(X, y, **_cv_params) base_estimator = clone(self.estimator) @@ -808,7 +886,8 @@ def fit(self, X, y=None, *, groups=None, **fit_params): fit_and_score_kwargs = dict( scorer=scorers, - fit_params=fit_params, + fit_params=_fit_params, + score_params=_score_params, return_train_score=self.return_train_score, return_n_test_samples=True, return_times=True, @@ -923,9 +1002,9 @@ def evaluate_candidates(candidate_params, cv=None, more_results=None): ) refit_start_time = time.time() if y is not None: - self.best_estimator_.fit(X, y, **fit_params) + self.best_estimator_.fit(X, y, **_fit_params) else: - self.best_estimator_.fit(X, **fit_params) + self.best_estimator_.fit(X, **_fit_params) refit_end_time = time.time() self.refit_time_ = refit_end_time - refit_start_time diff --git a/sklearn/model_selection/_split.py b/sklearn/model_selection/_split.py index 7ef00124b5038..84026efcc1c45 100644 --- a/sklearn/model_selection/_split.py +++ b/sklearn/model_selection/_split.py @@ -28,7 +28,8 @@ from ..utils.validation import _num_samples, column_or_1d from ..utils.validation import check_array from ..utils.multiclass import type_of_target -from ..base import _pprint +from ..utils.metadata_requests import RequestType +from ..base import _pprint, _MetadataRequester __all__ = [ "BaseCrossValidator", @@ -51,12 +52,25 @@ ] -class BaseCrossValidator(metaclass=ABCMeta): +class GroupsComsumerMixin(_MetadataRequester): + """A Mixin to add support for ``groups`` + + .. versionadded:: 1.1 + """ + + _metadata_request__groups = {"split": {"groups": RequestType.REQUESTED}} + + +class BaseCrossValidator(_MetadataRequester, metaclass=ABCMeta): """Base class for all cross-validators Implementations must define `_iter_test_masks` or `_iter_test_indices`. """ + # This indicates that by default CV splitters don't have a "groups" kwarg, + # unless indicated by inheriting from ``GroupsComsumerMixin``. + _metadata_request__groups = {"split": {"groups": RequestType.UNUSED}} + def split(self, X, y=None, groups=None): """Generate indices to split data into training and test set. @@ -450,7 +464,7 @@ def _iter_test_indices(self, X, y=None, groups=None): current = stop -class GroupKFold(_BaseKFold): +class GroupKFold(GroupsComsumerMixin, _BaseKFold): """K-fold iterator variant with non-overlapping groups. The same group will not appear in two different folds (the number of @@ -1098,7 +1112,7 @@ def split(self, X, y=None, groups=None): ) -class LeaveOneGroupOut(BaseCrossValidator): +class LeaveOneGroupOut(GroupsComsumerMixin, BaseCrossValidator): """Leave One Group Out cross-validator Provides train/test indices to split data according to a third-party @@ -1208,7 +1222,7 @@ def split(self, X, y=None, groups=None): return super().split(X, y, groups) -class LeavePGroupsOut(BaseCrossValidator): +class LeavePGroupsOut(GroupsComsumerMixin, BaseCrossValidator): """Leave P Group(s) Out cross-validator Provides train/test indices to split data according to a third-party @@ -1340,7 +1354,7 @@ def split(self, X, y=None, groups=None): return super().split(X, y, groups) -class _RepeatedSplits(metaclass=ABCMeta): +class _RepeatedSplits(_MetadataRequester, metaclass=ABCMeta): """Repeated splits for an arbitrary randomized CV splitter. Repeats splits for cross-validators n times with different randomization @@ -1554,9 +1568,13 @@ def __init__(self, *, n_splits=5, n_repeats=10, random_state=None): ) -class BaseShuffleSplit(metaclass=ABCMeta): +class BaseShuffleSplit(_MetadataRequester, metaclass=ABCMeta): """Base class for ShuffleSplit and StratifiedShuffleSplit""" + # This indicates that by default CV splitters don't have a "groups" kwarg, + # unless indicated by inheriting from ``GroupsComsumerMixin``. + _metadata_request__groups = {"split": {"groups": RequestType.UNUSED}} + def __init__( self, n_splits=10, *, test_size=None, train_size=None, random_state=None ): @@ -1721,7 +1739,7 @@ def _iter_indices(self, X, y=None, groups=None): yield ind_train, ind_test -class GroupShuffleSplit(ShuffleSplit): +class GroupShuffleSplit(GroupsComsumerMixin, ShuffleSplit): """Shuffle-Group(s)-Out cross-validation iterator Provides randomized train/test indices to split data according to a diff --git a/sklearn/model_selection/_validation.py b/sklearn/model_selection/_validation.py index 760418b7d8f54..5d5b09f69ef2b 100644 --- a/sklearn/model_selection/_validation.py +++ b/sklearn/model_selection/_validation.py @@ -26,6 +26,7 @@ from ..utils import indexable, check_random_state, _safe_indexing from ..utils.validation import _check_fit_params from ..utils.validation import _num_samples +from ..utils import metadata_request_factory from ..utils.fixes import delayed from ..utils.metaestimators import _safe_split from ..metrics import check_scoring @@ -56,6 +57,7 @@ def cross_validate( n_jobs=None, verbose=0, fit_params=None, + props=None, pre_dispatch="2*n_jobs", return_train_score=False, return_estimator=False, @@ -134,6 +136,11 @@ def cross_validate( fit_params : dict, default=None Parameters to pass to the fit method of the estimator. + props : dict, default=None + The metadata required to be passed to the underlying relevant methods. + + .. versionadded:: 1.1 + pre_dispatch : int or str, default='2*n_jobs' Controls the number of jobs that get dispatched during parallel execution. Reducing this number can be useful to avoid an @@ -250,17 +257,43 @@ def cross_validate( loss function. """ - X, y, groups = indexable(X, y, groups) - cv = check_cv(cv, y, classifier=is_classifier(estimator)) + if fit_params is not None: + warnings.warn( + "fit_params is deprecated and will be removed in version 1.3. " + "Please use props instead.", + FutureWarning, + ) + props = {} if props is None else props + props.update(fit_params) + if callable(scoring): - scorers = scoring + scorers = score_router = scoring elif scoring is None or isinstance(scoring, str): - scorers = check_scoring(estimator, scoring) + scorers = score_router = check_scoring(estimator, scoring) else: scorers = _check_multimetric_scoring(estimator, scoring) + score_router = _MultimetricScorer(**scorers) + props = {} if props is None else props + _fit_params = metadata_request_factory(estimator).fit.get_method_input( + ignore_extras=True, kwargs=props + ) + _score_params = metadata_request_factory(score_router).score.get_method_input( + ignore_extras=True, kwargs=props + ) + _cv_params = metadata_request_factory(cv).split.get_method_input( + ignore_extras=True, kwargs=props + ) + + _cv_param_values = _cv_params.values() + _cv_param_names = _cv_params.keys() + indexables = indexable(X, y, *_cv_param_values) + X, y = indexables[0], indexables[1] + _cv_param_values = indexables[2:] if len(indexables) > 2 else [] + _cv_params = {name: value for name, value in zip(_cv_param_names, _cv_param_values)} + _fit_params = _check_fit_params(X, _fit_params) # We clone the estimator to make sure that all the folds are # independent, and that it is pickle-able. parallel = Parallel(n_jobs=n_jobs, verbose=verbose, pre_dispatch=pre_dispatch) @@ -274,13 +307,14 @@ def cross_validate( test, verbose, None, - fit_params, + fit_params=_fit_params, + score_params=_score_params, return_train_score=return_train_score, return_times=True, return_estimator=return_estimator, error_score=error_score, ) - for train, test in cv.split(X, y, groups) + for train, test in cv.split(X, y, **_cv_params) ) _warn_about_fit_failures(results, error_score) @@ -383,6 +417,7 @@ def cross_val_score( n_jobs=None, verbose=0, fit_params=None, + props=None, pre_dispatch="2*n_jobs", error_score=np.nan, ): @@ -452,6 +487,11 @@ def cross_val_score( fit_params : dict, default=None Parameters to pass to the fit method of the estimator. + props : dict, default=None + The metadata required to be passed to the underlying relevant methods. + + .. versionadded:: 1.1 + pre_dispatch : int or str, default='2*n_jobs' Controls the number of jobs that get dispatched during parallel execution. Reducing this number can be useful to avoid an @@ -517,6 +557,7 @@ def cross_val_score( n_jobs=n_jobs, verbose=verbose, fit_params=fit_params, + props=props, pre_dispatch=pre_dispatch, error_score=error_score, ) @@ -533,6 +574,7 @@ def _fit_and_score( verbose, parameters, fit_params, + score_params, return_train_score=False, return_parameters=False, return_n_test_samples=False, @@ -587,6 +629,11 @@ def _fit_and_score( fit_params : dict or None Parameters that will be passed to ``estimator.fit``. + score_params : dict or None + Parameters that will be passed to the scorer. + + .. versionadded:: 1.1 + return_train_score : bool, default=False Compute and return score on training set. @@ -657,6 +704,9 @@ def _fit_and_score( # Adjust length of sample weights fit_params = fit_params if fit_params is not None else {} fit_params = _check_fit_params(X, fit_params, train) + score_params_train = score_params if score_params is not None else {} + score_params_train = _check_fit_params(X, score_params, train) + score_params_test = _check_fit_params(X, score_params, test) if parameters is not None: # clone after setting parameters in case any parameters @@ -700,10 +750,24 @@ def _fit_and_score( result["fit_error"] = None fit_time = time.time() - start_time - test_scores = _score(estimator, X_test, y_test, scorer, error_score) + test_scores = _score( + estimator, + X_test, + y_test, + scorer, + score_params=score_params_test, + error_score=error_score, + ) score_time = time.time() - start_time - fit_time if return_train_score: - train_scores = _score(estimator, X_train, y_train, scorer, error_score) + train_scores = _score( + estimator, + X_train, + y_train, + scorer, + score_params=score_params_train, + error_score=error_score, + ) if verbose > 1: total_time = score_time + fit_time @@ -745,21 +809,23 @@ def _fit_and_score( return result -def _score(estimator, X_test, y_test, scorer, error_score="raise"): +def _score(estimator, X_test, y_test, scorer, *, score_params, error_score="raise"): """Compute the score(s) of an estimator on a given test set. Will return a dict of floats if `scorer` is a dict, otherwise a single float is returned. """ + if score_params is None: + score_params = {} if isinstance(scorer, dict): # will cache method calls if needed. scorer() returns a dict scorer = _MultimetricScorer(**scorer) try: if y_test is None: - scores = scorer(estimator, X_test) + scores = scorer(estimator, X_test, **score_params) else: - scores = scorer(estimator, X_test, y_test) + scores = scorer(estimator, X_test, y_test, **score_params) except Exception: if error_score == "raise": raise @@ -1560,8 +1626,10 @@ def learning_curve( train, test, verbose, + # TODO: support score_params here, hint: cross_validate parameters=None, fit_params=fit_params, + score_params=None, return_train_score=True, error_score=error_score, return_times=return_times, @@ -1690,8 +1758,27 @@ def _incremental_fit_estimator( start_score = time.time() - test_scores.append(_score(estimator, X_test, y_test, scorer, error_score)) - train_scores.append(_score(estimator, X_train, y_train, scorer, error_score)) + # TODO: support score_params here + test_scores.append( + _score( + estimator, + X_test, + y_test, + scorer, + score_params=None, + error_score=error_score, + ) + ) + train_scores.append( + _score( + estimator, + X_train, + y_train, + scorer, + score_params=None, + error_score=error_score, + ) + ) score_time = time.time() - start_score score_times.append(score_time) @@ -1837,8 +1924,10 @@ def validation_curve( train, test, verbose, + # TODO: support score_params here, hint: cross_validate parameters={param_name: v}, fit_params=fit_params, + score_params=None, return_train_score=True, error_score=error_score, ) diff --git a/sklearn/model_selection/tests/test_search.py b/sklearn/model_selection/tests/test_search.py index 6960a17fb629b..5bf973f76e86b 100644 --- a/sklearn/model_selection/tests/test_search.py +++ b/sklearn/model_selection/tests/test_search.py @@ -23,6 +23,7 @@ MinimalRegressor, MinimalTransformer, ) +from sklearn.utils import MetadataRequest from sklearn.utils._mocking import CheckingClassifier, MockDataFrame from scipy.stats import bernoulli, expon, uniform @@ -50,6 +51,7 @@ from sklearn.model_selection._validation import FitFailedWarning +from sklearn.base import _MetadataRequester from sklearn.svm import LinearSVC, SVC from sklearn.tree import DecisionTreeRegressor from sklearn.tree import DecisionTreeClassifier @@ -75,7 +77,7 @@ # Neither of the following two estimators inherit from BaseEstimator, # to test hyperparameter search on user-defined classifiers. -class MockClassifier: +class MockClassifier(_MetadataRequester): """Dummy classifier to test the parameter search algorithms""" def __init__(self, foo_param=0): @@ -1878,7 +1880,13 @@ def _run_search(self, evaluate): attr[0].islower() and attr[-1:] == "_" and attr - not in {"cv_results_", "best_estimator_", "refit_time_", "classes_"} + not in { + "cv_results_", + "best_estimator_", + "refit_time_", + "classes_", + "scorer_", + } ): assert getattr(gscv, attr) == getattr(mycv, attr), ( "Attribute %s not equal" % attr @@ -1987,6 +1995,9 @@ def __call__(self, estimator, X, y): return np.nan return 1 + def get_metadata_request(self): + return MetadataRequest().to_dict() + grid = SearchCV( DecisionTreeClassifier(), scoring=FailingScorer(), @@ -2258,7 +2269,7 @@ def fit(self, X, y, r=None): def predict(self, X): return np.zeros(shape=(len(X))) - model = SearchCV(TestEstimator(), param_search) + model = SearchCV(TestEstimator().fit_requests(r=True), param_search) X, y = make_classification(random_state=42) model.fit(X, y, r=42) assert model.best_estimator_.r_ == 42 @@ -2305,7 +2316,12 @@ def fit( def _fit_param_callable(): pass - model = SearchCV(_FitParamClassifier(), param_search) + model = SearchCV( + _FitParamClassifier().fit_requests( + tuple_of_arrays=True, callable_param=True, scalar_param=True + ), + param_search, + ) # NOTE: `fit_params` should be data dependent (e.g. `sample_weight`) which # is not the case for the following parameters. But this abuse is common in diff --git a/sklearn/model_selection/tests/test_split.py b/sklearn/model_selection/tests/test_split.py index af8338792ad73..9bfb8066219e2 100644 --- a/sklearn/model_selection/tests/test_split.py +++ b/sklearn/model_selection/tests/test_split.py @@ -1783,7 +1783,7 @@ def test_nested_cv(): Ridge(), param_grid={"alpha": [1, 0.1]}, cv=inner_cv, error_score="raise" ) cross_val_score( - gs, X=X, y=y, groups=groups, cv=outer_cv, fit_params={"groups": groups} + gs, X=X, y=y, groups=groups, cv=outer_cv, props={"groups": groups} ) diff --git a/sklearn/model_selection/tests/test_validation.py b/sklearn/model_selection/tests/test_validation.py index 215ceb5877669..20975dd70ab56 100644 --- a/sklearn/model_selection/tests/test_validation.py +++ b/sklearn/model_selection/tests/test_validation.py @@ -693,7 +693,7 @@ def assert_fit_params(clf): "dummy_obj": DUMMY_OBJ, "callback": assert_fit_params, } - cross_val_score(clf, X, y, fit_params=fit_params) + cross_val_score(clf, X, y, props=fit_params) def test_cross_val_score_score_func(): @@ -1155,7 +1155,7 @@ def test_cross_val_score_sparse_fit_params(): X, y = iris.data, iris.target clf = MockClassifier() fit_params = {"sparse_sample_weight": coo_matrix(np.eye(X.shape[0]))} - a = cross_val_score(clf, X, y, fit_params=fit_params, cv=3) + a = cross_val_score(clf, X, y, props=fit_params, cv=3) assert_array_equal(a, np.ones(3)) @@ -2080,7 +2080,7 @@ def test_fit_and_score_failing(): # dummy X data X = np.arange(1, 10) y = np.ones(9) - fit_and_score_args = [failing_clf, X, None, dict(), None, None, 0, None, None] + fit_and_score_args = [failing_clf, X, None, dict(), None, None, 0, None, None, None] # passing error score to trigger the warning message fit_and_score_kwargs = {"error_score": "raise"} # check if exception was raised, with default error_score='raise' @@ -2124,6 +2124,7 @@ def test_fit_and_score_working(): fit_and_score_kwargs = { "parameters": {"max_iter": 100, "tol": 0.1}, "fit_params": None, + "score_params": None, "return_parameters": True, } result = _fit_and_score(*fit_and_score_args, **fit_and_score_kwargs) @@ -2293,7 +2294,7 @@ def test_fit_and_score_verbosity( train, test = next(ShuffleSplit().split(X)) # test print without train score - fit_and_score_args = [clf, X, y, scorer, train, test, verbose, None, None] + fit_and_score_args = [clf, X, y, scorer, train, test, verbose, None, None, None] fit_and_score_kwargs = { "return_train_score": train_score, "split_progress": split_prg, @@ -2314,9 +2315,15 @@ def test_score(): def two_params_scorer(estimator, X_test): return None - fit_and_score_args = [None, None, None, two_params_scorer] + fit_and_score_args = { + "estimator": None, + "X_test": None, + "y_test": None, + "scorer": two_params_scorer, + "score_params": None, + } with pytest.raises(ValueError, match=error_message): - _score(*fit_and_score_args, error_score=np.nan) + _score(**fit_and_score_args, error_score=np.nan) def test_callable_multimetric_confusion_matrix_cross_validate(): @@ -2357,3 +2364,11 @@ def _more_tags(self): msg = "_pairwise was deprecated in 0.24 and will be removed in 1.1" with pytest.warns(FutureWarning, match=msg): cross_validate(svm, linear_kernel, y, cv=2) + + +def test_cross_validate_deprecated_fit_params(): + X, y = make_classification(random_state=0) + estimator = MockClassifier() + + with pytest.warns(FutureWarning, match="fit_params is deprecated"): + cross_validate(estimator, X, y, fit_params={"param": "value"}) diff --git a/sklearn/pipeline.py b/sklearn/pipeline.py index 03c5a709d1f82..1c5b61265e213 100644 --- a/sklearn/pipeline.py +++ b/sklearn/pipeline.py @@ -9,6 +9,7 @@ # Lars Buitinck # License: BSD +import warnings from collections import defaultdict from itertools import islice @@ -27,10 +28,11 @@ from .utils.deprecation import deprecated from .utils._tags import _safe_tags from .utils.validation import check_memory +from .utils import MetadataRouter +from .utils import metadata_request_factory from .utils.validation import check_is_fitted from .utils.fixes import delayed from .exceptions import NotFittedError - from .utils.metaestimators import _BaseComposition __all__ = ["Pipeline", "FeatureUnion", "make_pipeline", "make_union"] @@ -294,19 +296,64 @@ def _log_message(self, step_idx): return "(step %d of %d) Processing %s" % (step_idx + 1, len(self.steps), name) + def get_metadata_request(self): + """Get requested data properties. + + .. versionadded:: 1.1 + + Returns + ------- + request : dict + A dict of dict of str->value. The key to the first dict is the name + of the method, and the key to the second dict is the name of the + argument requested by the method. + """ + _, estimators = zip(*self.steps) + return ( + MetadataRouter() + .add(*estimators, mapping="one-to-one", overwrite="smart", mask=True) + .get_metadata_request() + ) + def _check_fit_params(self, **fit_params): fit_params_steps = {name: {} for name, step in self.steps if step is not None} - for pname, pval in fit_params.items(): - if "__" not in pname: + # Remove old_behavior in 1.3 + old_behavior = np.any(["__" in pname for pname in fit_params.keys()]) + if old_behavior: + warnings.warn( + "It seems fit_params are using the deprecated " + "way to route parameters using '__'. This behavior " + "is deprecated in 0.24 and won't be accepted in " + "0.26. Please use `set_request_metadata` to route " + "parameters.", + FutureWarning, + ) + for pname, pval in fit_params.items(): + if "__" not in pname: + raise ValueError( + "Pipeline.fit does not accept the {} parameter. " + "You can pass parameters to specific steps of your " + "pipeline using the stepname__parameter format, e.g. " + "`Pipeline.fit(X, y, logisticregression__sample_weight" + "=sample_weight)`.".format(pname) + ) + step, param = pname.split("__", 1) + fit_params_steps[step][param] = pval + return fit_params_steps + + metadata_request_factory(self).fit.validate_metadata( + ignore_extras=False, kwargs=fit_params + ) + for _, name, transformer in self._iter(filter_passthrough=True): + try: + fit_params_steps[name] = metadata_request_factory( + transformer + ).fit.get_method_input(ignore_extras=True, kwargs=fit_params) + except Exception as e: raise ValueError( - "Pipeline.fit does not accept the {} parameter. " - "You can pass parameters to specific steps of your " - "pipeline using the stepname__parameter format, e.g. " - "`Pipeline.fit(X, y, logisticregression__sample_weight" - "=sample_weight)`.".format(pname) + f"Error while validating fit parameters for {name}. " + f"The underlying error message: {e}" ) - step, param = pname.split("__", 1) - fit_params_steps[step][param] = pval return fit_params_steps # Estimator interface diff --git a/sklearn/tests/test_pipeline.py b/sklearn/tests/test_pipeline.py index fa01b6e834b11..ce9a843fb407d 100644 --- a/sklearn/tests/test_pipeline.py +++ b/sklearn/tests/test_pipeline.py @@ -273,17 +273,18 @@ def test_pipeline_methods_anova(): def test_pipeline_fit_params(): # Test that the pipeline can take fit parameters pipe = Pipeline([("transf", Transf()), ("clf", FitParamT())]) - pipe.fit(X=None, y=None, clf__should_succeed=True) + with pytest.warns(FutureWarning, match="It seems fit_params are"): + pipe.fit(X=None, y=None, clf__should_succeed=True) # classifier should return True assert pipe.predict(None) # and transformer params should not be changed assert pipe.named_steps["transf"].a is None assert pipe.named_steps["transf"].b is None # invalid parameters should raise an error message - msg = re.escape("fit() got an unexpected keyword argument 'bad'") with pytest.raises(TypeError, match=msg): - pipe.fit(None, None, clf__bad=True) + with pytest.warns(FutureWarning, match="It seems fit_params are"): + pipe.fit(None, None, clf__bad=True) def test_pipeline_sample_weight_supported(): @@ -442,9 +443,10 @@ def test_fit_predict_with_intermediate_fit_params(): # tests that Pipeline passes fit_params to intermediate steps # when fit_predict is invoked pipe = Pipeline([("transf", TransfFitParams()), ("clf", FitParamT())]) - pipe.fit_predict( - X=None, y=None, transf__should_get_this=True, clf__should_succeed=True - ) + with pytest.warns(FutureWarning, match="It seems fit_params are using "): + pipe.fit_predict( + X=None, y=None, transf__should_get_this=True, clf__should_succeed=True + ) assert pipe.named_steps["transf"].fit_params["should_get_this"] assert pipe.named_steps["clf"].successful assert "should_succeed" not in pipe.named_steps["transf"].fit_params @@ -1283,10 +1285,28 @@ def test_pipeline_feature_names_out_error_without_definition(): def test_pipeline_param_error(): clf = make_pipeline(LogisticRegression()) with pytest.raises( - ValueError, match="Pipeline.fit does not accept the sample_weight parameter" + ValueError, + match="sample_weight is passed but is not explicitly set as requested or not", ): clf.fit([[0], [0]], [0, 1], sample_weight=[1, 1]) + # until the deprecation, the pipeline uses the "old" style or routing if at + # least one parameter has "__" in it, and therefore `sample_weight` would + # not be accepted + with pytest.raises( + ValueError, + match="Pipeline.fit does not accept the sample_weight parameter", + ): + with pytest.warns( + FutureWarning, match="It seems fit_params are using the deprecated way" + ): + clf.fit( + [[0], [0]], + [0, 1], + sample_weight=[1, 1], + logisticregression__sample_weight=[1, 1], + ) + parameter_grid_test_verbose = ( (est, pattern, method) diff --git a/sklearn/tests/test_props.py b/sklearn/tests/test_props.py new file mode 100644 index 0000000000000..bb7c60a17c372 --- /dev/null +++ b/sklearn/tests/test_props.py @@ -0,0 +1,717 @@ +import re +import numpy as np +import pytest + +from sklearn.base import BaseEstimator, ClassifierMixin, TransformerMixin +from sklearn.datasets import make_classification +from sklearn.feature_selection import SelectKBest +from sklearn.metrics import balanced_accuracy_score +from sklearn.metrics import accuracy_score +from sklearn.metrics import make_scorer +from sklearn.linear_model import LogisticRegression +from sklearn.linear_model import LogisticRegressionCV +from sklearn.model_selection import GridSearchCV +from sklearn.model_selection import cross_validate +from sklearn.model_selection._split import check_cv +from sklearn.pipeline import make_pipeline +from sklearn.svm import SVC +from sklearn.utils import MetadataRequest +from sklearn.utils.metadata_requests import RequestType +from sklearn.utils.metadata_requests import metadata_request_factory +from sklearn.utils.metadata_requests import MetadataRouter +from sklearn.utils.metadata_requests import MethodMetadataRequest +from sklearn.model_selection import KFold +from sklearn.model_selection import GroupKFold +from sklearn.model_selection import StratifiedKFold +from sklearn.model_selection import TimeSeriesSplit +from sklearn.model_selection import LeaveOneGroupOut +from sklearn.model_selection import LeaveOneOut +from sklearn.model_selection import LeavePGroupsOut +from sklearn.model_selection import LeavePOut +from sklearn.model_selection import RepeatedKFold +from sklearn.model_selection import RepeatedStratifiedKFold +from sklearn.model_selection import ShuffleSplit +from sklearn.model_selection import GroupShuffleSplit +from sklearn.model_selection import StratifiedShuffleSplit +from sklearn.model_selection import PredefinedSplit + +from sklearn.base import _MetadataRequester + +NonGroupCVs = [ + KFold, + LeaveOneOut, + LeavePOut, + RepeatedStratifiedKFold, + RepeatedKFold, + ShuffleSplit, + StratifiedKFold, + StratifiedShuffleSplit, + PredefinedSplit, + TimeSeriesSplit, +] + +GroupCVs = [ + GroupKFold, + LeaveOneGroupOut, + LeavePGroupsOut, + GroupShuffleSplit, +] + + +N, M = 100, 4 +X = np.random.rand(N, M) +y = np.random.randint(0, 2, size=N) +my_groups = np.random.randint(0, 10, size=N) +my_weights = np.random.rand(N) +my_other_weights = np.random.rand(N) + + +def assert_request_is_empty(metadata_request, exclude=None): + if isinstance(metadata_request, MetadataRequest): + metadata_request = metadata_request.to_dict() + if exclude is None: + exclude = [] + for method, request in metadata_request.items(): + if method in exclude: + continue + props = [ + prop + for prop, alias in request.items() + if isinstance(alias, str) + or RequestType(alias) != RequestType.ERROR_IF_PASSED + ] + assert not len(props) + + +class MyEst(ClassifierMixin, BaseEstimator): + _metadata_request__sample_weight = { + "fit": {"sample_weight": RequestType.REQUESTED} # type: ignore + } + _metadata_request__brand = {"fit": {"brand": RequestType.REQUESTED}} + + def __init__(self, C=1.0): + self.C = C + + def fit(self, X, y, **fit_params): + metadata_request_factory(self).fit.validate_metadata( + ignore_extras=False, kwargs=fit_params + ) + self.svc_ = SVC(C=self.C).fit(X, y) + return self + + def predict(self, X): + return self.svc_.predict(X) + + +class MyTrs(TransformerMixin, BaseEstimator): + def fit(self, X, y=None, brand=None, new_param=None, sample_weight=None): + metadata_request_factory(self).fit.validate_metadata( + ignore_extras=False, + kwargs={ + "brand": brand, + "new_param": new_param, + "sample_weight": sample_weight, + }, + ) + self._estimator = SelectKBest().fit(X, y) + return self + + def transform(self, X, y=None): + return self._estimator.transform(X) + + +def my_metric(y, y_pred, new_param): + return balanced_accuracy_score(y, y_pred) + + +def test_defaults(): + assert_request_is_empty(LogisticRegression().get_metadata_request()) + # check default requests for dummy estimators + trs_request = metadata_request_factory(MyTrs()) + assert trs_request.fit.requests == { + "sample_weight": RequestType(None), + "brand": RequestType(None), + "new_param": RequestType(None), + } + assert_request_is_empty(trs_request) + + est_request = metadata_request_factory(MyEst()) + assert est_request.fit.requests == { + "sample_weight": RequestType(True), + "brand": RequestType(True), + } + assert_request_is_empty(est_request, exclude={"fit"}) + + +def test_pipeline(): + X, y = make_classification() + sw = np.random.rand(len(X)) + my_data = [5, 6] + brand = ["my brand"] + + # MyEst is requesting "brand" but MyTrs has it as ERROR_IF_PASSED + with pytest.raises( + ValueError, + match="brand is passed but is not explicitly set as requested or not.", + ): + clf = make_pipeline(MyTrs(), MyEst()) + clf.fit(X, y, sample_weight=sw, brand=brand) + + clf = make_pipeline(MyTrs().fit_requests(brand=True, sample_weight=True), MyEst()) + clf.fit(X, y, sample_weight=sw, brand=brand) + + with pytest.raises(ValueError, match="Metadata passed which is not understood"): + clf.fit(X, y, sample_weight=sw, brand=brand, other_param=sw) + + trs = MyTrs().fit_requests(new_param="my_sw", sample_weight=False) + + trs_request = metadata_request_factory(trs) + assert trs_request.fit.requests == { + "new_param": "my_sw", + "brand": RequestType.ERROR_IF_PASSED, + "sample_weight": RequestType.UNREQUESTED, + } + assert_request_is_empty(trs_request, exclude={"fit"}) + + clf = make_pipeline(trs, MyEst()) + clf.get_metadata_request() + + clf = make_pipeline(MyTrs().fit_requests(new_param="my_sw"), MyEst()) + pipe_request = metadata_request_factory(clf) + assert pipe_request.fit.requests == { + "my_sw": RequestType.REQUESTED, + "sample_weight": RequestType.REQUESTED, + "brand": RequestType.REQUESTED, + } + assert_request_is_empty(pipe_request, exclude={"fit"}) + with pytest.raises( + ValueError, + match=( + "Error while validating fit parameters for mytrs. The underlying " + "error message: brand is passed but is not explicitly set as " + "requested or not." + ), + ): + clf.fit(X, y, sample_weight=sw, brand=brand, my_sw=my_data) + + clf.named_steps["mytrs"].fit_requests(brand=True, sample_weight=False) + clf.fit(X, y, sample_weight=sw, brand=brand, my_sw=my_data) + + # TODO: assert that trs did *not* receive sample_weight, but did receive + # my_sw + + # If requested metadata is not given, no warning or error is raised + with pytest.warns(None) as record: + clf.fit(X, y, brand=brand) + assert not record.list + + scorer = make_scorer(my_metric).score_requests(new_param=True) + + param_grid = {"myest__C": [0.1, 1]} + + gs = GridSearchCV(clf, param_grid=param_grid, scoring=scorer) + gs.get_metadata_request() + gs.fit(X, y, new_param=brand, sample_weight=sw, my_sw=sw, brand=brand) + + +def test_slep_caseA(): + # Case A: weighted scoring and fitting + + # Here we presume that GroupKFold requests `groups` by default. + # We need to explicitly request weights in make_scorer and for + # LogisticRegressionCV. Both of these consumers understand the meaning + # of the key "sample_weight". + + weighted_acc = make_scorer(accuracy_score).score_requests(sample_weight=True) + lr = LogisticRegressionCV( + cv=GroupKFold(), + scoring=weighted_acc, + ).fit_requests(sample_weight=True) + lr.fit(X, y, sample_weight=my_weights, groups=my_groups) + cross_validate( + lr, + X, + y, + cv=GroupKFold(), + props={"sample_weight": my_weights, "groups": my_groups}, + scoring=weighted_acc, + ) + + # Error handling: if props={'sample_eight': my_weights, ...} was passed, + # cross_validate would raise an error, since 'sample_eight' was not + # requested by any of its children. + + +def test_slep_caseB(): + # Case B: weighted scoring and unweighted fitting + + # Since LogisticRegressionCV requires that weights explicitly be requested, + # removing that request means the fitting is unweighted. + + weighted_acc = make_scorer(accuracy_score).score_requests(sample_weight=True) + lr = LogisticRegressionCV( + cv=GroupKFold(), + scoring=weighted_acc, + ) + with pytest.raises( + ValueError, + match="sample_weight is passed but is not explicitly set as requested or not", + ): + cross_validate( + lr, + X, + y, + cv=GroupKFold(), + props={"sample_weight": my_weights, "groups": my_groups}, + scoring=weighted_acc, + ) + + lr.fit_requests(sample_weight=False) + cross_validate( + lr, + X, + y, + cv=GroupKFold(), + props={"sample_weight": my_weights, "groups": my_groups}, + scoring=weighted_acc, + ) + + +def test_slep_caseC(): + # Case C: unweighted feature selection + + # Like LogisticRegressionCV, SelectKBest needs to request weights + # explicitly. Here it does not request them. + + weighted_acc = make_scorer(accuracy_score).score_requests(sample_weight=True) + lr = LogisticRegressionCV( + cv=GroupKFold(), + scoring=weighted_acc, + ).fit_requests(sample_weight=True) + sel = SelectKBest(k=2) + pipe = make_pipeline(sel, lr) + cross_validate( + pipe, + X, + y, + cv=GroupKFold(), + props={"sample_weight": my_weights, "groups": my_groups}, + scoring=weighted_acc, + ) + + +def test_slep_caseD(): + # Case D: different scoring and fitting weights + + # Despite make_scorer and LogisticRegressionCV both expecting a key + # sample_weight, we can use aliases to pass different weights to different + # consumers. + + weighted_acc = make_scorer(accuracy_score).score_requests(sample_weight=True) + + lr = LogisticRegressionCV( + cv=GroupKFold(), + scoring=weighted_acc, + ).fit_requests(sample_weight="fitting_weight") + cross_validate( + lr, + X, + y, + cv=GroupKFold(), + props={ + "scoring_weight": my_weights, + "fitting_weight": my_other_weights, + "groups": my_groups, + }, + scoring=weighted_acc, + ) + + +@pytest.mark.parametrize("Klass", GroupCVs) +def test_group_splitter_metadata_requests(Klass): + if Klass is LeavePGroupsOut: + cv = Klass(n_groups=2) + else: + cv = Klass() + # check the default metadata_request + assert metadata_request_factory(cv).split.requests == { + "groups": RequestType.REQUESTED + } + + # test that setting split to False empties the metadata_request + cv.split_requests(groups=None) + assert_request_is_empty(cv.get_metadata_request()) + + # set a different input name and test + cv.split_requests(groups="my_groups") + assert metadata_request_factory(cv).split.requests == {"groups": "my_groups"} + + +@pytest.mark.parametrize("Klass", NonGroupCVs) +def test_nongroup_splitter_metadata_requests(Klass): + if Klass is LeavePOut: + cv = Klass(p=2) + elif Klass is PredefinedSplit: + cv = Klass(test_fold=[1, 1, 0]) + else: + cv = Klass() + + # check the default metadata_request + assert_request_is_empty(cv.get_metadata_request()) + + # test that setting split to False empties the metadata_request + assert not hasattr(cv, "request_groups") + + +def test_invalid_arg_given(): + # tests that passing an invalid argument would raise an error + weighted_acc = make_scorer(accuracy_score).score_requests(sample_weight=True) + model = LogisticRegression().fit_requests(sample_weight=True) + param_grid = {"C": [0.1, 1]} + gs = GridSearchCV( + estimator=model, + cv=GroupKFold(), + scoring=weighted_acc, + param_grid=param_grid, + ) + gs.get_metadata_request() + gs.fit(X, y, sample_weight=my_weights, groups=my_groups) + with pytest.raises(ValueError, match="Metadata passed which is not understood"): + gs.fit( + X, + y, + sample_weigh=my_weights, + groups=my_groups, + sample_weight=my_weights, + ) + + with pytest.raises(ValueError, match="Metadata passed which is not understood"): + gs.fit( + X, + y, + sample_weigh=my_weights, + groups=my_groups, + ) + + +def test_get_metadata_request(): + class TestDefaultsBadMetadataName(_MetadataRequester): + _metadata_request__sample_weight = { + "fit": "sample_weight", + "score": "sample_weight", + } + + _metadata_request__my_param = { + "score": {"my_param": True}, + # the following method raise an error + "other_method": {"my_param": True}, + } + + _metadata_request__my_other_param = { + "score": "my_other_param", + # this should raise since the name is different than the metadata + "fit": "my_param", + } + + class TestDefaultsBadMethodName(_MetadataRequester): + _metadata_request__sample_weight = { + "fit": "sample_weight", + "score": "sample_weight", + } + + _metadata_request__my_param = { + "score": {"my_param": True}, + # the following method raise an error + "other_method": {"my_param": True}, + } + + _metadata_request__my_other_param = { + "score": "my_other_param", + "fit": "my_other_param", + } + + class TestDefaults(_MetadataRequester): + _metadata_request__sample_weight = { + "fit": "sample_weight", + "score": "sample_weight", + } + + _metadata_request__my_param = { + "score": {"my_param": True}, + "predict": {"my_param": True}, + } + + _metadata_request__my_other_param = { + "score": "my_other_param", + "fit": "my_other_param", + } + + with pytest.raises(ValueError, match="Expected all metadata to be called"): + TestDefaultsBadMetadataName().get_metadata_request() + + with pytest.raises(ValueError, match="other_method is not supported as a method"): + TestDefaultsBadMethodName().get_metadata_request() + + expected = { + "score": { + "my_param": RequestType(True), + "my_other_param": RequestType(None), + "sample_weight": RequestType(None), + }, + "fit": { + "my_other_param": RequestType(None), + "sample_weight": RequestType(None), + }, + "partial_fit": {}, + "predict": {"my_param": RequestType(True)}, + "transform": {}, + "inverse_transform": {}, + "split": {}, + } + assert TestDefaults().get_metadata_request() == expected + + est = TestDefaults().score_requests(my_param="other_param") + expected = { + "score": { + "my_param": "other_param", + "my_other_param": RequestType(None), + "sample_weight": RequestType(None), + }, + "fit": { + "my_other_param": RequestType(None), + "sample_weight": RequestType(None), + }, + "partial_fit": {}, + "predict": {"my_param": RequestType(True)}, + "transform": {}, + "inverse_transform": {}, + "split": {}, + } + assert est.get_metadata_request() == expected + + est = TestDefaults().fit_requests(sample_weight=True) + expected = { + "score": { + "my_param": RequestType(True), + "my_other_param": RequestType(None), + "sample_weight": RequestType(None), + }, + "fit": { + "my_other_param": RequestType(None), + "sample_weight": RequestType(True), + }, + "partial_fit": {}, + "predict": {"my_param": RequestType(True)}, + "transform": {}, + "inverse_transform": {}, + "split": {}, + } + assert est.get_metadata_request() == expected + + +def test__get_default_requests(): + class ExplicitRequest(BaseEstimator): + _metadata_request__prop = {"fit": "prop"} + + def fit(self, X, y): + return self + + assert metadata_request_factory(ExplicitRequest()).fit.requests == { + "prop": RequestType.ERROR_IF_PASSED + } + assert_request_is_empty(ExplicitRequest().get_metadata_request(), exclude="fit") + + class ExplicitRequestOverwrite(BaseEstimator): + _metadata_request__prop = {"fit": {"prop": RequestType.REQUESTED}} + + def fit(self, X, y, prop=None, **kwargs): + return self + + assert metadata_request_factory(ExplicitRequestOverwrite()).fit.requests == { + "prop": RequestType.REQUESTED + } + assert_request_is_empty( + ExplicitRequestOverwrite().get_metadata_request(), exclude="fit" + ) + + class ImplicitRequest(BaseEstimator): + def fit(self, X, y, prop=None, **kwargs): + return self + + assert metadata_request_factory(ImplicitRequest()).fit.requests == { + "prop": RequestType.ERROR_IF_PASSED + } + assert_request_is_empty(ImplicitRequest().get_metadata_request(), exclude="fit") + + class ImplicitRequestRemoval(BaseEstimator): + _metadata_request__prop = {"fit": {"prop": RequestType.UNUSED}} + + def fit(self, X, y, prop=None, **kwargs): + return self + + assert metadata_request_factory(ImplicitRequestRemoval()).fit.requests == {} + assert_request_is_empty(ImplicitRequestRemoval().get_metadata_request()) + + +def test_validate(): + class ConsumerRouter(BaseEstimator): + def __init__(self, cv=None): + self.cv = cv + + def fit(self, X, y, sample_weight=None, **kwargs): + kwargs["sample_weight"] = sample_weight + metadata_request_factory(self).fit.validate_metadata( + ignore_extras=False, + self_metadata=super(), + kwargs=kwargs, + ) + return self + + def get_metadata_request(self): + router = ( + MetadataRouter() + .add(super(), mapping="one-to-one", overwrite=False, mask=False) + .add( + check_cv(self.cv), + mapping={"fit": "split"}, + overwrite=False, + mask=True, + ) + ) + return router.get_metadata_request() + + err_message = "Metadata passed which is not understood: {param}. In method: fit" + + est = ConsumerRouter() + est.fit(X=None, y=None) + est.fit(X=None, y=None, sample_weight="test") + + with pytest.raises( + ValueError, match=re.escape(err_message.format(param=["my_weight"])) + ): + est.fit(X=None, y=None, my_weight="test") + + with pytest.raises( + ValueError, match=re.escape(err_message.format(param=["my_weight"])) + ): + est.fit(X=None, y=None, sample_weight="test", my_weight="test") + + est = ConsumerRouter(cv=GroupKFold()) + est.fit(X=None, y=None, groups="test") + est.fit(X=None, y=None, sample_weight="test", groups="test") + + with pytest.raises( + ValueError, match=re.escape(err_message.format(param=["my_weight"])) + ): + est.fit(X=None, y=None, my_weight="test") + + with pytest.raises( + ValueError, match=re.escape(err_message.format(param=["my_weight"])) + ): + est.fit(X=None, y=None, sample_weight="test", my_weight="test") + + est = ConsumerRouter(cv=GroupKFold().split_requests(groups="my_groups")) + est.fit(X=None, y=None, sample_weight="test", my_groups="test") + + with pytest.raises( + ValueError, match=re.escape(err_message.format(param=["groups"])) + ): + est.fit(X=None, y=None, groups="test") + + with pytest.raises( + ValueError, match=re.escape(err_message.format(param=["groups"])) + ): + est.fit(X=None, y=None, sample_weight="test", my_groups="test", groups="test") + + with pytest.raises( + ValueError, match=re.escape(err_message.format(param=["groups"])) + ): + est.fit(X=None, y=None, sample_weight="test", groups="test") + + est = ConsumerRouter( + cv=GroupKFold().split_requests(groups="my_groups") + ).fit_requests(sample_weight="my_weight") + est.fit(X=None, y=None, sample_weight="test", my_groups="test") + est.fit(X=None, y=None, my_groups="test") + + with pytest.raises( + ValueError, match=re.escape(err_message.format(param=["groups"])) + ): + est.fit(X=None, y=None, groups="test") + + with pytest.raises( + ValueError, match=re.escape(err_message.format(param=["groups"])) + ): + est.fit(X=None, y=None, sample_weight="test", my_groups="test", groups="test") + + with pytest.raises( + ValueError, match=re.escape(err_message.format(param=["groups"])) + ): + est.fit(X=None, y=None, sample_weight="test", groups="test") + + with pytest.raises( + ValueError, match=re.escape(err_message.format(param=["my_weights"])) + ): + est.fit(X=None, y=None, my_weights="test", my_groups="test") + + with pytest.raises( + ValueError, match=re.escape(err_message.format(param=["groups", "my_weights"])) + ): + est.fit(X=None, y=None, my_weights="test", groups="test") + + +def test_method_metadata_request(): + mmr = MethodMetadataRequest(name="fit") + with pytest.raises( + ValueError, + match="overwrite can only be one of {True, False, 'smart', 'ignore'}.", + ): + mmr.add_request(prop="test", alias=None, overwrite="test") + + with pytest.raises(ValueError, match="Expected all metadata to be called test"): + mmr.add_request(prop="foo", alias="bar", expected_metadata="test") + + with pytest.raises(ValueError, match="Aliasing is not allowed"): + mmr.add_request(prop="foo", alias="bar", allow_aliasing=False) + + with pytest.raises(ValueError, match="alias should be either a string or"): + mmr.add_request(prop="foo", alias=1.4) + + mmr.add_request(prop="foo", alias=None) + assert mmr.requests == {"foo": RequestType.ERROR_IF_PASSED} + with pytest.raises(ValueError, match="foo is already requested"): + mmr.add_request(prop="foo", alias=True) + with pytest.raises(ValueError, match="foo is already requested"): + mmr.add_request(prop="foo", alias=True) + mmr.add_request(prop="foo", alias=True, overwrite="smart") + assert mmr.requests == {"foo": RequestType.REQUESTED} + + with pytest.raises(ValueError, match="Can only add another MethodMetadataRequest"): + mmr.merge_method_request({}) + + assert MethodMetadataRequest.from_dict(None, name="fit").requests == {} + assert MethodMetadataRequest.from_dict("foo", name="fit").requests == { + "foo": RequestType.ERROR_IF_PASSED + } + assert MethodMetadataRequest.from_dict(["foo", "bar"], name="fit").requests == { + "foo": RequestType.ERROR_IF_PASSED, + "bar": RequestType.ERROR_IF_PASSED, + } + + +def test_metadata_request_factory(): + class Consumer(BaseEstimator): + _metadata_request__prop = {"fit": "prop"} + + assert_request_is_empty(metadata_request_factory(None)) + assert_request_is_empty(metadata_request_factory({})) + assert_request_is_empty(metadata_request_factory(object())) + + mr = MetadataRequest({"fit": "foo"}, default="bar") + mr_factory = metadata_request_factory(mr) + assert_request_is_empty(mr_factory, exclude="fit") + assert mr_factory.fit.requests == {"foo": "bar"} + + mr = metadata_request_factory(Consumer()) + assert_request_is_empty(mr, exclude="fit") + assert mr.fit.requests == {"prop": RequestType.ERROR_IF_PASSED} diff --git a/sklearn/utils/__init__.py b/sklearn/utils/__init__.py index 8290318d35deb..30ef0b0c3cb7f 100644 --- a/sklearn/utils/__init__.py +++ b/sklearn/utils/__init__.py @@ -40,6 +40,10 @@ check_scalar, ) from .. import get_config +from .metadata_requests import MetadataRequest +from .metadata_requests import MethodMetadataRequest +from .metadata_requests import metadata_request_factory +from .metadata_requests import MetadataRouter # Do not deprecate parallel_backend and register_parallel_backend as they are @@ -74,6 +78,10 @@ "all_estimators", "DataConversionWarning", "estimator_html_repr", + "MetadataRequest", + "metadata_request_factory", + "MetadataRouter", + "MethodMetadataRequest", ] IS_PYPY = platform.python_implementation() == "PyPy" diff --git a/sklearn/utils/_mocking.py b/sklearn/utils/_mocking.py index 33a73f77d2d47..6980799bac053 100644 --- a/sklearn/utils/_mocking.py +++ b/sklearn/utils/_mocking.py @@ -2,6 +2,7 @@ from ..base import BaseEstimator, ClassifierMixin from .validation import _num_samples, check_array, check_is_fitted +from ..utils import MetadataRequest class ArraySlicingWrapper: @@ -133,6 +134,9 @@ def __init__( self.methods_to_check = methods_to_check self.foo_param = foo_param self.expected_fit_params = expected_fit_params + self._metadata_request = MetadataRequest( + {"fit": expected_fit_params}, default=True + ) def _check_X_y(self, X, y=None, should_be_fitted=True): """Validate X and y and make extra check. diff --git a/sklearn/utils/estimator_checks.py b/sklearn/utils/estimator_checks.py index 914b4e6168247..85b87655a4dc3 100644 --- a/sklearn/utils/estimator_checks.py +++ b/sklearn/utils/estimator_checks.py @@ -2941,6 +2941,8 @@ def check_no_attributes_set_in_init(name, estimator_orig): # Test for no setting apart from parameters during init invalid_attr = set(vars(estimator)) - set(init_params) - set(parents_init_params) + # Ignore private attributes + invalid_attr = set([attr for attr in invalid_attr if not attr.startswith("_")]) assert not invalid_attr, ( "Estimator %s should not set any attribute apart" " from parameters during init. Found attributes %s." diff --git a/sklearn/utils/metadata_requests.py b/sklearn/utils/metadata_requests.py new file mode 100644 index 0000000000000..da62af0b482ff --- /dev/null +++ b/sklearn/utils/metadata_requests.py @@ -0,0 +1,826 @@ +# from copy import deepcopy +import inspect +from enum import Enum +from collections import defaultdict +from typing import Union, Optional +from ..externals._sentinels import sentinel # type: ignore # mypy error!!! + + +class RequestType(Enum): + UNREQUESTED = False + REQUESTED = True + ERROR_IF_PASSED = None + # this sentinel is used in `_metadata_request__*` attributes to indicate + # that a metadata is not present even though it may be present in the + # corresponding method's signature. + UNUSED = sentinel("UNUSED") + + +# this sentinel is the default used in `{method}_requests` methods to indicate +# no change requested by the user. +UNCHANGED = sentinel("UNCHANGED") + +METHODS = [ + "fit", + "partial_fit", + "predict", + "score", + "split", + "transform", + "inverse_transform", +] + + +REQUESTER_DOC = """ Request metadata passed to the ``{method}`` method. + + Parameters + ---------- +""" +REQUESTER_DOC_PARAM = """ {metadata} : RequestType, str, True, False, or None, \ + default=UNCHANGED + Whether {metadata} should be passed to {method} by meta-estimators or + not, and if yes, should it have an alias. + + - True or RequestType.REQUESTED: {metadata} is requested, and passed to \ +{method} if provided. + + - False or RequestType.UNREQUESTED: {metadata} is not requested and the \ +meta-estimator will not pass it to {method}. + + - None or RequestType.ERROR_IF_PASSED: {metadata} is not requested, and \ +the meta-estimator will raise an error if the user provides {metadata} + + - str: {metadata} should be passed to the meta-estimator with this given \ +alias instead of the original name. + +""" +REQUESTER_DOC_RETURN = """ Returns + ------- + self + Returns the object itself. +""" + + +class MethodMetadataRequest: + """Contains the metadata request info for a single method. + + .. versionadded:: 1.1 + + Parameters + ---------- + name : str + The name of the method to which these requests belong. + """ + + def __init__(self, name): + self.requests = dict() + self.name = name + + def add_request( + self, + *, + prop, + alias, + allow_aliasing=True, + overwrite=False, + expected_metadata=None, + ): + """Add request info for a prop. + + Parameters + ---------- + prop : str + The property for which a request is set. + + alias : str, RequestType, or {True, False, None} + The alias which is routed to `prop` + + - str: the name which should be used as an alias when a meta-estimator + routes the metadata. + + - True or RequestType.REQUESTED: requested + + - False or RequestType.UNREQUESTED: not requested + + - None or RequestType.ERROR_IF_PASSED: error if passed + + allow_aliasing : bool, default=True + If False, alias should be the same as prop if it's a string. + + overwrite : bool or str, default=False + + - True: ``alias`` replaces the existing routing. + + - False: a ``ValueError`` is raised if the given value conflicts + with an existing one. + + - "smart": overwrite in this order: + ``RequestType.REQUESTED`` over ``RequestType.UNREQUESTED`` over + ``RequestType.ERROR_IF_PASSED``, and error if existing value is + a string. + + - "ignore": ignore the requested metadata if it already exists. + + expected_metadata : str, default=None + If provided, all props should be the same as this value. It used to + handle default values. + """ + if overwrite not in {True, False, "smart", "ignore"}: + raise ValueError( + "overwrite can only be one of {True, False, 'smart', 'ignore'}; " + f"but f{overwrite} is given." + ) + if expected_metadata is not None and expected_metadata != prop: + raise ValueError( + f"Expected all metadata to be called {expected_metadata} but " + f"{prop} was passed." + ) + if not allow_aliasing and isinstance(alias, str) and prop != alias: + raise ValueError( + "Aliasing is not allowed, prop and alias should " + "be the same strings if alias is a string." + ) + + if not isinstance(alias, str): + try: + alias = RequestType(alias) + except ValueError: + raise ValueError( + "alias should be either a string or one of " + "{None, True, False}, or a RequestType." + ) + + if alias == prop: + alias = RequestType.REQUESTED + + if alias == RequestType.UNUSED and prop in self.requests: + del self.requests[prop] + elif prop not in self.requests or overwrite is True: + self.requests[prop] = alias + elif prop in self.requests and overwrite == "ignore": + pass + elif overwrite == "smart": + current = self.requests[prop] + if isinstance(current, str): + raise ValueError( + "Cannot overwrite f{current} with f{alias} when overwrite=smart." + ) + current = RequestType(current) + + # REQUESTED > UNREQUESTED > ERROR_IF_PASSED + if alias == RequestType.REQUESTED and current in { + RequestType.ERROR_IF_PASSED, + RequestType.UNREQUESTED, + }: + self.requests[prop] = alias + elif ( + alias == RequestType.UNREQUESTED + and current == RequestType.ERROR_IF_PASSED + ): + self.requests[prop] = alias + elif self.requests[prop] != alias: + raise ValueError( + f"{prop} is already requested as {self.requests[prop]}, " + f"which is not the same as the one given: {alias}. Cannot " + "overwrite when overwrite=False." + ) + + def merge_method_request(self, other, overwrite=False, expected_metadata=None): + """Merge the metadata request info of two methods. + + The methods can be the same, or different. For example, merging + fit and score info of the same object, or merging fit request info + from two different sub estimators. + + Parameters + ---------- + other : MethodMetadataRequest + The other object to be merged with this instance. + + overwrite : bool or str, default=False + + - True: ``alias`` replaces the existing routing. + + - False: a ``ValueError`` is raised if the given value conflicts + with an existing one. + + - "smart": overwrite in this order: + ``RequestType.REQUESTED`` over ``RequestType.UNREQUESTED`` over + ``RequestType.ERROR_IF_PASSED``, and error if existing value is + a string. + + - "ignore": ignore the requested metadata if it already exists. + + expected_metadata : str, default=None + If provided, all props should be the same as this value. It used to + handle default values. + """ + if not isinstance(other, MethodMetadataRequest): + raise ValueError("Can only add another MethodMetadataRequest.") + for prop, alias in other.requests.items(): + self.add_request( + prop=prop, + alias=alias, + overwrite=overwrite, + expected_metadata=expected_metadata, + ) + + def validate_metadata(self, ignore_extras=False, self_metadata=None, kwargs=None): + """Validate the given arguments against the requested ones. + + Parameters + ---------- + ignore_extras : bool, default=False + If ``True``, no error is raised if extra unknown args are passed. + + self_metadata : MetadataRequest-like, default=None + This parameter can be anything which can be an input to + ``metadata_request_factory``. Only the part of the metadata which + is the same as ``name`` is used. + + Consumers don't validate their own metadata. Validation is always + done by routers (i.e. usually meta-estimators). But sometimes an + object is a consumer and a router, e.g. ``LogisticRegressionCV`` + which consumes ``sample_weight``, but also routes metadata to the + given scorer(s) and CV object, and therefore is also a router. In + such a case, ``sample_weight`` is the metadata being consumed. A + router can get its own required metadata, as opposed to the ones + required by its sub-objects, using + ``metadata_request_factory(super())``. ``validate_metadata`` then + uses the part which is relevant to this validation. Since this + object knows which method is relevant using its ``name``, passing + ``super()`` here would be sufficient. + + kwargs : dict + Provided metadata. + + Returns + ------- + None + """ + kwargs = {} if kwargs is None else kwargs + self_metadata = getattr( + metadata_request_factory(self_metadata), self.name + ).requests + # we then remove self metadata from kwargs, since they should not be + # validated. + kwargs = {v: k for v, k in kwargs.items() if v not in self_metadata} + args = {arg for arg, value in kwargs.items() if value is not None} + if not ignore_extras and args - set(self.requests.keys()): + raise ValueError( + "Metadata passed which is not understood: " + f"{sorted(args - set(self.requests.keys()))}. In method: " + f"{self.name}" + ) + + for prop, alias in self.requests.items(): + if not isinstance(alias, str): + alias = RequestType(alias) + if alias == RequestType.UNREQUESTED: + continue + elif alias == RequestType.REQUESTED or isinstance(alias, str): + # we ignore what the given alias here is, since aliases are + # checked at the parent meta-estimator level, and the child + # still expects the original names for the metadata. + # If a metadata is requested but not passed, no error is raised + continue + elif alias == RequestType.ERROR_IF_PASSED: + if prop in args: + raise ValueError( + f"{prop} is passed but is not explicitly set as " + f"requested or not. In method: {self.name}" + ) + + def get_method_input(self, ignore_extras=False, kwargs=None): + """Return the input parameters requested by the method. + + The output of this method can be used directly as the input to the + corresponding method as extra props. + + Parameters + ---------- + ignore_extras : bool, default=False + If ``True``, no error is raised if extra unknown args are passed. + + kwargs : dict + A dictionary of provided metadata. + + Returns + ------- + kwargs : dict + A dictionary of {prop: value} which can be given to the + corresponding method. + """ + kwargs = {} if kwargs is None else kwargs + args = {arg: value for arg, value in kwargs.items() if value is not None} + res = dict() + for prop, alias in self.requests.items(): + if not isinstance(alias, str): + alias = RequestType(alias) + + if alias == RequestType.UNREQUESTED: + continue + elif alias == RequestType.REQUESTED and prop in args: + res[prop] = args[prop] + elif alias == RequestType.ERROR_IF_PASSED and prop in args: + raise ValueError( + f"{prop} is passed but is not explicitly set as " + f"requested or not. In method: {self.name}" + ) + elif alias in args: + res[prop] = args[alias] + self.validate_metadata(ignore_extras=ignore_extras, kwargs=res) + return res + + def masked(self): + """Return a masked version of the requests. + + Returns + ------- + masked : MethodMetadataRequest + A masked version is one which converts a ``{'prop': 'alias'}`` to + ``{'alias': True}``. This is desired in meta-estimators passing + requests to their parent estimators. + """ + res = MethodMetadataRequest(name=self.name) + for prop, alias in self.requests.items(): + if isinstance(alias, str): + res.add_request( + prop=alias, + alias=alias, + allow_aliasing=False, + overwrite=False, + ) + else: + res.add_request( + prop=prop, + alias=alias, + allow_aliasing=False, + overwrite=False, + ) + return res + + @classmethod + def from_dict( + cls, requests, name, allow_aliasing=True, default=RequestType.ERROR_IF_PASSED + ): + """Construct a MethodMetadataRequest from a given dictionary. + + Parameters + ---------- + requests : dict + A dictionary representing the requests. + + name : str + The name of the method to which these requests belong. + + allow_aliasing : bool, default=True + If false, only aliases with the same name as the parameter are + allowed. This is useful when handling the default values. + + default : RequestType, True, False, None, or str, \ + default=RequestType.ERROR_IF_PASSED + The default value to be used if parameters are provided as a string + or list instead of the fully specifying dict. + + Returns + ------- + requests: MethodMetadataRequest + A :class:`MethodMetadataRequest` object. + """ + if requests is None: + requests = dict() + elif isinstance(requests, str): + requests = {requests: default} + elif isinstance(requests, (list, set)): + requests = {r: default for r in requests} + result = cls(name=name) + for prop, alias in requests.items(): + result.add_request(prop=prop, alias=alias, allow_aliasing=allow_aliasing) + return result + + def __repr__(self): + return str(self.requests) + + def __str__(self): + return str(self.requests) + + +class MetadataRequest: + """Contains the metadata request info of an object. + + .. versionadded:: 1.1 + + Parameters + ---------- + requests : dict of dict of {str: str}, default=None + A dictionary where the keys are the names of the methods, and the values are + a dictionary of the form ``{"required_metadata": "provided_metadata"}``. + ``"provided_metadata"`` can also be a ``RequestType`` or {True, False, None}. + + default : RequestType, True, False, None, or str, \ + default=RequestType.ERROR_IF_PASSED + The default value to be used if parameters are provided as a string instead of + the usual second layer dict. + """ + + def __init__(self, requests=None, default=RequestType.ERROR_IF_PASSED): + for method in METHODS: + setattr(self, method, MethodMetadataRequest(name=method)) + + if requests is None: + return + elif not isinstance(requests, dict): + raise ValueError( + "Can only construct an instance from a dict. Please call " + "metadata_request_factory for other types of input." + ) + + for method, method_requests in requests.items(): + if method not in METHODS: + raise ValueError(f"{method} is not supported as a method.") + setattr( + self, + method, + MethodMetadataRequest.from_dict( + method_requests, name=method, default=default + ), + ) + + def add_requests( + self, + obj, + mapping="one-to-one", + overwrite=False, + expected_metadata=None, + ): + """Add request info from the given object with the desired mapping. + + Parameters + ---------- + obj : object + An object from which a MetadataRequest can be constructed. + + mapping : dict or str, default="one-to-one" + The mapping between the ``obj``'s methods and this object's + methods. If ``"one-to-one"`` all methods' requests from ``obj`` are + merged into this instance's methods. If a dict, the mapping is of + the form ``{"destination_method": "source_method"}``. + + overwrite : bool or str, default=False + + - True: ``alias`` replaces the existing routing. + + - False: a ``ValueError`` is raised if the given value conflicts + with an existing one. + + - "smart": overwrite in this order: + ``RequestType.REQUESTED`` over ``RequestType.UNREQUESTED`` over + ``RequestType.ERROR_IF_PASSED``, and error if existing value is + a string. + + - "ignore": ignore the requested metadata if it already exists. + + expected_metadata : str, default=None + If provided, all props should be the same as this value. It used to + handle default values. + """ + if not isinstance(mapping, dict) and mapping != "one-to-one": + raise ValueError( + "mapping can only be a dict or the literal 'one-to-one'. " + f"Given value: {mapping}" + ) + if mapping == "one-to-one": + mapping = {method: method for method in METHODS} + other = metadata_request_factory(obj) + for destination, source in mapping.items(): + my_method = getattr(self, destination) + other_method = getattr(other, source) + my_method.merge_method_request( + other_method, + overwrite=overwrite, + expected_metadata=expected_metadata, + ) + + def masked(self): + """Return a masked version of the requests. + + A masked version is one which converts a ``{'prop': 'alias'}`` to + ``{'alias': True}``. This is desired in meta-estimators passing + requests to their parent estimators. + """ + res = MetadataRequest() + for method in METHODS: + setattr(res, method, getattr(self, method).masked()) + return res + + def to_dict(self): + """Return dictionary representation of this object.""" + output = dict() + for method in METHODS: + output[method] = getattr(self, method).requests + return output + + def __repr__(self): + return str(self.to_dict()) + + def __str__(self): + return str(self.to_dict()) + + +def metadata_request_factory(obj=None): + """Get a MetadataRequest instance from the given object. + + .. versionadded:: 1.1 + + Parameters + ---------- + obj : object + If the object is already a MetadataRequest, return that. + If the object is an estimator, try to call `get_metadata_request` and get + an instance from that method. + If the object is a dict, create a MetadataRequest from that. + + Returns + ------- + metadata_requests : MetadataRequest + A ``MetadataRequest`` taken or created from the given object. + """ + if obj is None: + return MetadataRequest() + + if isinstance(obj, MetadataRequest): + return obj + + if isinstance(obj, dict): + return MetadataRequest(obj) + + try: + return MetadataRequest(obj.get_metadata_request()) + except AttributeError: + # The object doesn't have a `get_metadata_request` method. + return MetadataRequest() + + +class MetadataRouter: + """Route the metadata to child objects. + + .. versionadded:: 1.1 + """ + + def __init__(self): + self.requests = MetadataRequest() + + def add(self, *obj, mapping="one-to-one", overwrite=False, mask=False): + """Add a set of requests to the existing ones. + + Parameters + ---------- + *obj : objects + A set of objects from which the requests are extracted. Passed as + arguments to this method. + + mapping : dict or str, default="one-to-one" + The mapping between the ``obj``'s methods and this routing object's + methods. If ``"one-to-one"`` all methods' requests from ``obj`` are + merged into this instance's methods. If a dict, the mapping is of + the form ``{"destination_method": "source_method"}``. + + overwrite : bool or str, default=False + + - True: ``alias`` replaces the existing routing. + + - False: a ``ValueError`` is raised if the given value conflicts + with an existing one. + + - "smart": overwrite in this order: + ``RequestType.REQUESTED`` over ``RequestType.UNREQUESTED`` over + ``RequestType.ERROR_IF_PASSED``, and error if existing value is + a string. + + - "ignore": ignore the requested metadata if it already exists. + + mask : bool, default=False + If the requested metadata should be masked by the alias. If + ``True``, then a request of the form + ``{'sample_weight' : 'my_weight'}`` is converted to + ``{'my_weight': 'my_weight'}``. This is required for meta-estimators + which should expose the requested parameters and not the ones + expected by the objects' methods. + """ + for x in obj: + if mask: + x = metadata_request_factory(x).masked() + self.requests.add_requests(x, mapping=mapping, overwrite=overwrite) + return self + + def get_metadata_request(self): + """Get requested data properties. + + Returns + ------- + request : dict + A dict of dict of str->value. The key to the first dict is the name + of the method, and the key to the second dict is the name of the + argument requested by the method. + """ + return self.requests.to_dict() + + +class RequestMethod: + """ + A descriptor for request methods. + + .. versionadded:: 1.1 + + Parameters + ---------- + name : str + The name of the method for which the request function should be + created, e.g. ``"fit"`` would create a ``fit_requests`` function. + + keys : list of str + A list of strings which are accepted parameters by the created + function, e.g. ``["sample_weight"]`` if the corresponding method + accepts it as a metadata. + + Notes + ----- + This class is a descriptor [1]_ and uses PEP-362 to set the signature of + the returned function [2]_. + + References + ---------- + .. [1] https://docs.python.org/3/howto/descriptor.html + + .. [2] https://www.python.org/dev/peps/pep-0362/ + """ + + def __init__(self, name, keys): + self.name = name + self.keys = keys + + def __get__(self, instance, owner): + # we would want to have a method which accepts only the expected args + def func(**kw): + if set(kw) - set(self.keys): + raise TypeError(f"Unexpected args: {set(kw) - set(self.keys)}") + + requests = metadata_request_factory(instance) + + try: + method_metadata_request = getattr(requests, self.name) + except AttributeError: + raise ValueError(f"{self.name} is not a supported method.") + + for prop, alias in kw.items(): + if alias is not UNCHANGED: + method_metadata_request.add_request( + prop=prop, alias=alias, allow_aliasing=True, overwrite=True + ) + instance._metadata_request = requests.to_dict() + + return instance + + # Now we set the relevant attributes of the function so that it seems + # like a normal method to the end user, with known expected arguments. + func.__name__ = f"{self.name}_requests" + params = [ + inspect.Parameter( + name="self", + kind=inspect.Parameter.POSITIONAL_OR_KEYWORD, + annotation=type(instance), + ) + ] + params.extend( + [ + inspect.Parameter( + k, + inspect.Parameter.KEYWORD_ONLY, + default=UNCHANGED, + annotation=Optional[Union[RequestType, str]], + ) + for k in self.keys + ] + ) + func.__signature__ = inspect.Signature( + params, + return_annotation=type(instance), + ) + doc = REQUESTER_DOC.format(method=self.name) + for metadata in self.keys: + doc += REQUESTER_DOC_PARAM.format(metadata=metadata, method=self.name) + doc += REQUESTER_DOC_RETURN + func.__doc__ = doc + return func + + +class _MetadataRequester: + """Mixin class for adding metadata request functionality. + + .. versionadded:: 1.1 + """ + + def __init_subclass__(cls, **kwargs): + """Set the ``{method}_requests`` methods. + + This uses PEP-487 [1]_ to set the ``{method}_requests`` methods. It + looks for the information available in the set default values which are + set using ``_metadata_request__*`` class attributes. + + References + ---------- + .. [1] https://www.python.org/dev/peps/pep-0487 + """ + try: + requests = cls._get_default_requests().to_dict() + except Exception: + # if there are any issues in the default values, it will be raised + # when ``get_metadata_request`` is called. Here we are going to + # ignore all the issues such as bad defaults etc.` + super().__init_subclass__(**kwargs) + return + + for request_method, request_keys in requests.items(): + # set ``{method}_requests``` methods + if not len(request_keys): + continue + setattr( + cls, + f"{request_method}_requests", + RequestMethod(request_method, sorted(request_keys)), + ) + super().__init_subclass__(**kwargs) + + @classmethod + def _get_default_requests(cls): + """Collect default request values. + + This method combines the information present in ``metadata_request__*`` + class attributes. + """ + + requests = MetadataRequest() + + # need to go through the MRO since this is a class attribute and + # ``vars`` doesn't report the parent class attributes. We go through + # the reverse of the MRO since cls is the first in the tuple and object + # is the last. + defaults = defaultdict() + for klass in reversed(inspect.getmro(cls)): + klass_defaults = { + attr: value + for attr, value in vars(klass).items() + if attr.startswith("_metadata_request__") + } + defaults.update(klass_defaults) + defaults = dict(sorted(defaults.items())) + + # First take all arguments from the method signatures and have them as + # ERROR_IF_PASSED, except X, y, *args, and **kwargs. + for method in METHODS: + # Here we use `isfunction` instead of `ismethod` because calling `getattr` + # on a class instead of an instance returns an unbound function. + if not hasattr(cls, method) or not inspect.isfunction(getattr(cls, method)): + continue + # ignore the first parameter of the method, which is usually "self" + params = list(inspect.signature(getattr(cls, method)).parameters.items())[ + 1: + ] + for pname, param in params: + if pname in {"X", "y", "Y"}: + continue + if param.kind in {param.VAR_POSITIONAL, param.VAR_KEYWORD}: + continue + getattr(requests, method).add_request( + prop=pname, + alias=RequestType.ERROR_IF_PASSED, + allow_aliasing=False, + overwrite=False, + ) + + # Then overwrite those defaults with the ones provided in + # _metadata_request__* attributes, which are provided in `requests` here. + + for attr, value in defaults.items(): + requests.add_requests( + value, overwrite=True, expected_metadata="__".join(attr.split("__")[1:]) + ) + return requests + + def get_metadata_request(self): + """Get requested data properties. + + Returns + ------- + request : dict + A dict of dict of str->value. The key to the first dict is the name + of the method, and the key to the second dict is the name of the + argument requested by the method. + """ + if hasattr(self, "_metadata_request"): + requests = metadata_request_factory(self._metadata_request) + else: + requests = self._get_default_requests() + + return requests.to_dict() diff --git a/sklearn/utils/tests/test_estimator_checks.py b/sklearn/utils/tests/test_estimator_checks.py index c4f954790cd26..50b70d85054b9 100644 --- a/sklearn/utils/tests/test_estimator_checks.py +++ b/sklearn/utils/tests/test_estimator_checks.py @@ -651,6 +651,10 @@ class NonConformantEstimatorNoParamSet(BaseEstimator): def __init__(self, you_should_set_this_=None): pass + class ConformantEstimatorClassAttribute(BaseEstimator): + # making sure our _metadata_request__* class attributes are okay! + _metadata_request__foo = {"fit": "foo"} + msg = ( "Estimator estimator_name should not set any" " attribute apart from parameters during init." @@ -670,6 +674,14 @@ def __init__(self, you_should_set_this_=None): "estimator_name", NonConformantEstimatorNoParamSet() ) + # a private class attribute is okay! + check_no_attributes_set_in_init( + "estimator_name", ConformantEstimatorClassAttribute() + ) + check_no_attributes_set_in_init( + "estimator_name", ConformantEstimatorClassAttribute().fit_requests(foo=True) + ) + def test_check_estimator_pairwise(): # check that check_estimator() works on estimator with _pairwise diff --git a/sklearn/utils/validation.py b/sklearn/utils/validation.py index a2693a44a9f8b..648ee8f2a1602 100644 --- a/sklearn/utils/validation.py +++ b/sklearn/utils/validation.py @@ -1628,6 +1628,7 @@ def _check_fit_params(X, fit_params, indices=None): """ from . import _safe_indexing + fit_params = {} if fit_params is None else fit_params fit_params_validated = {} for param_key, param_value in fit_params.items(): if not _is_arraylike(param_value) or _num_samples(param_value) != _num_samples(